yonishafir commited on
Commit
d929725
·
1 Parent(s): e834bcc
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +23 -91
  3. requirements.txt +8 -7
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ */__pycache__/
2
+ *.pyc
app.py CHANGED
@@ -3,20 +3,16 @@ import os
3
  import random
4
  import gradio as gr
5
 
6
-
7
  import cv2
8
  import torch
9
  import numpy as np
10
  from PIL import Image
11
 
12
  from transformers import CLIPVisionModelWithProjection
13
- from diffusers.utils import load_image
14
  from diffusers.models import ControlNetModel
15
- # from diffusers.image_processor import IPAdapterMaskProcessor
16
  from insightface.app import FaceAnalysis
17
- # import sys
18
- # import glob
19
- # import os
20
  import io
21
  import spaces
22
 
@@ -25,8 +21,8 @@ from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPip
25
  import pandas as pd
26
  import json
27
  import requests
28
- from PIL import Image
29
  from io import BytesIO
 
30
 
31
 
32
  def resize_img(input_image, max_side=1280, min_side=1024, size=None,
@@ -128,25 +124,6 @@ def calc_emb_cropped(image, app):
128
 
129
  return cropped_face_image
130
 
131
- def process_benchmark_csv(banchmark_csv_path):
132
- # Reading the first CSV file into a DataFrame
133
- df = pd.read_csv(banchmark_csv_path)
134
-
135
- # Drop any unnamed columns
136
- df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
137
-
138
- # Drop columns with all NaN values
139
- df.dropna(axis=1, how='all', inplace=True)
140
-
141
- # Drop rows with all NaN values
142
- df.dropna(axis=0, how='all', inplace=True)
143
-
144
- df = df.loc[df['High resolution'] == 1]
145
-
146
- df.reset_index(drop=True, inplace=True)
147
-
148
- return df
149
-
150
  def make_canny_condition(image, min_val=100, max_val=200, w_bilateral=True):
151
  if w_bilateral:
152
  image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
@@ -167,50 +144,45 @@ default_negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly
167
  CURRENT_LORA_NAME = None
168
 
169
  # Load face detection and recognition package
170
- app = FaceAnalysis(name='antelopev2', root='./', providers=['CPUExecutionProvider'])
171
  app.prepare(ctx_id=0, det_size=(640, 640))
172
 
173
 
174
  # download checkpoints
175
- from huggingface_hub import hf_hub_download
176
 
177
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="controlnet/config.json", local_dir="./checkpoints")
178
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="controlnet/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
179
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="ip-adapter.bin", local_dir="./checkpoints")
180
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="image_encoder/pytorch_model.bin", local_dir="./checkpoints")
181
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="image_encoder/config.json", local_dir="./checkpoints")
182
-
183
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/3D_illustration/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
184
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Avatar_internlm/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
185
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Characters/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
186
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Storyboards/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
187
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Vangogh_Vanilla/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
188
 
 
189
 
190
- # base_dir = "./instantID_ckpt/checkpoint_174000"
191
- # face_adapter = f'{base_dir}/pytorch_model.bin'
192
- # controlnet_path = f'{base_dir}/controlnet'
193
  face_adapter = f"./checkpoints/ip-adapter.bin"
194
  controlnet_path = f"./checkpoints/controlnet"
195
-
196
-
197
  base_model_path = f'briaai/BRIA-2.3'
198
  resolution = 1024
199
 
 
200
  controlnet_lnmks = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
201
-
202
  controlnet_canny = ControlNetModel.from_pretrained("briaai/BRIA-2.3-ControlNet-Canny",
203
- torch_dtype=torch.float16)
204
-
205
  controlnet = [controlnet_lnmks, controlnet_canny]
206
 
207
- device = "cuda" if torch.cuda.is_available() else "cpu"
208
 
209
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
210
  f"./checkpoints/image_encoder",
211
  torch_dtype=torch.float16,
212
  )
213
-
214
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
215
  base_model_path,
216
  controlnet=controlnet,
@@ -220,14 +192,13 @@ pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
220
 
221
  pipe = pipe.to(device)
222
 
223
- use_native_ip_adapter = True
224
- pipe.use_native_ip_adapter=use_native_ip_adapter
225
 
226
  pipe.load_ip_adapter_instantid(face_adapter)
227
 
228
  clip_embeds=None
229
 
