Pooling models¶
Source examples/offline_inference/pooling.
Convert llm model to seq cls¶
# for BAAI/bge-reranker-v2-gemma
# Caution: "Yes" and "yes" are two different tokens
python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
# for mxbai-rerank-v2
python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
# for Qwen3-Reranker
python examples/offline_inference/pooling/convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
Embed jina_embeddings_v3 usage¶
Only text matching task is supported for now. See Pull Request #16120
Embed matryoshka dimensions usage¶
Named Entity Recognition (NER) usage¶
Qwen3 reranker usage¶
Example materials¶
convert_model_to_seq_cls.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import argparse
import json
import torch
import transformers
# Usage:
# for BAAI/bge-reranker-v2-gemma
# Caution: "Yes" and "yes" are two different tokens
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
# for mxbai-rerank-v2
# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
# for Qwen3-Reranker
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
assert len(tokens) == 2
lm_head_weights = causal_lm.lm_head.weight
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
score_weight = lm_head_weights[true_id].to(device).to(
torch.float32
) - lm_head_weights[false_id].to(device).to(torch.float32)
with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
if seq_cls_model.score.bias is not None:
seq_cls_model.score.bias.zero_()
def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device):
lm_head_weights = causal_lm.lm_head.weight
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
score_weight = lm_head_weights[token_ids].to(device)
with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight)
if seq_cls_model.score.bias is not None:
seq_cls_model.score.bias.zero_()
method_map = {
function.__name__: function for function in [from_2_way_softmax, no_post_processing]
}
def converting(
model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu"
):
assert method in method_map
if method == "from_2_way_softmax":
assert len(classifier_from_tokens) == 2
num_labels = 1
else:
num_labels = len(classifier_from_tokens)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
model_name, device_map=device
)
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=num_labels,
ignore_mismatched_sizes=True,
device_map=device,
)
method_map[method](
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
)
# `llm as reranker` defaults to not using pad_token
seq_cls_model.config.use_pad_token = use_pad_token
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
seq_cls_model.save_pretrained(path)
tokenizer.save_pretrained(path)
def parse_args():
parser = argparse.ArgumentParser(
description="Converting *ForCausalLM models to "
"*ForSequenceClassification models."
)
parser.add_argument(
"--model_name",
type=str,
default="BAAI/bge-reranker-v2-gemma",
help="Model name",
)
parser.add_argument(
"--classifier_from_tokens",
type=str,
default='["Yes"]',
help="classifier from tokens",
)
parser.add_argument(
"--method", type=str, default="no_post_processing", help="Converting converting"
)
parser.add_argument(
"--use-pad-token", action="store_true", help="Whether to use pad_token"
)
parser.add_argument(
"--path",
type=str,
default="./bge-reranker-v2-gemma-seq-cls",
help="Path to save converted model",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
converting(
model_name=args.model_name,
classifier_from_tokens=json.loads(args.classifier_from_tokens),
method=args.method,
use_pad_token=args.use_pad_token,
path=args.path,
)
embed_jina_embeddings_v3.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="jinaai/jina-embeddings-v3",
runner="pooling",
trust_remote_code=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Follow the white rabbit.", # English
"Sigue al conejo blanco.", # Spanish
"Suis le lapin blanc.", # French
"跟着白兔走。", # Chinese
"اتبع الأرنب الأبيض.", # Arabic
"Folge dem weißen Kaninchen.", # German
]
# Create an LLM.
# You should pass runner="pooling" for embedding models
llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
# Only text matching task is supported for now. See #16120
outputs = llm.embed(prompts)
# Print the outputs.
print("\nGenerated Outputs:")
print("Only text matching task is supported for now. See #16120")
print("-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
embeds_trimmed = (
(str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
)
print(
f"Prompt: {prompt!r} \n"
f"Embeddings for text matching: {embeds_trimmed} "
f"(size={len(embeds)})"
)
print("-" * 60)
if __name__ == "__main__":
args = parse_args()
main(args)
embed_matryoshka_fy.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from argparse import Namespace
from vllm import LLM, EngineArgs, PoolingParams
from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="jinaai/jina-embeddings-v3",
runner="pooling",
trust_remote_code=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Follow the white rabbit.", # English
"Sigue al conejo blanco.", # Spanish
"Suis le lapin blanc.", # French
"跟着白兔走。", # Chinese
"اتبع الأرنب الأبيض.", # Arabic
"Folge dem weißen Kaninchen.", # German
]
# Create an LLM.
# You should pass runner="pooling" for embedding models
llm = LLM(**vars(args))
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = llm.embed(prompts, pooling_params=PoolingParams(dimensions=32))
# Print the outputs.
print("\nGenerated Outputs:")
print("-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
embeds_trimmed = (
(str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
)
print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
print("-" * 60)
if __name__ == "__main__":
args = parse_args()
main(args)
ner.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(
model="boltuix/NeuroBERT-NER",
runner="pooling",
enforce_eager=True,
trust_remote_code=True,
)
return parser.parse_args()
def main(args: Namespace):
# Sample prompts.
prompts = [
"Barack Obama visited Microsoft headquarters in Seattle on January 2025."
]
# Create an LLM.
llm = LLM(**vars(args))
tokenizer = llm.get_tokenizer()
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
# Run inference
outputs = llm.encode(prompts)
for prompt, output in zip(prompts, outputs):
logits = output.outputs.data
predictions = logits.argmax(dim=-1)
# Map predictions to labels
tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids)
labels = [label_map[p.item()] for p in predictions]
# Print results
for token, label in zip(tokens, labels):
if token not in tokenizer.all_special_tokens:
print(f"{token:15} → {label}")
if __name__ == "__main__":
args = parse_args()
main(args)
qwen3_reranker.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from vllm import LLM
model_name = "Qwen/Qwen3-Reranker-0.6B"
# What is the difference between the official original version and one
# that has been converted into a sequence classification model?
# Qwen3-Reranker is a language model that doing reranker by using the
# logits of "no" and "yes" tokens.
# It needs to computing 151669 tokens logits, making this method extremely
# inefficient, not to mention incompatible with the vllm score API.
# A method for converting the original model into a sequence classification
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
# Models converted offline using this method can not only be more efficient
# and support the vllm score API, but also make the init parameters more
# concise, for example.
# llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", runner="pooling")
# If you want to load the official original version, the init parameters are
# as follows.
def get_llm() -> LLM:
"""Initializes and returns the LLM model for Qwen3-Reranker."""
return LLM(
model=model_name,
runner="pooling",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
)
# Why do we need hf_overrides for the official original version:
# vllm converts it to Qwen3ForSequenceClassification when loaded for
# better performance.
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
# to manually route to Qwen3ForSequenceClassification.
# - Then, we will extract the vector corresponding to classifier_from_token
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
# - Third, we will convert these two vectors into one vector. The use of
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.
# Please use the query_template and document_template to format the query and
# document for better reranker results.
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}"
def main() -> None:
instruction = (
"Given a web search query, retrieve relevant passages that answer the query"
)
queries = [
"What is the capital of China?",
"Explain gravity",
]
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
queries = [
query_template.format(prefix=prefix, instruction=instruction, query=query)
for query in queries
]
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
llm = get_llm()
outputs = llm.score(queries, documents)
print("-" * 30)
print([output.outputs.score for output in outputs])
print("-" * 30)
if __name__ == "__main__":
main()