Avijit Ghosh commited on
Commit
956fa05
·
1 Parent(s): e28cd55

added files

Browse files
Files changed (3) hide show
  1. app.py +161 -0
  2. css.py +17 -0
  3. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import AutoPipelineForText2Image
4
+ from transformers import BlipProcessor, BlipForConditionalGeneration
5
+ from pathlib import Path
6
+ import stone
7
+ import requests
8
+ import io
9
+ import os
10
+ from PIL import Image
11
+ import spaces
12
+
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from matplotlib.colors import hex2color
16
+
17
+
18
+ pipeline_text2image = None
19
+
20
+ @spaces.GPU
21
+ def loadpipeline():
22
+ global pipeline_text2image
23
+ pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
24
+ "stabilityai/sdxl-turbo",
25
+ torch_dtype=torch.float16,
26
+ variant="fp16",
27
+ )
28
+ pipeline_text2image = pipeline_text2image.to("cuda")
29
+
30
+ loadpipeline()
31
+
32
+ @spaces.GPU
33
+ def getimgen(prompt):
34
+
35
+ return pipeline_text2image(
36
+ prompt=prompt,
37
+ guidance_scale=0.0,
38
+ num_inference_steps=2
39
+ ).images[0]
40
+
41
+ blip_processor = None
42
+
43
+ @spaces.GPU
44
+ def loadblip():
45
+ global blip_processor
46
+ global blip_model
47
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
48
+ blip_model = BlipForConditionalGeneration.from_pretrained(
49
+ "Salesforce/blip-image-captioning-large",
50
+ torch_dtype=torch.float16
51
+ ).to("cuda")
52
+
53
+ loadblip()
54
+
55
+ @spaces.GPU
56
+ def blip_caption_image(image, prefix):
57
+ inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
58
+ out = blip_model.generate(**inputs)
59
+ return blip_processor.decode(out[0], skip_special_tokens=True)
60
+
61
+ def genderfromcaption(caption):
62
+ cc = caption.split()
63
+ if "man" in cc or "boy" in cc:
64
+ return "Man"
65
+ elif "woman" in cc or "girl" in cc:
66
+ return "Woman"
67
+ return "Unsure"
68
+
69
+ def genderplot(genlist):
70
+ order = ["Man", "Woman", "Unsure"]
71
+
72
+ # Sort the list based on the order of keys
73
+ words = sorted(genlist, key=lambda x: order.index(x))
74
+
75
+ # Define colors for each category
76
+ colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
77
+
78
+ # Map each word to its corresponding color
79
+ word_colors = [colors[word] for word in words]
80
+
81
+ # Plot the colors in a grid with reduced spacing
82
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
83
+
84
+ # Adjust spacing between subplots
85
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
86
+
87
+ for i, ax in enumerate(axes.flat):
88
+ ax.set_axis_off()
89
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
90
+
91
+ return fig
92
+
93
+ def skintoneplot(hex_codes):
94
+ # Convert hex codes to RGB values
95
+ rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
96
+
97
+ # Calculate luminance for each color
98
+ luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
99
+
100
+ # Sort hex codes based on luminance in descending order (dark to light)
101
+ sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
102
+
103
+ # Plot the colors in a grid with reduced spacing
104
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
105
+
106
+ # Adjust spacing between subplots
107
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
108
+
109
+ for i, ax in enumerate(axes.flat):
110
+ ax.set_axis_off()
111
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
112
+
113
+ return fig
114
+
115
+ @spaces.GPU
116
+ def generate_images_plots(prompt):
117
+ foldername = "temp"
118
+ # Generate 10 images
119
+ images = [getimgen(prompt) for _ in range(10)]
120
+
121
+ Path(foldername).mkdir(parents=True, exist_ok=True)
122
+
123
+ genders = []
124
+ skintones = []
125
+
126
+ for image, i in zip(images, range(10)):
127
+ prompt_prefix = "photo of a "
128
+ caption = blip_caption_image(image, prefix=prompt_prefix)
129
+ image.save(f"{foldername}/image_{i}.png")
130
+ try:
131
+ skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
132
+ tone = skintoneres['faces'][0]['dominant_colors'][0]['color']
133
+ skintones.append(tone)
134
+ except:
135
+ skintones.append(None)
136
+
137
+ genders.append(genderfromcaption(caption))
138
+
139
+ print(genders, skintones)
140
+
141
+ return images, skintoneplot(skintones), genderplot(genders)
142
+
143
+
144
+ with gr.Blocks(title = "Skin Tone and Gender bias in SDXL Demo - Inference API") as demo:
145
+
146
+ gr.Markdown("# Skin Tone and Gender bias in SDXL Demo")
147
+
148
+ prompt = gr.Textbox(label="Enter the Prompt")
149
+ gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery",
150
+ columns=[5], rows=[2], object_fit="contain", height="auto")
151
+ btn = gr.Button("Generate images", scale=0)
152
+ with gr.Row(equal_height=True):
153
+ skinplot = gr.Plot(label="Skin Tone")
154
+ genplot = gr.Plot(label="Gender")
155
+
156
+
157
+ btn.click(generate_images_plots, inputs = prompt, outputs = [gallery, skinplot, genplot])
158
+
159
+
160
+
161
+ demo.launch(debug=True)
css.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ custom_css = """
2
+ /* Full width space */
3
+ a {
4
+ text-decoration: underline;
5
+ # text-decoration-style: dotted;
6
+ }
7
+
8
+ h1, h2, h3, h4, h5, h6 {
9
+ margin: 0;
10
+ }
11
+
12
+ .tag {
13
+ padding: .1em .3em;
14
+ background-color: lightgrey;
15
+ border-radius: 12px;
16
+ }
17
+ """
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ diffusers
4
+ transformers
5
+ spaces
6
+ skin-tone-classifier
7
+ matplotlib
8
+ pillow
9
+ numpy