230
-
231
  Loras_dict = {
232
  "":"",
233
  "Vangogh_Vanilla": "bold, dramatic brush strokes, vibrant colors, swirling patterns, intense, emotionally charged paintings of",
@@ -239,8 +210,6 @@ Loras_dict = {
239
 
240
  lora_names = Loras_dict.keys()
241
 
242
- lora_base_path = "./checkpoints/LoRAs"
243
-
244
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
245
  if randomize_seed:
246
  seed = random.randint(0, 99999999)
@@ -254,13 +223,11 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
254
  if image_path is None:
255
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
256
 
257
- # img = np.array(Image.open(image_path))[:,:,::-1]
258
  img = Image.open(image_path)
259
 
260
- face_image_orig = img #Image.open(BytesIO(response.content))
261
  face_image_cropped = calc_emb_cropped(face_image_orig, app)
262
  face_image = resize_img(face_image_cropped, max_side=resolution, min_side=resolution)
263
- # face_image_padded = resize_img(face_image_cropped, max_side=resolution, min_side=resolution, pad_to_max_side=True)
264
  face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
265
  face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
266
  face_emb = face_info['embedding']
@@ -305,19 +272,6 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
305
 
306
  generator = torch.Generator(device=device).manual_seed(seed)
307
 
308
- # if lora_name != "":
309
- # lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
310
- # pipe.load_lora_weights(lora_path)
311
- # pipe.fuse_lora(lora_scale)
312
- # pipe.enable_lora()
313
-
314
- # lora_prefix = Loras_dict[lora_name]
315
-
316
- # prompt = f"{lora_prefix} {prompt}"
317
-
318
- # print("Using LoRA: ", lora_name)
319
-
320
-
321
  if lora_name != CURRENT_LORA_NAME: # Check if LoRA needs to be changed
322
  if CURRENT_LORA_NAME is not None: # If a LoRA is already loaded, unload it
323
  pipe.disable_lora()
@@ -326,12 +280,13 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
326
  print(f"Unloaded LoRA: {CURRENT_LORA_NAME}")
327
 
328
  if lora_name != "": # Load the new LoRA if specified
 
329
  lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
330
  pipe.load_lora_weights(lora_path)
331
  pipe.fuse_lora(lora_scale)
332
  pipe.enable_lora()
333
 
334
- lora_prefix = Loras_dict[lora_name]
335
 
336
  print(f"Loaded new LoRA: {lora_name}")
337
 
@@ -339,7 +294,7 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
339
  CURRENT_LORA_NAME = lora_name
340
 
341
  if lora_name != "":
342
- full_prompt = f"{lora_prefix} {prompt}"
343
  else:
344
  full_prompt = prompt
345
 
@@ -348,9 +303,9 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
348
  prompt = full_prompt,
349
  negative_prompt = default_negative_prompt,
350
  image_embeds = face_emb,
351
- image = [face_kps, canny_img] if canny_scale>0.0 else face_kps,
352
  controlnet_conditioning_scale = [kps_scale, canny_scale] if canny_scale>0.0 else kps_scale,
353
- control_guidance_end = [1.0, 1.0] if canny_scale>0.0 else 1.0,
354
  ip_adapter_scale = ip_adapter_scale,
355
  num_inference_steps = num_steps,
356
  guidance_scale = guidance_scale,
@@ -358,12 +313,8 @@ def generate_image(image_path, prompt, num_steps, guidance_scale, seed, num_imag
358
  visual_prompt_embds = clip_embeds,
359
  cross_attention_kwargs = None,
360
  num_images_per_prompt=num_images,
361
- ).images #[0]
362
 
363
- # if lora_name != "":
364
- # pipe.disable_lora()
365
- # pipe.unfuse_lora()
366
- # pipe.unload_lora_weights()
367
 
368
  gc.collect()
369
  torch.cuda.empty_cache()
@@ -412,12 +363,7 @@ with gr.Blocks(css=css) as demo:
412
  lora_name = gr.Dropdown(choices=lora_names, label="LoRA", value="", info="Select a LoRA name from the list, not selecting any will disable LoRA.")
413
 
414
  submit = gr.Button("Submit", variant="primary")
415
-
416
- # use_lcm = gr.Checkbox(
417
- # label="Use LCM-LoRA to accelerate sampling", value=False,
418
- # info="Reduces sampling steps significantly, but may decrease quality.",
419
- # )
420
-
421
  with gr.Accordion(open=False, label="Advanced Options"):
422
  num_steps = gr.Slider(
423
  label="Number of sample steps",
@@ -436,7 +382,7 @@ with gr.Blocks(css=css) as demo:
436
  num_images = gr.Slider(
437
  label="Number of output images",
438
  minimum=1,
439
- maximum=3,
440
  step=1,
441
  value=1,
442
  )
@@ -491,22 +437,8 @@ with gr.Blocks(css=css) as demo:
491
  inputs=[img_file, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, lora_name, lora_scale],
492
  outputs=[gallery]
493
  )
494
-
495
- # use_lcm.input(
496
- # fn=toggle_lcm_ui,
497
- # inputs=[use_lcm],
498
- # outputs=[num_steps, guidance_scale],
499
- # queue=False,
500
- # )
501
-
502
- # gr.Examples(
503
- # examples=get_example(),
504
- # inputs=[img_file],
505
- # run_on_click=True,
506
- # fn=run_example,
507
- # outputs=[gallery],
508
- # )
509
 
510
  gr.Markdown(Footer)
511
 
 
512
  demo.launch()
 
3
  import random
4
  import gradio as gr
5
 
 
6
  import cv2
7
  import torch
8
  import numpy as np
9
  from PIL import Image
10
 
11
  from transformers import CLIPVisionModelWithProjection
 
12
  from diffusers.models import ControlNetModel
13
+
14
  from insightface.app import FaceAnalysis
15
+
 
 
16
  import io
17
  import spaces
18
 
 
21
  import pandas as pd
22
  import json
23
  import requests
 
24
  from io import BytesIO
25
+ from huggingface_hub import hf_hub_download
26
 
27
 
28
  def resize_img(input_image, max_side=1280, min_side=1024, size=None,
 
124
 
125
  return cropped_face_image
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def make_canny_condition(image, min_val=100, max_val=200, w_bilateral=True):
128
  if w_bilateral:
129
  image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
 
144
  CURRENT_LORA_NAME = None
145
 
146
  # Load face detection and recognition package
147
+ app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
148
  app.prepare(ctx_id=0, det_size=(640, 640))
149
 
150
 
151
  # download checkpoints
 
152
 
153
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="controlnet/config.json", local_dir="./checkpoints")
154
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="controlnet/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
155
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="ip-adapter.bin", local_dir="./checkpoints")
156
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="image_encoder/pytorch_model.bin", local_dir="./checkpoints")
157
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="image_encoder/config.json", local_dir="./checkpoints")
158
+ # Download Lora weights
159
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/3D_illustration/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
160
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Avatar_internlm/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
161
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Characters/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
162
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Storyboards/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
163
  hf_hub_download(repo_id="briaai/ID_preservation_2.3", filename="LoRAs/Vangogh_Vanilla/pytorch_lora_weights.safetensors", local_dir="./checkpoints")
