用 🤗 transformers, 🤗 datasets 和 FAISS 嵌入多模态数据进行相似度搜索
嵌入对于语义压缩信息很有用。它们可用于执行相似性搜索、零样本分类或简单地训练新模型。相似性搜索用例包括在电子商务中搜索类似的产品、在社交媒体中搜索内容等等。
本 notebook 指导你使用 🤗transformers、🤗datasets 和 FAISS 从特征提取模型创建和索引嵌入,以便稍后使用它们进行相似性搜索。
让我们安装必要的库。
!pip install -q datasets faiss-gpu transformers sentencepiece
对于本教程,我们将会使用 CLIP 模型来提取特征。 CLIP 是一个革命性的模型,其成功的将文本数据和图片数据两个模态联合起来训练。
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
import faiss
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
processor = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch16")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch16")
载入数据集,为了保持 notebook 的轻量性,我们这里会使用一个小型字幕数据集 jmhessel/newyorker_caption_contest.
from datasets import load_dataset
ds = load_dataset("jmhessel/newyorker_caption_contest", "explanation")
让我们来看一个例子
>>> ds["train"][0]["image"]
ds["train"][0]["image_description"]
我们没必要去写任何函数去嵌入例子或创建索引。 🤗 datasets 库的 FAISS 组件抽象这些过程。我们可以仅仅简单使用 dataset 的 map
方法就可以创建一个新的带有每个例子嵌入的列,就像下面所示。让我们针对提示列中的文本特征创建一个嵌入。
dataset = ds["train"]
ds_with_embeddings = dataset.map(
lambda example: {
"embeddings": model.get_text_features(
**tokenizer([example["image_description"]], truncation=True, return_tensors="pt").to("cuda")
)[0]
.detach()
.cpu()
.numpy()
}
)
ds_with_embeddings.add_faiss_index(column="embeddings")
我们可以同样处理图像嵌入
ds_with_embeddings = ds_with_embeddings.map(
lambda example: {
"image_embeddings": model.get_image_features(**processor([example["image"]], return_tensors="pt").to("cuda"))[
0
]
.detach()
.cpu()
.numpy()
}
)
ds_with_embeddings.add_faiss_index(column="image_embeddings")
用文本提示( prompts )查询相关数据
我们现在可以用文本或者图片查询数据集来获取相似的项目
prmt = "a snowy day"
prmt_embedding = (
model.get_text_features(**tokenizer([prmt], return_tensors="pt", truncation=True).to("cuda"))[0]
.detach()
.cpu()
.numpy()
)
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples("embeddings", prmt_embedding, k=1)
>>> def downscale_images(image):
... width = 200
... ratio = width / float(image.size[0])
... height = int((float(image.size[1]) * float(ratio)))
... img = image.resize((width, height), Image.Resampling.LANCZOS)
... return img
>>> images = [downscale_images(image) for image in retrieved_examples["image"]]
>>> # see the closest text and image
>>> print(retrieved_examples["image_description"])
>>> display(images[0])
['A man is in the snow. A boy with a huge snow shovel is there too. They are outside a house.']
用图片提示( prompts )来查询数据
图片相似度推理也类似,你只需要调用 get_image_features
函数即可。
>>> import requests
>>> # image of a beaver
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/beaver.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> display(downscale_images(image))
搜索相似的图片
img_embedding = (
model.get_image_features(**processor([image], return_tensors="pt", truncation=True).to("cuda"))[0]
.detach()
.cpu()
.numpy()
)
scores, retrieved_examples = ds_with_embeddings.get_nearest_examples("image_embeddings", img_embedding, k=1)
显示与海狸图像最相似的图像。
>>> images = [downscale_images(image) for image in retrieved_examples["image"]]
>>> # see the closest text and image
>>> print(retrieved_examples["image_description"])
>>> display(images[0])
['Salmon swim upstream but they see a grizzly bear and are in shock. The bear has a smug look on his face when he sees the salmon.']
保存,推送,加载嵌入( embeddings )
我们可以用 save_faiss_index
函数储存数据集的嵌入。
ds_with_embeddings.save_faiss_index("embeddings", "embeddings/embeddings.faiss")
ds_with_embeddings.save_faiss_index("image_embeddings", "embeddings/image_embeddings.faiss")
去储存一个数据集仓库的嵌入是一个很好的练习,所以我们将在那里创建一个,并将我们的嵌入稍后推送到那里。
我们会登录 Hugging Face Hub, 创建一个数据集仓库,推送我们的所以,然后使用 snapshot_download
加载。
from huggingface_hub import HfApi, notebook_login, snapshot_download
notebook_login()
from huggingface_hub import HfApi
api = HfApi()
api.create_repo("merve/faiss_embeddings", repo_type="dataset")
api.upload_folder(
folder_path="./embeddings",
repo_id="merve/faiss_embeddings",
repo_type="dataset",
)
snapshot_download(repo_id="merve/faiss_embeddings", repo_type="dataset", local_dir="downloaded_embeddings")
我们可以使用 load_faiss_index
将嵌入加载到没有嵌入的数据集中。
ds = ds["train"]
ds.load_faiss_index("embeddings", "./downloaded_embeddings/embeddings.faiss")
# infer again
prmt = "people under the rain"
prmt_embedding = (
model.get_text_features(**tokenizer([prmt], return_tensors="pt", truncation=True).to("cuda"))[0]
.detach()
.cpu()
.numpy()
)
scores, retrieved_examples = ds.get_nearest_examples("embeddings", prmt_embedding, k=1)
>>> display(retrieved_examples["image"][0])