FLUX.1-Merged
This repository provides the merged params for black-forest-labs/FLUX.1-dev
and black-forest-labs/FLUX.1-schnell
.
Merge & Upload
from diffusers import FluxTransformer2DModel
from huggingface_hub import snapshot_download
from huggingface_hub import upload_folder
from accelerate import init_empty_weights
from diffusers.models.model_loading_utils import load_model_dict_into_meta
import safetensors.torch
import glob
import torch
# Initialize the model with empty weights
with init_empty_weights():
config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config)
# Download the model checkpoints
dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*")
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")
# Get the paths to the model shards
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))
# Merge the state dictionaries
merged_state_dict = {}
guidance_state_dict = {}
for i in range(len(dev_shards)):
state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])
keys = list(state_dict_dev_temp.keys())
for k in keys:
if "guidance" not in k:
merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2
else:
guidance_state_dict[k] = state_dict_dev_temp.pop(k)
if len(state_dict_dev_temp) > 0:
raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
if len(state_dict_schnell_temp) > 0:
raise ValueError(f"There should not be any residue but got: {list(state_dict_schnell_temp.keys())}.")
# Update the merged state dictionary with the guidance state dictionary
merged_state_dict.update(guidance_state_dict)
# Load the merged state dictionary into the model
load_model_dict_into_meta(model, merged_state_dict)
# Save the merged model
model.to(torch.bfloat16).save_pretrained("transformer")
# Upload the merged model to the Hugging Face Hub
upload_folder(
repo_id="prithivMLmods/Flux.1-Merged", # Replace with your Hugging Face username and desired repo name
folder_path="transformer",
path_in_repo="transformer",
)
Inference
from diffusers import FluxPipeline
import torch
pipeline = FluxPipeline.from_pretrained(
"prithivMLmods/Flux.1-Merged", torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
prompt="a tiny astronaut hatching from an egg on the moon",
guidance_scale=3.5,
num_inference_steps=4,
height=880,
width=1184,
max_sequence_length=512,
generator=torch.manual_seed(0),
).images[0]
image.save("merged_flux.png")
- Downloads last month
- 0
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.
Model tree for prithivMLmods/Flux.1-Merged
Merge model
this model