import gc import os import random import gradio as gr import cv2 import torch import numpy as np from PIL import Image from transformers import CLIPVisionModelWithProjection from diffusers.models import ControlNetModel from huggingface_hub import snapshot_download from insightface.app import FaceAnalysis import io import spaces from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps import pandas as pd import json import requests from io import BytesIO from huggingface_hub import hf_hub_download, HfApi def resize_img(input_image, max_side=1280, min_side=1024, size=None, pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): w, h = input_image.size if size is not None: w_resize_new, h_resize_new = size else: ratio = min_side / min(h, w) w, h = round(ratio*w), round(ratio*h) ratio = max_side / max(h, w) input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number input_image = input_image.resize([w_resize_new, h_resize_new], mode) if pad_to_max_side: res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 offset_x = (max_side - w_resize_new) // 2 offset_y = (max_side - h_resize_new) // 2 res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) input_image = Image.fromarray(res) return input_image def process_image_by_bbox_larger(input_image, bbox_xyxy, min_bbox_ratio=0.2): """ Process an image based on a bounding box, cropping and resizing as necessary. Parameters: - input_image: PIL Image object. - bbox_xyxy: Tuple (x1, y1, x2, y2) representing the bounding box coordinates. Returns: - A processed image cropped and resized to 1024x1024 if the bounding box is valid, or None if the bounding box does not meet the required size criteria. """ # Constants target_size = 1024 # min_bbox_ratio = 0.2 # Bounding box should be at least 20% of the crop # Extract bounding box coordinates x1, y1, x2, y2 = bbox_xyxy bbox_w = x2 - x1 bbox_h = y2 - y1 # Calculate the area of the bounding box bbox_area = bbox_w * bbox_h # Start with the smallest square crop that allows bbox to be at least 20% of the crop area crop_size = max(bbox_w, bbox_h) initial_crop_area = crop_size * crop_size while (bbox_area / initial_crop_area) < min_bbox_ratio: crop_size += 10 # Gradually increase until bbox is at least 20% of the area initial_crop_area = crop_size * crop_size # Once the minimum condition is satisfied, try to expand the crop further max_possible_crop_size = min(input_image.width, input_image.height) while crop_size < max_possible_crop_size: # Calculate a potential new area new_crop_size = crop_size + 10 new_crop_area = new_crop_size * new_crop_size if (bbox_area / new_crop_area) < min_bbox_ratio: break # Stop if expanding further violates the 20% rule crop_size = new_crop_size # Determine the center of the bounding box center_x = (x1 + x2) // 2 center_y = (y1 + y2) // 2 # Calculate the crop coordinates centered around the bounding box crop_x1 = max(0, center_x - crop_size // 2) crop_y1 = max(0, center_y - crop_size // 2) crop_x2 = min(input_image.width, crop_x1 + crop_size) crop_y2 = min(input_image.height, crop_y1 + crop_size) # Ensure the crop is square, adjust if it goes out of image bounds if crop_x2 - crop_x1 != crop_y2 - crop_y1: side_length = min(crop_x2 - crop_x1, crop_y2 - crop_y1) crop_x2 = crop_x1 + side_length crop_y2 = crop_y1 + side_length # Crop the image cropped_image = input_image.crop((crop_x1, crop_y1, crop_x2, crop_y2)) # Resize the cropped image to 1024x1024 resized_image = cropped_image.resize((target_size, target_size), Image.LANCZOS) return resized_image def calc_emb_cropped(image, app, min_bbox_ratio=0.2): face_image = image.copy() face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) face_info = face_info[0] cropped_face_image = process_image_by_bbox_larger(face_image, face_info["bbox"], min_bbox_ratio=min_bbox_ratio) return cropped_face_image def make_canny_condition(image, min_val=100, max_val=200, w_bilateral=True): if w_bilateral: image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) bilateral_filtered_image = cv2.bilateralFilter(image, d=9, sigmaColor=75, sigmaSpace=75) image = cv2.Canny(bilateral_filtered_image, min_val, max_val) else: image = np.array(image) image = cv2.Canny(image, min_val, max_val) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) image = Image.fromarray(image) return image def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, 99999999) return seed default_negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers" # Download face encoder snapshot_download( "fal/AuraFace-v1", local_dir="models/auraface", ) app = FaceAnalysis( name="auraface", providers=["CUDAExecutionProvider", "CPUExecutionProvider"], root=".", ) app.prepare(ctx_id=0, det_size=(640, 640)) # download checkpoints print("Downloading checkpoints") hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="checkpoint_105000/controlnet/config.json", local_dir="./checkpoints") hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="checkpoint_105000/controlnet/diffusion_pytorch_model.safetensors", local_dir="./checkpoints") hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="checkpoint_105000/ip-adapter.bin", local_dir="./checkpoints") hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="image_encoder/pytorch_model.bin", local_dir="./checkpoints") hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="image_encoder/config.json", local_dir="./checkpoints") # Download Lora weights hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="LoRAs/3D_avatar/pytorch_lora_weights.safetensors", local_dir=".") hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="LoRAs/coloringbook/pytorch_lora_weights.safetensors", local_dir=".") hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="LoRAs/One_line_portraits_Light/pytorch_lora_weights.safetensors", local_dir=".") hf_hub_download(repo_id="briaai/BRIA-2.3-ID_Preservation", filename="LoRAs/Stickers/pytorch_lora_weights.safetensors", local_dir=".") device = "cuda" if torch.cuda.is_available() else "cpu" # ckpts paths face_adapter = f"./checkpoints/checkpoint_105000/ip-adapter.bin" controlnet_path = f"./checkpoints/checkpoint_105000/controlnet" base_model_path = f'briaai/BRIA-2.3' lora_base_path = f"./LoRAs" resolution = 1024 # Load ControlNet models controlnet_lnmks = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) controlnet_canny = ControlNetModel.from_pretrained("briaai/BRIA-2.3-ControlNet-Canny", torch_dtype=torch.float16) controlnet = [controlnet_lnmks, controlnet_canny] image_encoder = CLIPVisionModelWithProjection.from_pretrained( f"./checkpoints/image_encoder", torch_dtype=torch.float16, ) pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( base_model_path, controlnet=controlnet, torch_dtype=torch.float16, image_encoder=image_encoder # For compatibility issues - needs to be there ) pipe = pipe.to(device) # use_native_ip_adapter = True pipe.use_native_ip_adapter=True pipe.load_ip_adapter_instantid(face_adapter) clip_embeds=None Loras_dict = { "":"", "One_line_portraits_Light": "An illustration of ", "3D_avatar": "An illustration of ", "coloringbook": "An illustration of ", "Stickers": "An illustration of " } lora_names = Loras_dict.keys() @spaces.GPU def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, lora_name, lora_scale, progress=gr.Progress(track_tqdm=True)): # def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, progress=gr.Progress(track_tqdm=True)): global CURRENT_LORA_NAME # Use the global variable to track LoRA CURRENT_LORA_NAME = None if image_path is None: raise gr.Error(f"Cannot find any input face image! Please upload a face image.") img = Image.open(image_path) face_image_orig = img face_image_cropped = calc_emb_cropped(face_image_orig, app) face_image = resize_img(face_image_cropped, max_side=resolution, min_side=resolution) face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face face_emb = face_info['embedding'] face_kps = draw_kps(face_image, face_info['kps']) if canny_scale>0.0: # Convert PIL image to a file-like object image_file = io.BytesIO() face_image_cropped.save(image_file, format='JPEG') # Save in the desired format (e.g., 'JPEG' or 'PNG') image_file.seek(0) # Move to the start of the BytesIO stream url = "https://engine.prod.bria-api.com/v1/background/remove" payload = {} files = [ ('file', ('image_name.jpeg', image_file, 'image/jpeg')) # Specify file name, file-like object, and MIME type ] headers = { 'api_token': os.getenv('BRIA_RMBG_TOKEN') # Securely retrieve the token } response = requests.request("POST", url, headers=headers, data=payload, files=files) print(response.text) response_json = json.loads(response.content.decode('utf-8')) img = requests.get(response_json['result_url']) processed_image = Image.open(io.BytesIO(img.content)) # Assuming `processed_image` is the RGBA image returned if processed_image.mode == 'RGBA': # Create a white background image white_background = Image.new("RGB", processed_image.size, (255, 255, 255)) # Composite the RGBA image over the white background face_image = Image.alpha_composite(white_background.convert('RGBA'), processed_image).convert('RGB') else: face_image = processed_image.convert('RGB') # If already RGB, just ensure mode is correct canny_img = make_canny_condition(face_image, min_val=20, max_val=40, w_bilateral=True) generator = torch.Generator(device=device).manual_seed(seed) # full_prompt = prompt if lora_name != CURRENT_LORA_NAME: # Check if LoRA needs to be changed if CURRENT_LORA_NAME is not None: # If a LoRA is already loaded, unload it pipe.disable_lora() pipe.unfuse_lora() pipe.unload_lora_weights() print(f"Unloaded LoRA: {CURRENT_LORA_NAME}") if lora_name != "": # Load the new LoRA if specified # pipe.enable_model_cpu_offload() lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors") pipe.load_lora_weights(lora_path) pipe.fuse_lora(lora_scale) pipe.enable_lora() # lora_prefix = Loras_dict[lora_name] print(f"Loaded new LoRA: {lora_name}") # Update the current LoRA name CURRENT_LORA_NAME = lora_name if lora_name != "": full_prompt = f"{Loras_dict[lora_name]} + " " + {prompt}" else: full_prompt = prompt print("Start inference...") images = pipe( prompt = full_prompt, negative_prompt = default_negative_prompt, image_embeds = face_emb, image = [face_kps, canny_img] if canny_scale > 0.0 else face_kps, controlnet_conditioning_scale = [kps_scale, canny_scale] if canny_scale>0.0 else kps_scale, # control_guidance_end = [1.0, 1.0] if canny_scale>0.0 else 1.0, ip_adapter_scale = ip_adapter_scale, num_inference_steps = num_steps, guidance_scale = guidance_scale, generator = generator, visual_prompt_embds = clip_embeds, cross_attention_kwargs = None, num_images_per_prompt=num_images, ).images gc.collect() torch.cuda.empty_cache() columns = 1 if num_images == 1 else 2 return gr.update(value=images, columns=columns) # return images ### Description title = r"""