Mask2Former for Semantic Segmentation
This repository contains the Mask2Former
model fine-tuned for semantic segmentation tasks. The model can be used to predict segmentation masks on input images and is based on the facebook/mask2former-swin-large-cityscapes-semantic
pre-trained model.
Model Overview
Mask2Former is a general-purpose framework for mask prediction tasks, including:
- Semantic Segmentation
- Instance Segmentation
- Panoptic Segmentation
This version has been fine-tuned and optimized for semantic segmentation tasks. You can use it for tasks such as road scene understanding, autonomous driving, and other segmentation-related applications.
How to Use the Model
You can use this model with the transformers
library from Hugging Face. Below is an example to load the model, process an image, and visualize the output.
Installation
First, ensure you have the required libraries installed:
pip install transformers torch torchvision pillow matplotlib
How to use
Here is how to use this model:
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from PIL import Image
import torch
import matplotlib.pyplot as plt
# Load the processor and model
model_name = "saninmohammedn/mask2former-deployment"
processor = AutoImageProcessor.from_pretrained(model_name)
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name)
# Load an input image
image_path = "your_image.jpg" # Replace with your image path
image = Image.open(image_path).convert("RGB")
# Prepare the image for the model
inputs = processor(images=image, return_tensors="pt")
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process the predicted segmentation map
predicted_map = processor.post_process_semantic_segmentation(
outputs, target_sizes=[image.size[::-1]]
)[0].cpu().numpy()
# Visualize the input and predicted segmentation map
plt.figure(figsize=(10, 5))
# Display original image
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis("off")
# Display predicted segmentation map
plt.subplot(1, 2, 2)
plt.imshow(predicted_map, cmap="jet")
plt.title("Predicted Segmentation Map")
plt.axis("off")
plt.tight_layout()
plt.show()
- Downloads last month
- 24