164
 
165
+ device = "cuda" if torch.cuda.is_available() else "cpu"
166
 
167
+ # ckpts paths
 
 
168
  face_adapter = f"./checkpoints/ip-adapter.bin"
169
  controlnet_path = f"./checkpoints/controlnet"
170
+ lora_base_path = "./checkpoints/LoRAs"
 
171
  base_model_path = f'briaai/BRIA-2.3'
172
  resolution = 1024
173
 
174
+ # Load ControlNet models
175
  controlnet_lnmks = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
 
176
  controlnet_canny = ControlNetModel.from_pretrained("briaai/BRIA-2.3-ControlNet-Canny",
177
+ torch_dtype=torch.float16)
178
+
179
  controlnet = [controlnet_lnmks, controlnet_canny]
180
 
 
181
 
182
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(
183
  f"./checkpoints/image_encoder",
184
  torch_dtype=torch.float16,
185
  )
 
186
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
187
  base_model_path,
188
  controlnet=controlnet,
 
192
 
193
  pipe = pipe.to(device)
194
 
195
+ # use_native_ip_adapter = True
196
+ pipe.use_native_ip_adapter=True
197
 
198
  pipe.load_ip_adapter_instantid(face_adapter)
199
 
200
  clip_embeds=None
201
 
 
202
  Loras_dict = {
203
  "":"",
204
  "Vangogh_Vanilla": "bold, dramatic brush strokes, vibrant colors, swirling patterns, intense, emotionally charged paintings of",
 
210
 
211
  lora_names = Loras_dict.keys()
212
 
 
 
213
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
214
  if randomize_seed:
215
  seed = random.randint(0, 99999999)
 
223
  if image_path is None:
224
  raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
225
 
 
226
  img = Image.open(image_path)
227
 
228
+ face_image_orig = img
229
  face_image_cropped = calc_emb_cropped(face_image_orig, app)
230
  face_image = resize_img(face_image_cropped, max_side=resolution, min_side=resolution)
 
231
  face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
232
  face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
233
  face_emb = face_info['embedding']
 
272
 
273
  generator = torch.Generator(device=device).manual_seed(seed)
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  if lora_name != CURRENT_LORA_NAME: # Check if LoRA needs to be changed
276
  if CURRENT_LORA_NAME is not None: # If a LoRA is already loaded, unload it
277
  pipe.disable_lora()
 
280
  print(f"Unloaded LoRA: {CURRENT_LORA_NAME}")
281
 
282
  if lora_name != "": # Load the new LoRA if specified
283
+ # pipe.enable_model_cpu_offload()
284
  lora_path = os.path.join(lora_base_path, lora_name, "pytorch_lora_weights.safetensors")
285
  pipe.load_lora_weights(lora_path)
286
  pipe.fuse_lora(lora_scale)
287
  pipe.enable_lora()
288
 
289
+ # lora_prefix = Loras_dict[lora_name]
290
 
291
  print(f"Loaded new LoRA: {lora_name}")
292
 
 
294
  CURRENT_LORA_NAME = lora_name
295
 
296
  if lora_name != "":
297
+ full_prompt = f"{Loras_dict[lora_name]} + " " + {prompt}"
298
  else:
299
  full_prompt = prompt
300
 
 
303
  prompt = full_prompt,
304
  negative_prompt = default_negative_prompt,
305
  image_embeds = face_emb,
306
+ image = [face_kps, canny_img] if canny_scale > 0.0 else face_kps,
307
  controlnet_conditioning_scale = [kps_scale, canny_scale] if canny_scale>0.0 else kps_scale,
308
+ # control_guidance_end = [1.0, 1.0] if canny_scale>0.0 else 1.0,
309
  ip_adapter_scale = ip_adapter_scale,
310
  num_inference_steps = num_steps,
311
  guidance_scale = guidance_scale,
 
313
  visual_prompt_embds = clip_embeds,
314
  cross_attention_kwargs = None,
315
  num_images_per_prompt=num_images,
316
+ ).images
317
 
 
 
 
 
