Cross-Encoder for Natural Language Inference

This model was trained using SentenceTransformers Cross-Encoder class.

Training Data

The model was trained on the SNLI and MultiNLI datasets. For a given sentence pair, it will output three scores corresponding to the labels: contradiction, entailment, neutral.

Performance

For evaluation results, see SBERT.net - Pretrained Cross-Encoder.

Usage

Pre-trained models can be used like this:

from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/nli-distilroberta-base')
scores = model.predict([('A man is eating pizza', 'A man eats something'), ('A black race car starts up in front of a crowd of people.', 'A man is driving down a lonely road.')])

#Convert scores to labels
label_mapping = ['contradiction', 'entailment', 'neutral']
labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]

Usage with Transformers AutoModel

You can use the model also directly with Transformers library (without SentenceTransformers library):

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/nli-distilroberta-base')
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/nli-distilroberta-base')

features = tokenizer(['A man is eating pizza', 'A black race car starts up in front of a crowd of people.'], ['A man eats something', 'A man is driving down a lonely road.'],  padding=True, truncation=True, return_tensors="pt")

model.eval()
with torch.no_grad():
    scores = model(**features).logits
    label_mapping = ['contradiction', 'entailment', 'neutral']
    labels = [label_mapping[score_max] for score_max in scores.argmax(dim=1)]
    print(labels)

Zero-Shot Classification

This model can also be used for zero-shot-classification:

from transformers import pipeline

classifier = pipeline("zero-shot-classification", model='cross-encoder/nli-distilroberta-base')

sent = "Apple just announced the newest iPhone X"
candidate_labels = ["technology", "sports", "politics"]
res = classifier(sent, candidate_labels)
print(res)
Downloads last month
11,090
Safetensors
Model size
82.1M params
Tensor type
I64
·
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for cross-encoder/nli-distilroberta-base

Finetunes
1 model

Datasets used to train cross-encoder/nli-distilroberta-base

Spaces using cross-encoder/nli-distilroberta-base 5