Spaces:
Running
on
Zero
Running
on
Zero
from typing import Tuple | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
from diffusers.models.modeling_utils import ModelMixin | |
from einops import rearrange | |
def zero_module(module): | |
for p in module.parameters(): | |
nn.init.zeros_(p) | |
return module | |
class PoseGuider(ModelMixin): | |
def __init__( | |
self, | |
conditioning_embedding_channels: int, | |
conditioning_channels: int = 3, | |
block_out_channels: Tuple[int] = (16, 32, 96, 256), | |
): | |
super().__init__() | |
self.conv_in = nn.Conv2d( | |
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 | |
) | |
self.blocks = nn.ModuleList([]) | |
for i in range(len(block_out_channels) - 1): | |
channel_in = block_out_channels[i] | |
channel_out = block_out_channels[i + 1] | |
self.blocks.append( | |
nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1) | |
) | |
self.blocks.append( | |
nn.Conv2d( | |
channel_in, channel_out, kernel_size=3, padding=1, stride=2 | |
) | |
) | |
self.out = zero_module( | |
nn.Linear( | |
block_out_channels[-1]*4, | |
conditioning_embedding_channels, | |
) | |
) | |
def forward(self, conditioning): | |
embedding = self.conv_in(conditioning) | |
embedding = F.silu(embedding) | |
for block in self.blocks: | |
embedding = block(embedding) | |
embedding = F.silu(embedding) | |
embedding = rearrange(embedding, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
embedding = self.out(embedding) | |
return embedding | |
if __name__ == "__main__": | |
import torch | |
model = PoseGuider(conditioning_embedding_channels=3072, block_out_channels = (16, 32, 96, 256)) | |
inp = torch.randn((4, 3, 1024, 768)) | |
out = model(inp) | |