Mojmir commited on
Commit
924d062
·
1 Parent(s): a52b4d4
Files changed (3) hide show
  1. app.py +90 -4
  2. custom_resnet.py +726 -0
  3. requirements.txt +5 -0
app.py CHANGED
@@ -1,7 +1,93 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import time
7
+ from concrete.ml.torch.compile import compile_torch_model
8
 
9
+ from custom_resnet import resnet18_custom # Assuming custom_resnet.py is in the same directory
 
10
 
11
+ # Load class names (FLIPPED as ['Fake', 'Real'])
12
+ class_names = ['Fake', 'Real'] # Fix the incorrect mapping
13
+
14
+ # Load the trained model
15
+ def load_model(model_path, device):
16
+ model = resnet18_custom(weights=None)
17
+ num_ftrs = model.fc.in_features
18
+ model.fc = nn.Linear(num_ftrs, len(class_names)) # Assuming 2 classes: Fake and Real
19
+ model.load_state_dict(torch.load(model_path, map_location=device))
20
+ model = model.to(device)
21
+ model.eval() # Set model to evaluation mode
22
+ return model
23
+
24
+
25
+ def load_secure_model(model):
26
+ print("Compiling secure model...")
27
+ secure_model = compile_torch_model(model.to("cpu"),
28
+ n_bits={"model_inputs": 4, "op_inputs": 3, "op_weights": 3, "model_outputs": 5},
29
+ rounding_threshold_bits={"n_bits": 7},
30
+ torch_inputset=torch.rand(10, 3, 224, 224))
31
+ return secure_model
32
+
33
+ # Image preprocessing (match with the transforms used during training)
34
+ data_transform = transforms.Compose([
35
+ transforms.Resize((224, 224)),
36
+ transforms.ToTensor(),
37
+ ])
38
+
39
+ # Prediction function
40
+ def predict(image, mode):
41
+ # Device configuration
42
+ device = torch.device(
43
+ "cuda:0" if torch.cuda.is_available() else
44
+ "mps" if torch.backends.mps.is_available() else
45
+ "cpu"
46
+ )
47
+
48
+ print(f"Device: {device}")
49
+ # Load model
50
+ model_path = 'models/deepfake_detection_model.pth'
51
+ model = load_model(model_path, device)
52
+
53
+ # Apply transformations to the input image
54
+ image = Image.open(image).convert('RGB')
55
+ image = data_transform(image).unsqueeze(0).to(device) # Add batch dimension
56
+
57
+ # Inference
58
+ with torch.no_grad():
59
+ start_time = time.time()
60
+
61
+ if mode == "Fast":
62
+ # Fast mode (less computation)
63
+ outputs = model(image)
64
+ elif mode == "Secure":
65
+ # Secure mode (e.g., running multiple times for higher confidence)
66
+ secure_model = load_secure_model(model)
67
+ detached_input = image.detach().numpy()
68
+ outputs = secure_model(detached_input, fhe="simulate")
69
+
70
+ print(outputs)
71
+ _, preds = torch.max(outputs, 1)
72
+ elapsed_time = time.time() - start_time
73
+
74
+ predicted_class = class_names[preds[0]]
75
+ return f"Predicted: {predicted_class}", f"Time taken: {elapsed_time:.2f} seconds"
76
+
77
+ # Gradio interface
78
+ iface = gr.Interface(
79
+ fn=predict,
80
+ inputs=[
81
+ gr.Image(type="filepath", label="Upload an Image"), # Update to gr.Image
82
+ gr.Radio(choices=["Fast", "Secure"], label="Inference Mode", value="Fast") # Update to gr.Radio
83
+ ],
84
+ outputs=[
85
+ gr.Textbox(label="Prediction"), # Update to gr.Textbox
86
+ gr.Textbox(label="Time Taken") # Update to gr.Textbox
87
+ ],
88
+ title="Deepfake Detection Model",
89
+ description="Upload an image and select the inference mode (Fast or Secure)."
90
+ )
91
+
92
+ if __name__ == "__main__":
93
+ iface.launch(share=True)
custom_resnet.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is a modification of the original ResNet implementation from:
3
+ https://github.com/pytorch/vision/blob/bf01bab6125c5f1152e4f336b470399e52a8559d/torchvision/models/resnet.py
4
+ """
5
+
6
+ from functools import partial
7
+ from typing import Any, Callable, List, Optional, Type, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ from torchvision.models._api import Weights, WeightsEnum, register_model
13
+ from torchvision.models._meta import _IMAGENET_CATEGORIES
14
+ from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface
15
+ from torchvision.transforms._presets import ImageClassification
16
+ from torchvision.utils import _log_api_usage_once
17
+
18
+ __all__ = [
19
+ "ResNet",
20
+ "ResNet18_Weights",
21
+ "ResNet34_Weights",
22
+ "ResNet50_Weights",
23
+ "ResNet101_Weights",
24
+ "ResNet152_Weights",
25
+ "ResNeXt50_32X4D_Weights",
26
+ "ResNeXt101_32X8D_Weights",
27
+ "ResNeXt101_64X4D_Weights",
28
+ "Wide_ResNet50_2_Weights",
29
+ "Wide_ResNet101_2_Weights",
30
+ "resnet18",
31
+ "resnet34",
32
+ "resnet50",
33
+ "resnet101",
34
+ "resnet152",
35
+ "resnext50_32x4d",
36
+ "resnext101_32x8d",
37
+ "resnext101_64x4d",
38
+ "wide_resnet50_2",
39
+ "wide_resnet101_2",
40
+ ]
41
+
42
+
43
+ def conv3x3(
44
+ in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1
45
+ ) -> nn.Conv2d:
46
+ """3x3 convolution with padding"""
47
+ return nn.Conv2d(
48
+ in_planes,
49
+ out_planes,
50
+ kernel_size=3,
51
+ stride=stride,
52
+ padding=dilation,
53
+ groups=groups,
54
+ bias=False,
55
+ dilation=dilation,
56
+ )
57
+
58
+
59
+ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
60
+ """1x1 convolution"""
61
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
62
+
63
+
64
+ class BasicBlock(nn.Module):
65
+ expansion: int = 1
66
+
67
+ def __init__(
68
+ self,
69
+ inplanes: int,
70
+ planes: int,
71
+ stride: int = 1,
72
+ downsample: Optional[nn.Module] = None,
73
+ groups: int = 1,
74
+ base_width: int = 64,
75
+ dilation: int = 1,
76
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
77
+ ) -> None:
78
+ super().__init__()
79
+ if norm_layer is None:
80
+ norm_layer = nn.BatchNorm2d
81
+ if groups != 1 or base_width != 64:
82
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
83
+ if dilation > 1:
84
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
85
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
86
+ self.conv1 = conv3x3(inplanes, planes, stride)
87
+ self.bn1 = norm_layer(planes)
88
+ self.relu = nn.ReLU(inplace=True)
89
+ self.conv2 = conv3x3(planes, planes)
90
+ self.bn2 = norm_layer(planes)
91
+ self.downsample = downsample
92
+ self.stride = stride
93
+
94
+ def forward(self, x: Tensor) -> Tensor:
95
+ identity = x
96
+
97
+ out = self.conv1(x)
98
+ out = self.bn1(out)
99
+ out = self.relu(out)
100
+
101
+ out = self.conv2(out)
102
+ out = self.bn2(out)
103
+
104
+ if self.downsample is not None:
105
+ identity = self.downsample(x)
106
+
107
+ out += identity
108
+ out = self.relu(out)
109
+
110
+ return out
111
+
112
+
113
+ class Bottleneck(nn.Module):
114
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
115
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
116
+ # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
117
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
118
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
119
+
120
+ expansion: int = 4
121
+
122
+ def __init__(
123
+ self,
124
+ inplanes: int,
125
+ planes: int,
126
+ stride: int = 1,
127
+ downsample: Optional[nn.Module] = None,
128
+ groups: int = 1,
129
+ base_width: int = 64,
130
+ dilation: int = 1,
131
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
132
+ ) -> None:
133
+ super().__init__()
134
+ if norm_layer is None:
135
+ norm_layer = nn.BatchNorm2d
136
+ width = int(planes * (base_width / 64.0)) * groups
137
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
138
+ self.conv1 = conv1x1(inplanes, width)
139
+ self.bn1 = norm_layer(width)
140
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
141
+ self.bn2 = norm_layer(width)
142
+ self.conv3 = conv1x1(width, planes * self.expansion)
143
+ self.bn3 = norm_layer(planes * self.expansion)
144
+ self.relu = nn.ReLU(inplace=True)
145
+ self.downsample = downsample
146
+ self.stride = stride
147
+
148
+ def forward(self, x: Tensor) -> Tensor:
149
+ identity = x
150
+
151
+ out = self.conv1(x)
152
+ out = self.bn1(out)
153
+ out = self.relu(out)
154
+
155
+ out = self.conv2(out)
156
+ out = self.bn2(out)
157
+ out = self.relu(out)
158
+
159
+ out = self.conv3(out)
160
+ out = self.bn3(out)
161
+
162
+ if self.downsample is not None:
163
+ identity = self.downsample(x)
164
+
165
+ out += identity
166
+ out = self.relu(out)
167
+
168
+ return out
169
+
170
+
171
+ class ResNet(nn.Module):
172
+ def __init__(
173
+ self,
174
+ block: Type[Union[BasicBlock, Bottleneck]],
175
+ layers: List[int],
176
+ num_classes: int = 1000,
177
+ zero_init_residual: bool = False,
178
+ groups: int = 1,
179
+ width_per_group: int = 64,
180
+ replace_stride_with_dilation: Optional[List[bool]] = None,
181
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
182
+ ) -> None:
183
+ super().__init__()
184
+ _log_api_usage_once(self)
185
+ if norm_layer is None:
186
+ norm_layer = nn.BatchNorm2d
187
+ self._norm_layer = norm_layer
188
+
189
+ self.inplanes = 64
190
+ self.dilation = 1
191
+ if replace_stride_with_dilation is None:
192
+ # each element in the tuple indicates if we should replace
193
+ # the 2x2 stride with a dilated convolution instead
194
+ replace_stride_with_dilation = [False, False, False]
195
+ if len(replace_stride_with_dilation) != 3:
196
+ raise ValueError(
197
+ "replace_stride_with_dilation should be None "
198
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
199
+ )
200
+ self.groups = groups
201
+ self.base_width = width_per_group
202
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
203
+ self.bn1 = norm_layer(self.inplanes)
204
+ self.relu = nn.ReLU(inplace=True)
205
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
206
+ self.layer1 = self._make_layer(block, 64, layers[0])
207
+ self.layer2 = self._make_layer(
208
+ block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
209
+ )
210
+ self.layer3 = self._make_layer(
211
+ block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
212
+ )
213
+ self.layer4 = self._make_layer(
214
+ block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
215
+ )
216
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # FIXME
217
+ self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1, padding=0)
218
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
219
+
220
+ for m in self.modules():
221
+ if isinstance(m, nn.Conv2d):
222
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
223
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
224
+ nn.init.constant_(m.weight, 1)
225
+ nn.init.constant_(m.bias, 0)
226
+
227
+ # Zero-initialize the last BN in each residual branch,
228
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
229
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
230
+ if zero_init_residual:
231
+ for m in self.modules():
232
+ if isinstance(m, Bottleneck) and m.bn3.weight is not None:
233
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
234
+ elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
235
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
236
+
237
+ def _make_layer(
238
+ self,
239
+ block: Type[Union[BasicBlock, Bottleneck]],
240
+ planes: int,
241
+ blocks: int,
242
+ stride: int = 1,
243
+ dilate: bool = False,
244
+ ) -> nn.Sequential:
245
+ norm_layer = self._norm_layer
246
+ downsample = None
247
+ previous_dilation = self.dilation
248
+ if dilate:
249
+ self.dilation *= stride
250
+ stride = 1
251
+ if stride != 1 or self.inplanes != planes * block.expansion:
252
+ downsample = nn.Sequential(
253
+ conv1x1(self.inplanes, planes * block.expansion, stride),
254
+ norm_layer(planes * block.expansion),
255
+ )
256
+
257
+ layers = []
258
+ layers.append(
259
+ block(
260
+ self.inplanes,
261
+ planes,
262
+ stride,
263
+ downsample,
264
+ self.groups,
265
+ self.base_width,
266
+ previous_dilation,
267
+ norm_layer,
268
+ )
269
+ )
270
+ self.inplanes = planes * block.expansion
271
+ for _ in range(1, blocks):
272
+ layers.append(
273
+ block(
274
+ self.inplanes,
275
+ planes,
276
+ groups=self.groups,
277
+ base_width=self.base_width,
278
+ dilation=self.dilation,
279
+ norm_layer=norm_layer,
280
+ )
281
+ )
282
+
283
+ return nn.Sequential(*layers)
284
+
285
+ def _forward_impl(self, x: Tensor) -> Tensor:
286
+ # See note [TorchScript super()]
287
+ x = self.conv1(x)
288
+ x = self.bn1(x)
289
+ x = self.relu(x)
290
+ x = self.maxpool(x)
291
+
292
+ x = self.layer1(x)
293
+ x = self.layer2(x)
294
+ x = self.layer3(x)
295
+ x = self.layer4(x)
296
+
297
+ x = self.avgpool(x)
298
+ x = torch.flatten(x, 1)
299
+ x = self.fc(x)
300
+
301
+ return x
302
+
303
+ def forward(self, x: Tensor) -> Tensor:
304
+ return self._forward_impl(x)
305
+
306
+
307
+ def _resnet(
308
+ block: Type[Union[BasicBlock, Bottleneck]],
309
+ layers: List[int],
310
+ weights: Optional[WeightsEnum],
311
+ progress: bool,
312
+ **kwargs: Any,
313
+ ) -> ResNet:
314
+ if weights is not None:
315
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
316
+
317
+ model = ResNet(block, layers, **kwargs)
318
+
319
+ if weights is not None:
320
+ model.load_state_dict(weights.get_state_dict(progress=progress))
321
+
322
+ return model
323
+
324
+
325
+ _COMMON_META = {
326
+ "min_size": (1, 1),
327
+ "categories": _IMAGENET_CATEGORIES,
328
+ }
329
+
330
+
331
+ class ResNet18_Weights(WeightsEnum):
332
+ IMAGENET1K_V1 = Weights(
333
+ url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
334
+ transforms=partial(ImageClassification, crop_size=224),
335
+ meta={
336
+ **_COMMON_META,
337
+ "num_params": 11689512,
338
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
339
+ "_metrics": {
340
+ "ImageNet-1K": {
341
+ "acc@1": 69.758,
342
+ "acc@5": 89.078,
343
+ }
344
+ },
345
+ "_ops": 1.814,
346
+ "_file_size": 44.661,
347
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
348
+ },
349
+ )
350
+ DEFAULT = IMAGENET1K_V1
351
+
352
+
353
+ class ResNet34_Weights(WeightsEnum):
354
+ IMAGENET1K_V1 = Weights(
355
+ url="https://download.pytorch.org/models/resnet34-b627a593.pth",
356
+ transforms=partial(ImageClassification, crop_size=224),
357
+ meta={
358
+ **_COMMON_META,
359
+ "num_params": 21797672,
360
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
361
+ "_metrics": {
362
+ "ImageNet-1K": {
363
+ "acc@1": 73.314,
364
+ "acc@5": 91.420,
365
+ }
366
+ },
367
+ "_ops": 3.664,
368
+ "_file_size": 83.275,
369
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
370
+ },
371
+ )
372
+ DEFAULT = IMAGENET1K_V1
373
+
374
+
375
+ class ResNet50_Weights(WeightsEnum):
376
+ IMAGENET1K_V1 = Weights(
377
+ url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
378
+ transforms=partial(ImageClassification, crop_size=224),
379
+ meta={
380
+ **_COMMON_META,
381
+ "num_params": 25557032,
382
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
383
+ "_metrics": {
384
+ "ImageNet-1K": {
385
+ "acc@1": 76.130,
386
+ "acc@5": 92.862,
387
+ }
388
+ },
389
+ "_ops": 4.089,
390
+ "_file_size": 97.781,
391
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
392
+ },
393
+ )
394
+ IMAGENET1K_V2 = Weights(
395
+ url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
396
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
397
+ meta={
398
+ **_COMMON_META,
399
+ "num_params": 25557032,
400
+ "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
401
+ "_metrics": {
402
+ "ImageNet-1K": {
403
+ "acc@1": 80.858,
404
+ "acc@5": 95.434,
405
+ }
406
+ },
407
+ "_ops": 4.089,
408
+ "_file_size": 97.79,
409
+ "_docs": """
410
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
411
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
412
+ """,
413
+ },
414
+ )
415
+ DEFAULT = IMAGENET1K_V2
416
+
417
+
418
+ class ResNet101_Weights(WeightsEnum):
419
+ IMAGENET1K_V1 = Weights(
420
+ url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
421
+ transforms=partial(ImageClassification, crop_size=224),
422
+ meta={
423
+ **_COMMON_META,
424
+ "num_params": 44549160,
425
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
426
+ "_metrics": {
427
+ "ImageNet-1K": {
428
+ "acc@1": 77.374,
429
+ "acc@5": 93.546,
430
+ }
431
+ },
432
+ "_ops": 7.801,
433
+ "_file_size": 170.511,
434
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
435
+ },
436
+ )
437
+ IMAGENET1K_V2 = Weights(
438
+ url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
439
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
440
+ meta={
441
+ **_COMMON_META,
442
+ "num_params": 44549160,
443
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
444
+ "_metrics": {
445
+ "ImageNet-1K": {
446
+ "acc@1": 81.886,
447
+ "acc@5": 95.780,
448
+ }
449
+ },
450
+ "_ops": 7.801,
451
+ "_file_size": 170.53,
452
+ "_docs": """
453
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
454
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
455
+ """,
456
+ },
457
+ )
458
+ DEFAULT = IMAGENET1K_V2
459
+
460
+
461
+ class ResNet152_Weights(WeightsEnum):
462
+ IMAGENET1K_V1 = Weights(
463
+ url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
464
+ transforms=partial(ImageClassification, crop_size=224),
465
+ meta={
466
+ **_COMMON_META,
467
+ "num_params": 60192808,
468
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
469
+ "_metrics": {
470
+ "ImageNet-1K": {
471
+ "acc@1": 78.312,
472
+ "acc@5": 94.046,
473
+ }
474
+ },
475
+ "_ops": 11.514,
476
+ "_file_size": 230.434,
477
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
478
+ },
479
+ )
480
+ IMAGENET1K_V2 = Weights(
481
+ url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
482
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
483
+ meta={
484
+ **_COMMON_META,
485
+ "num_params": 60192808,
486
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
487
+ "_metrics": {
488
+ "ImageNet-1K": {
489
+ "acc@1": 82.284,
490
+ "acc@5": 96.002,
491
+ }
492
+ },
493
+ "_ops": 11.514,
494
+ "_file_size": 230.474,
495
+ "_docs": """
496
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
497
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
498
+ """,
499
+ },
500
+ )
501
+ DEFAULT = IMAGENET1K_V2
502
+
503
+
504
+ class ResNeXt50_32X4D_Weights(WeightsEnum):
505
+ IMAGENET1K_V1 = Weights(
506
+ url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
507
+ transforms=partial(ImageClassification, crop_size=224),
508
+ meta={
509
+ **_COMMON_META,
510
+ "num_params": 25028904,
511
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
512
+ "_metrics": {
513
+ "ImageNet-1K": {
514
+ "acc@1": 77.618,
515
+ "acc@5": 93.698,
516
+ }
517
+ },
518
+ "_ops": 4.23,
519
+ "_file_size": 95.789,
520
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
521
+ },
522
+ )
523
+ IMAGENET1K_V2 = Weights(
524
+ url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
525
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
526
+ meta={
527
+ **_COMMON_META,
528
+ "num_params": 25028904,
529
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
530
+ "_metrics": {
531
+ "ImageNet-1K": {
532
+ "acc@1": 81.198,
533
+ "acc@5": 95.340,
534
+ }
535
+ },
536
+ "_ops": 4.23,
537
+ "_file_size": 95.833,
538
+ "_docs": """
539
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
540
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
541
+ """,
542
+ },
543
+ )
544
+ DEFAULT = IMAGENET1K_V2
545
+
546
+
547
+ class ResNeXt101_32X8D_Weights(WeightsEnum):
548
+ IMAGENET1K_V1 = Weights(
549
+ url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
550
+ transforms=partial(ImageClassification, crop_size=224),
551
+ meta={
552
+ **_COMMON_META,
553
+ "num_params": 88791336,
554
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
555
+ "_metrics": {
556
+ "ImageNet-1K": {
557
+ "acc@1": 79.312,
558
+ "acc@5": 94.526,
559
+ }
560
+ },
561
+ "_ops": 16.414,
562
+ "_file_size": 339.586,
563
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
564
+ },
565
+ )
566
+ IMAGENET1K_V2 = Weights(
567
+ url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
568
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
569
+ meta={
570
+ **_COMMON_META,
571
+ "num_params": 88791336,
572
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
573
+ "_metrics": {
574
+ "ImageNet-1K": {
575
+ "acc@1": 82.834,
576
+ "acc@5": 96.228,
577
+ }
578
+ },
579
+ "_ops": 16.414,
580
+ "_file_size": 339.673,
581
+ "_docs": """
582
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
583
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
584
+ """,
585
+ },
586
+ )
587
+ DEFAULT = IMAGENET1K_V2
588
+
589
+
590
+ class ResNeXt101_64X4D_Weights(WeightsEnum):
591
+ IMAGENET1K_V1 = Weights(
592
+ url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
593
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
594
+ meta={
595
+ **_COMMON_META,
596
+ "num_params": 83455272,
597
+ "recipe": "https://github.com/pytorch/vision/pull/5935",
598
+ "_metrics": {
599
+ "ImageNet-1K": {
600
+ "acc@1": 83.246,
601
+ "acc@5": 96.454,
602
+ }
603
+ },
604
+ "_ops": 15.46,
605
+ "_file_size": 319.318,
606
+ "_docs": """
607
+ These weights were trained from scratch by using TorchVision's `new training recipe
608
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
609
+ """,
610
+ },
611
+ )
612
+ DEFAULT = IMAGENET1K_V1
613
+
614
+
615
+ class Wide_ResNet50_2_Weights(WeightsEnum):
616
+ IMAGENET1K_V1 = Weights(
617
+ url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
618
+ transforms=partial(ImageClassification, crop_size=224),
619
+ meta={
620
+ **_COMMON_META,
621
+ "num_params": 68883240,
622
+ "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
623
+ "_metrics": {
624
+ "ImageNet-1K": {
625
+ "acc@1": 78.468,
626
+ "acc@5": 94.086,
627
+ }
628
+ },
629
+ "_ops": 11.398,
630
+ "_file_size": 131.82,
631
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
632
+ },
633
+ )
634
+ IMAGENET1K_V2 = Weights(
635
+ url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
636
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
637
+ meta={
638
+ **_COMMON_META,
639
+ "num_params": 68883240,
640
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
641
+ "_metrics": {
642
+ "ImageNet-1K": {
643
+ "acc@1": 81.602,
644
+ "acc@5": 95.758,
645
+ }
646
+ },
647
+ "_ops": 11.398,
648
+ "_file_size": 263.124,
649
+ "_docs": """
650
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
651
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
652
+ """,
653
+ },
654
+ )
655
+ DEFAULT = IMAGENET1K_V2
656
+
657
+
658
+ class Wide_ResNet101_2_Weights(WeightsEnum):
659
+ IMAGENET1K_V1 = Weights(
660
+ url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
661
+ transforms=partial(ImageClassification, crop_size=224),
662
+ meta={
663
+ **_COMMON_META,
664
+ "num_params": 126886696,
665
+ "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
666
+ "_metrics": {
667
+ "ImageNet-1K": {
668
+ "acc@1": 78.848,
669
+ "acc@5": 94.284,
670
+ }
671
+ },
672
+ "_ops": 22.753,
673
+ "_file_size": 242.896,
674
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
675
+ },
676
+ )
677
+ IMAGENET1K_V2 = Weights(
678
+ url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
679
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
680
+ meta={
681
+ **_COMMON_META,
682
+ "num_params": 126886696,
683
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
684
+ "_metrics": {
685
+ "ImageNet-1K": {
686
+ "acc@1": 82.510,
687
+ "acc@5": 96.020,
688
+ }
689
+ },
690
+ "_ops": 22.753,
691
+ "_file_size": 484.747,
692
+ "_docs": """
693
+ These weights improve upon the results of the original paper by using TorchVision's `new training recipe
694
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
695
+ """,
696
+ },
697
+ )
698
+ DEFAULT = IMAGENET1K_V2
699
+
700
+
701
+ @register_model()
702
+ @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
703
+ def resnet18_custom(
704
+ *, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any
705
+ ) -> ResNet:
706
+ """ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
707
+
708
+ Args:
709
+ weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
710
+ pre-trained weights to use. See
711
+ :class:`~torchvision.models.ResNet18_Weights` below for
712
+ more details, and possible values. By default, no pre-trained
713
+ weights are used.
714
+ progress (bool, optional): If True, displays a progress bar of the
715
+ download to stderr. Default is True.
716
+ **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
717
+ base class. Please refer to the `source code
718
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
719
+ for more details about this class.
720
+
721
+ .. autoclass:: torchvision.models.ResNet18_Weights
722
+ :members:
723
+ """
724
+ weights = ResNet18_Weights.verify(weights)
725
+
726
+ return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ concrete-ml