ssl_watermarking / utils.py
Pierre Fernandez
added encoding and decoding
9e6cbab
import numpy as np
import torch
import torch.nn as nn
from torchvision import models
from scipy.optimize import root_scalar
from scipy.special import betainc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def build_backbone(path, name='resnet50'):
""" Builds a pretrained ResNet-50 backbone. """
model = getattr(models, name)(pretrained=False)
model.head = nn.Identity()
model.fc = nn.Identity()
checkpoint = torch.load(path, map_location=device)
state_dict = checkpoint
for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
if ckpt_key in checkpoint:
state_dict = checkpoint[ckpt_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
return model
def get_linear_layer(weight, bias):
""" Creates a layer that performs feature whitening or centering """
dim_out, dim_in = weight.shape
layer = nn.Linear(dim_in, dim_out)
layer.weight = nn.Parameter(weight)
layer.bias = nn.Parameter(bias)
return layer
def load_normalization_layer(path):
"""
Loads the normalization layer from a checkpoint and returns the layer.
"""
checkpoint = torch.load(path, map_location=device)
if 'whitening' in path or 'out' in path:
D = checkpoint['weight'].shape[1]
weight = torch.nn.Parameter(D*checkpoint['weight'])
bias = torch.nn.Parameter(D*checkpoint['bias'])
else:
weight = checkpoint['weight']
bias = checkpoint['bias']
return get_linear_layer(weight, bias).to(device, non_blocking=True)
class NormLayerWrapper(nn.Module):
"""
Wraps backbone model and normalization layer
"""
def __init__(self, backbone, head):
super(NormLayerWrapper, self).__init__()
backbone.eval(), head.eval()
self.backbone = backbone
self.head = head
def forward(self, x):
output = self.backbone(x)
return self.head(output)
def cosine_pvalue(c, d, k=1):
"""
Returns the probability that the absolute value of the projection
between random unit vectors is higher than c
Args:
c: cosine value
d: dimension of the features
k: number of dimensions of the projection
"""
assert k>0
a = (d - k) / 2.0
b = k / 2.0
if c < 0:
return 1.0
return betainc(a, b, 1 - c ** 2)
def pvalue_angle(dim, k=1, angle=None, proba=None):
def f(a):
return cosine_pvalue(np.cos(a), dim, k) - proba
a = root_scalar(f, x0=0.49*np.pi, bracket=[0, np.pi/2])
# a = fsolve(f, x0=0.49*np.pi)[0]
return a.root