Spaces:
Sleeping
Sleeping
import os | |
os.system('pip install ./transformers-4.47.0.dev0-py3-none-any.whl') | |
import gradio as gr | |
import PIL.Image | |
import transformers | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
import torch | |
import string | |
import functools | |
import re | |
import numpy as np | |
import spaces | |
#adapter_id = "merve/paligemma2-3b-vqav2" | |
#adapter_id = "google/paligemma2-10b-pt-448" | |
#model_id = "google/paligemma2-10b-pt-448" | |
adapter_id = "google/paligemma2-3b-ft-docci-448" | |
model_id = "google/paligemma2-3b-ft-docci-448" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dtype = torch.bfloat16 | |
model = PaliGemmaForConditionalGeneration.from_pretrained(adapter_id, device_map='cuda', torch_dtype=dtype).eval() | |
processor = PaliGemmaProcessor.from_pretrained(model_id) | |
###### Transformers Inference | |
def infer( | |
text, | |
image: PIL.Image.Image, | |
max_new_tokens: int | |
) -> str: | |
text = "answer en " + text | |
inputs = processor(text=text, images=image, return_tensors="pt").to(device=device, dtype=dtype) | |
with torch.inference_mode(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=False | |
) | |
result = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return result[0][len(text):].lstrip("\n") | |
######## Demo | |
INTRO_TEXT = """## PaliGemma 2 demo\n\n | |
| [Github](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) | |
| [Blogpost](https://huggingface.co/blog/paligemma) | |
| [Fine-tuning notebook](https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb) | |
|\n\n | |
PaliGemma 2 is an open vision-language model by Google, inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and | |
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) | |
vision model and the [Gemma 2](https://arxiv.org/abs/2408.00118) language model. PaliGemma 2 is designed as a versatile | |
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question | |
answering, text reading, object detection and object segmentation. | |
\n\n | |
This space includes a model LoRA fine-tuned by the team at Hugging Face on VQAv2, inferred using transformers. | |
See the [Blogpost](https://huggingface.co/blog/paligemma2), the project | |
[README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) and the | |
[fine-tuning notebook](https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb) | |
for detailed information about how to use and fine-tune PaliGemma and PaliGemma 2 models. | |
\n\n | |
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(INTRO_TEXT) | |
with gr.Column(): | |
image = gr.Image(label="Input Image", type="pil", height=400) | |
question = gr.Text(label="Question") | |
tokens = gr.Slider( | |
label="Max New Tokens", | |
info="Set to larger for longer generation.", | |
minimum=20, | |
maximum=1600, | |
value=256, | |
step=10, | |
) | |
caption_btn = gr.Button(value="Submit") | |
text_output = gr.Text(label="Text Output") | |
caption_inputs = [ | |
question, | |
image, | |
tokens | |
] | |
caption_outputs = [ | |
text_output | |
] | |
caption_btn.click( | |
fn=infer, | |
inputs=caption_inputs, | |
outputs=caption_outputs, | |
) | |
examples = [ | |
["What is the graphic about?", "./howto.jpg", 60], | |
["What is the password", "./password.jpg", 20], | |
["Who is in this image?", "./examples_bowie.jpg", 80], | |
] | |
gr.Markdown("Example images are licensed CC0 by [akolesnikoff@](https://github.com/akolesnikoff), [mbosnjak@](https://github.com/mbosnjak), [maximneumann@](https://github.com/maximneumann) and [merve](https://huggingface.co/merve).") | |
gr.Examples( | |
examples=examples, | |
inputs=caption_inputs, | |
) | |
######### | |
if __name__ == "__main__": | |
demo.queue(max_size=10).launch(debug=True) |