318
 
319
  gc.collect()
320
  torch.cuda.empty_cache()
 
363
  lora_name = gr.Dropdown(choices=lora_names, label="LoRA", value="", info="Select a LoRA name from the list, not selecting any will disable LoRA.")
364
 
365
  submit = gr.Button("Submit", variant="primary")
366
+
 
 
 
 
 
367
  with gr.Accordion(open=False, label="Advanced Options"):
368
  num_steps = gr.Slider(
369
  label="Number of sample steps",
 
382
  num_images = gr.Slider(
383
  label="Number of output images",
384
  minimum=1,
385
+ maximum=2,
386
  step=1,
387
  value=1,
388
  )
 
437
  inputs=[img_file, prompt, num_steps, guidance_scale, seed, num_images, ip_adapter_scale, kps_scale, canny_scale, lora_name, lora_scale],
438
  outputs=[gallery]
439
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
 
441
  gr.Markdown(Footer)
442
 
443
+ # demo.launch(server_port=7865)
444
  demo.launch()
requirements.txt CHANGED
@@ -1,15 +1,16 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu121
2
- torch
3
- torchvision
4
- transformers
5
- accelerate
6
  ftfy
7
  numpy
8
  matplotlib
9
  uuid
10
  opencv-python
11
- diffusers==0.26.0
12
  spaces
13
- insightface
14
- onnxruntime
 
15
  peft==0.12.0
 
1
  --extra-index-url https://download.pytorch.org/whl/cu121
2
+ torch==2.4.0
3
+ torchvision==0.19.0
4
+ transformers==4.43.4
5
+ accelerate==0.33.0
6
  ftfy
7
  numpy
8
  matplotlib
9
  uuid
10
  opencv-python
11
+ diffusers==0.29.2
12
  spaces
13
+ insightface==0.7.3
14
+ onnx==1.16.2
15
+ onnxruntime==1.18.1
16
  peft==0.12.0