koukyo1994 commited on
Commit
07155e5
·
verified ·
1 Parent(s): 3b36933

upload LFQ implementation

Browse files
Files changed (1) hide show
  1. modeling_lfq_tokenizer.py +616 -27
modeling_lfq_tokenizer.py CHANGED
@@ -4,40 +4,629 @@ Code reference: https://github.com/TencentARC/Open-MAGVIT2
4
  """
5
 
6
 
7
- from transformers import PretrainedConfig
 
8
 
 
 
 
 
 
 
 
9
 
10
- class EncoderDecoderConfig(PretrainedConfig):
11
- model_type = "resnet_encoder_decoder"
12
 
13
- def __init__(self, **kwargs):
14
- super().__init__(**kwargs)
15
- self.ch = kwargs.get("ch", 128)
16
- self.in_channels = kwargs.get("in_channels", 3)
17
- self.out_ch = kwargs.get("out_ch", 3)
18
- self.z_channels = kwargs.get("z_channels", 18)
19
- self.num_res_blocks = kwargs.get("num_res_blocks", 2)
20
- self.ch_mult = kwargs.get("ch_mult", [1, 1, 2, 2, 4])
21
 
 
 
 
22
 
23
- class QuantizerConfig(PretrainedConfig):
24
- model_type = "lfq_quantizer"
25
 
26
- def __init__(self, **kwargs):
27
- super().__init__(**kwargs)
28
- self.dim = kwargs.get("dim", 18)
29
- self.codebook_size = kwargs.get("codebook_size", 262144)
30
- self.batch_maximization_weight = kwargs.get("batch_maximization_weight", 1.0)
31
- self.sample_minimization_weight = kwargs.get("sample_minimization_weight", 1.0)
 
32
 
 
 
 
33
 
34
- class LFQTokenizerConfig(PretrainedConfig):
35
- r"""
36
- This is the configuration class to store the configuration of a :class:`~transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  """
38
- model_type = "lfq_tokenizer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def __init__(self, **kwargs):
41
- super().__init__(**kwargs)
42
- self.encoder_decoder_config = kwargs.get("encoder_decoder_config", EncoderDecoderConfig())
43
- self.quantizer_config = kwargs.get("quantizer_config", QuantizerConfig())
 
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
 
7
+ from math import log2, ceil
8
+ from collections import namedtuple
9
 
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange, reduce, pack, unpack
14
+ from torch import einsum
15
+ from torch.nn import Module
16
+ from transformers import PreTrainedModel
17
 
18
+ from .configuration_lfq_tokenizer import LFQTokenizerConfig
 
19
 
 
 
 
 
 
 
 
 
20
 
21
+ def swish(x):
22
+ # swish
23
+ return x * torch.sigmoid(x)
24
 
 
 
25
 
26
+ class ResBlock(nn.Module):
27
+ def __init__(self,
28
+ in_filters,
29
+ out_filters,
30
+ use_conv_shortcut = False
31
+ ) -> None:
32
+ super().__init__()
33
 
34
+ self.in_filters = in_filters
35
+ self.out_filters = out_filters
36
+ self.use_conv_shortcut = use_conv_shortcut
37
 
38
+ self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6)
39
+ self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6)
40
+
41
+ self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
42
+ self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
43
+
44
+ if in_filters != out_filters:
45
+ if self.use_conv_shortcut:
46
+ self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False)
47
+ else:
48
+ self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False)
49
+
50
+
51
+ def forward(self, x, **kwargs):
52
+ residual = x
53
+
54
+ x = self.norm1(x)
55
+ x = swish(x)
56
+ x = self.conv1(x)
57
+ x = self.norm2(x)
58
+ x = swish(x)
59
+ x = self.conv2(x)
60
+ if self.in_filters != self.out_filters:
61
+ if self.use_conv_shortcut:
62
+ residual = self.conv_shortcut(residual)
63
+ else:
64
+ residual = self.nin_shortcut(residual)
65
+
66
+ return x + residual
67
+
68
+ class Encoder(nn.Module):
69
+ def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)):
70
+ super().__init__()
71
+
72
+ self.in_channels = in_channels
73
+ self.z_channels = z_channels
74
+
75
+ self.num_res_blocks = num_res_blocks
76
+ self.num_blocks = len(ch_mult)
77
+
78
+ self.conv_in = nn.Conv2d(in_channels,
79
+ ch,
80
+ kernel_size=(3, 3),
81
+ padding=1,
82
+ bias=False
83
+ )
84
+
85
+ ## construct the model
86
+ self.down = nn.ModuleList()
87
+
88
+ in_ch_mult = (1,)+tuple(ch_mult)
89
+ for i_level in range(self.num_blocks):
90
+ block = nn.ModuleList()
91
+ block_in = ch*in_ch_mult[i_level] #[1, 1, 2, 2, 4]
92
+ block_out = ch*ch_mult[i_level] #[1, 2, 2, 4]
93
+ for _ in range(self.num_res_blocks):
94
+ block.append(ResBlock(block_in, block_out))
95
+ block_in = block_out
96
+
97
+ down = nn.Module()
98
+ down.block = block
99
+ if i_level < self.num_blocks - 1:
100
+ down.downsample = nn.Conv2d(block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
101
+
102
+ self.down.append(down)
103
+
104
+ ### mid
105
+ self.mid_block = nn.ModuleList()
106
+ for res_idx in range(self.num_res_blocks):
107
+ self.mid_block.append(ResBlock(block_in, block_in))
108
+
109
+ ### end
110
+ self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6)
111
+ self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1))
112
+
113
+ def forward(self, x):
114
+
115
+ ## down
116
+ x = self.conv_in(x)
117
+ for i_level in range(self.num_blocks):
118
+ for i_block in range(self.num_res_blocks):
119
+ x = self.down[i_level].block[i_block](x)
120
+
121
+ if i_level < self.num_blocks - 1:
122
+ x = self.down[i_level].downsample(x)
123
+
124
+ ## mid
125
+ for res in range(self.num_res_blocks):
126
+ x = self.mid_block[res](x)
127
+
128
+
129
+ x = self.norm_out(x)
130
+ x = swish(x)
131
+ x = self.conv_out(x)
132
+
133
+ return x
134
+
135
+
136
+ class Decoder(nn.Module):
137
+ def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)) -> None:
138
+ super().__init__()
139
+
140
+ self.ch = ch
141
+ self.num_blocks = len(ch_mult)
142
+ self.num_res_blocks = num_res_blocks
143
+ self.in_channels = in_channels
144
+
145
+ block_in = ch*ch_mult[self.num_blocks-1]
146
+
147
+ self.conv_in = nn.Conv2d(
148
+ z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True
149
+ )
150
+
151
+ self.mid_block = nn.ModuleList()
152
+ for res_idx in range(self.num_res_blocks):
153
+ self.mid_block.append(ResBlock(block_in, block_in))
154
+
155
+ self.up = nn.ModuleList()
156
+
157
+ for i_level in reversed(range(self.num_blocks)):
158
+ block = nn.ModuleList()
159
+ block_out = ch*ch_mult[i_level]
160
+ for i_block in range(self.num_res_blocks):
161
+ block.append(ResBlock(block_in, block_out))
162
+ block_in = block_out
163
+
164
+ up = nn.Module()
165
+ up.block = block
166
+ if i_level > 0:
167
+ up.upsample = Upsampler(block_in)
168
+ self.up.insert(0, up)
169
+
170
+ self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6)
171
+
172
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1)
173
+
174
+ def forward(self, z):
175
+
176
+ z = self.conv_in(z)
177
+
178
+ ## mid
179
+ for res in range(self.num_res_blocks):
180
+ z = self.mid_block[res](z)
181
+
182
+ ## upsample
183
+ for i_level in reversed(range(self.num_blocks)):
184
+ for i_block in range(self.num_res_blocks):
185
+ z = self.up[i_level].block[i_block](z)
186
+
187
+ if i_level > 0:
188
+ z = self.up[i_level].upsample(z)
189
+
190
+ z = self.norm_out(z)
191
+ z = swish(z)
192
+ z = self.conv_out(z)
193
+
194
+ return z
195
+
196
+
197
+ def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor:
198
+ """ Depth-to-Space DCR mode (depth-column-row) core implementation.
199
+
200
+ Args:
201
+ x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported.
202
+ block_size (int): block side size
203
+ """
204
+ # check inputs
205
+ if x.dim() < 3:
206
+ raise ValueError(
207
+ f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions"
208
+ )
209
+ c, h, w = x.shape[-3:]
210
+
211
+ s = block_size**2
212
+ if c % s != 0:
213
+ raise ValueError(
214
+ f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels"
215
+ )
216
+
217
+ outer_dims = x.shape[:-3]
218
+
219
+ # splitting two additional dimensions from the channel dimension
220
+ x = x.view(-1, block_size, block_size, c // s, h, w)
221
+
222
+ # putting the two new dimensions along H and W
223
+ x = x.permute(0, 3, 4, 1, 5, 2)
224
+
225
+ # merging the two new dimensions with H and W
226
+ x = x.contiguous().view(*outer_dims, c // s, h * block_size,
227
+ w * block_size)
228
+
229
+ return x
230
+
231
+ class Upsampler(nn.Module):
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ dim_out = None
236
+ ):
237
+ super().__init__()
238
+ dim_out = dim * 4
239
+ self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1)
240
+ self.depth2space = depth_to_space
241
+
242
+ def forward(self, x):
243
+ """
244
+ input_image: [B C H W]
245
+ """
246
+ out = self.conv1(x)
247
+ out = self.depth2space(out, block_size=2)
248
+ return out
249
+
250
+
251
+ class AdaptiveGroupNorm(nn.Module):
252
+ def __init__(self, z_channel, in_filters, num_groups=32, eps=1e-6):
253
+ super().__init__()
254
+ self.gn = nn.GroupNorm(num_groups=32, num_channels=in_filters, eps=eps, affine=False)
255
+ # self.lin = nn.Linear(z_channels, in_filters * 2)
256
+ self.gamma = nn.Linear(z_channel, in_filters)
257
+ self.beta = nn.Linear(z_channel, in_filters)
258
+ self.eps = eps
259
+
260
+ def forward(self, x, quantizer):
261
+ B, C, _, _ = x.shape
262
+ # quantizer = F.adaptive_avg_pool2d(quantizer, (1, 1))
263
+ ### calcuate var for scale
264
+ scale = rearrange(quantizer, "b c h w -> b c (h w)")
265
+ scale = scale.var(dim=-1) + self.eps #not unbias
266
+ scale = scale.sqrt()
267
+ scale = self.gamma(scale).view(B, C, 1, 1)
268
+
269
+ ### calculate mean for bias
270
+ bias = rearrange(quantizer, "b c h w -> b c (h w)")
271
+ bias = bias.mean(dim=-1)
272
+ bias = self.beta(bias).view(B, C, 1, 1)
273
+
274
+ x = self.gn(x)
275
+ x = scale * x + bias
276
+
277
+ return x
278
+
279
+
280
+ # constants
281
+
282
+ LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'codebook_entropy', 'commitment', 'avg_probs'])
283
+
284
+ # helper functions
285
+
286
+ def exists(v):
287
+ return v is not None
288
+
289
+ def default(*args):
290
+ for arg in args:
291
+ if exists(arg):
292
+ return arg() if callable(arg) else arg
293
+ return None
294
+
295
+ def pack_one(t, pattern):
296
+ return pack([t], pattern)
297
+
298
+ def unpack_one(t, ps, pattern):
299
+ return unpack(t, ps, pattern)[0]
300
+
301
+ # entropy
302
+
303
+ def entropy(prob):
304
+ return (-prob * torch.log(prob + 1e-5)).sum(dim=-1)
305
+
306
+ # class
307
+
308
+ def mult_along_first_dims(x, y):
309
+ """
310
+ returns x * y elementwise along the leading dimensions of y
311
+ """
312
+ ndim_to_expand = x.ndim - y.ndim
313
+ for _ in range(ndim_to_expand):
314
+ y = y.unsqueeze(-1)
315
+ return x * y
316
+
317
+ def masked_mean(x, m):
318
+ """
319
+ takes the mean of the elements of x that are not masked
320
+ the mean is taken along the shared leading dims of m
321
+ equivalent to: x[m].mean(tuple(range(m.ndim)))
322
+
323
+ The benefit of using masked_mean rather than using
324
+ tensor indexing is that masked_mean is much faster
325
+ for torch-compile on batches.
326
+
327
+ The drawback is larger floating point errors
328
+ """
329
+ x = mult_along_first_dims(x, m)
330
+ x = x / m.sum()
331
+ return x.sum(tuple(range(m.ndim)))
332
+
333
+ def entropy_loss(
334
+ logits,
335
+ mask=None,
336
+ temperature=0.01,
337
+ sample_minimization_weight=1.0,
338
+ batch_maximization_weight=1.0,
339
+ eps=1e-5,
340
+ ):
341
+ """
342
+ Entropy loss of unnormalized logits
343
+
344
+ logits: Affinities are over the last dimension
345
+
346
+ https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279
347
+ LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024)
348
  """
349
+ probs = F.softmax(logits / temperature, -1)
350
+ log_probs = F.log_softmax(logits / temperature + eps, -1)
351
+
352
+ if mask is not None:
353
+ avg_probs = masked_mean(probs, mask)
354
+ else:
355
+ avg_probs = reduce(probs, "... D -> D", "mean")
356
+
357
+ avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps))
358
+
359
+ sample_entropy = -torch.sum(probs * log_probs, -1)
360
+ if mask is not None:
361
+ sample_entropy = masked_mean(sample_entropy, mask).mean()
362
+ else:
363
+ sample_entropy = torch.mean(sample_entropy)
364
+
365
+ loss = (sample_minimization_weight * sample_entropy) - (
366
+ batch_maximization_weight * avg_entropy
367
+ )
368
+
369
+ return sample_entropy, avg_entropy, loss
370
+
371
+
372
+ class LFQ(Module):
373
+ def __init__(
374
+ self,
375
+ *,
376
+ dim = None,
377
+ codebook_size = None,
378
+ num_codebooks = 1,
379
+ sample_minimization_weight=1.0,
380
+ batch_maximization_weight=1.0,
381
+ token_factorization = False,
382
+ ):
383
+ super().__init__()
384
+
385
+ # some assert validations
386
+
387
+ assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
388
+ assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
389
+
390
+ self.codebook_size = default(codebook_size, lambda: 2 ** dim)
391
+ self.codebook_dim = int(log2(codebook_size))
392
+
393
+ codebook_dims = self.codebook_dim * num_codebooks
394
+ dim = default(dim, codebook_dims)
395
+
396
+ has_projections = dim != codebook_dims
397
+ self.has_projections = has_projections
398
+
399
+ self.dim = dim
400
+ self.codebook_dim = self.codebook_dim
401
+ self.num_codebooks = num_codebooks
402
+
403
+ # for entropy loss
404
+ self.sample_minimization_weight = sample_minimization_weight
405
+ self.batch_maximization_weight = batch_maximization_weight
406
+
407
+ # for no auxiliary loss, during inference
408
+ self.token_factorization = token_factorization ## only utilized in second stage
409
+ if not self.token_factorization: #for first stage model
410
+ self.register_buffer('mask', 2 ** torch.arange(self.codebook_dim - 1, -1, -1), persistent=False)
411
+ else:
412
+ k = self.codebook_dim // 2
413
+ self.register_buffer("mask", 2 ** torch.arange(k - 1, -1, -1), persistent=False)
414
+
415
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
416
+
417
+ # codes
418
+ all_codes = torch.arange(codebook_size)
419
+ bits = self.indices_to_bits(all_codes)
420
+ codebook = bits * 2.0 - 1.0
421
+
422
+ self.register_buffer('codebook', codebook, persistent = False)
423
+
424
+ @property
425
+ def dtype(self):
426
+ return self.codebook.dtype
427
+
428
+ def indices_to_bits(self, x):
429
+ """
430
+ x: long tensor of indices for constructing codebook, but actually not utilized in all the experiments.
431
+
432
+ returns big endian bits
433
+ """
434
+ mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long)
435
+ # x is now big endian bits, the last dimension being the bits
436
+ x = (x.unsqueeze(-1) & mask) != 0
437
+ return x
438
+
439
+ def get_codebook_entry(self, x, bhwc):
440
+ if self.token_factorization:
441
+ k = self.codebook_dim // 2
442
+ mask = 2 ** torch.arange(k - 1, -1, -1, device=x.device, dtype=torch.long)
443
+ else:
444
+ mask = 2 ** torch.arange(self.codebook_dim-1, -1, -1, device=x.device, dtype=torch.long)
445
+
446
+ x = (x.unsqueeze(-1) & mask) != 0
447
+ x = x * 2.0 - 1.0 #back to the float
448
+ ## scale back to the desired shape
449
+ b, h, w, c = bhwc
450
+ x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c)
451
+ x = rearrange(x, "b h w c -> b c h w")
452
+ return x
453
+
454
+ def bits_to_indices(self, bits):
455
+ """
456
+ bits: bool tensor of big endian bits, where the last dimension is the bit dimension
457
+
458
+ returns indices, which are long integers from 0 to self.codebook_size
459
+ """
460
+ assert bits.shape[-1] == self.codebook_dim
461
+ indices = 2 ** torch.arange(
462
+ 0,
463
+ self.codebook_dim,
464
+ 1,
465
+ dtype=torch.long,
466
+ device=bits.device,
467
+ )
468
+ return (bits * indices).sum(-1)
469
+
470
+ def decode(self, x):
471
+ """
472
+ x: ... NH
473
+ where NH is number of codebook heads
474
+ A longtensor of codebook indices, containing values from
475
+ 0 to self.codebook_size
476
+ """
477
+ x = self.indices_to_bits(x)
478
+ # to some sort of float
479
+ x = x.to(self.dtype)
480
+ # -1 or 1
481
+ x = x * 2 - 1
482
+ x = rearrange(x, "... NC Z-> ... (NC Z)")
483
+ return x
484
+
485
+ def forward(
486
+ self,
487
+ x,
488
+ return_loss_breakdown = False,
489
+ mask = None,
490
+ return_loss = True,
491
+ ):
492
+ """
493
+ einstein notation
494
+ b - batch
495
+ n - sequence (or flattened spatial dimensions)
496
+ d - feature dimension, which is also log2(codebook size)
497
+ c - number of codebook dim
498
+ """
499
+
500
+
501
+ x = rearrange(x, 'b d ... -> b ... d')
502
+ x, ps = pack_one(x, 'b * d')
503
+ # split out number of codebooks
504
+
505
+ x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
506
+
507
+
508
+ codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype)
509
+ quantized = torch.where(x > 0, codebook_value, -codebook_value) # higher than 0 filled
510
+
511
+ # calculate indices
512
+ if self.token_factorization:
513
+ k = self.codebook_dim // 2
514
+ indices_pre = reduce((quantized[..., :k] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
515
+ indices_post = reduce((quantized[..., k:] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum")
516
+ # indices_post = 2**k + indices_post #shifter to the 1024
517
+ else:
518
+ indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
519
+
520
+ # entropy aux loss
521
+
522
+ if self.training and return_loss:
523
+ logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook)
524
+ # the same as euclidean distance up to a constant
525
+ per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss(
526
+ logits = logits,
527
+ sample_minimization_weight = self.sample_minimization_weight,
528
+ batch_maximization_weight = self.batch_maximization_weight
529
+ )
530
+
531
+ avg_probs = self.zero
532
+ else:
533
+ ## calculate the codebook_entropy needed for one batch evaluation
534
+ #------------------------------------------------------------------
535
+ # logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook)
536
+ # probs = F.softmax(logits / 0.01, -1)
537
+ # avg_probs = reduce(probs, "b n c d -> b d", "mean")
538
+ # avg_probs = torch.sum(avg_probs, 0) #batch dimension
539
+ #-------------------------------------------------------------------
540
+ # if not training, just return dummy 0
541
+ per_sample_entropy = codebook_entropy = self.zero
542
+ entropy_aux_loss = self.zero
543
+ avg_probs = self.zero
544
+
545
+ # commit loss
546
+
547
+ if self.training:
548
+ commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none')
549
+
550
+ if exists(mask):
551
+ commit_loss = commit_loss[mask]
552
+
553
+ commit_loss = commit_loss.mean()
554
+ else:
555
+ commit_loss = self.zero
556
+
557
+
558
+ # use straight-through gradients (optionally with custom activation fn) if training
559
+
560
+ quantized = x + (quantized - x).detach() #transfer to quantized
561
+
562
+ # merge back codebook dim
563
+
564
+ quantized = rearrange(quantized, 'b n c d -> b n (c d)')
565
+
566
+ # reconstitute image or video dimensions
567
+
568
+ quantized = unpack_one(quantized, ps, 'b * d')
569
+ quantized = rearrange(quantized, 'b ... d -> b d ...')
570
+
571
+
572
+ if self.token_factorization:
573
+ indices_pre = unpack_one(indices_pre, ps, "b * c")
574
+ indices_post = unpack_one(indices_post, ps, "b * c")
575
+ indices_pre = indices_pre.flatten()
576
+ indices_post = indices_post.flatten()
577
+ indices = (indices_pre, indices_post)
578
+ else:
579
+ indices = unpack_one(indices, ps, 'b * c')
580
+ indices = indices.flatten()
581
+
582
+ ret = (quantized, entropy_aux_loss, indices)
583
+
584
+ if not return_loss_breakdown:
585
+ return ret
586
+
587
+ return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss, avg_probs)
588
+
589
+
590
+ class LFQTokenizer(PreTrainedModel):
591
+ config_class = LFQTokenizerConfig
592
+
593
+ def __init__(self, config: LFQTokenizerConfig):
594
+ super().__init__(config)
595
+
596
+ self.encoder = Encoder(**config.encoder_decoder_config)
597
+ self.decoder = Decoder(**config.encoder_decoder_config)
598
+ self.quantize = LFQ(**config.quantizer_config)
599
+
600
+ def encode(self, x):
601
+ h = self.encoder(x)
602
+ (quant, emb_loss, info), loss_breakdown = self.quantize(h, return_loss_breakdown=True)
603
+ return quant, emb_loss, info, loss_breakdown
604
+
605
+ def decode(self, quant):
606
+ return self.decoder(quant)
607
+
608
+ def forward(self, input):
609
+ quant, diff, _, loss_breakdown = self.encode(input)
610
+ dec = self.decoder(quant)
611
+ return dec, diff, loss_breakdown
612
+
613
+ def tokenize(self, input):
614
+ _, _, tokens, _ = self.encode(input)
615
+ return tokens
616
+
617
+ def get_last_layer(self):
618
+ return self.decoder.conv_out.weight
619
 
620
+ def decode_tokens(self, tokens, shape: tuple):
621
+ if self.quantize.token_factorization:
622
+ tokens_pre, tokens_post = tokens[0], tokens[1]
623
+ quant_pre = self.quantize.get_codebook_entry(tokens_pre, shape)
624
+ quant_post = self.quantize.get_codebook_entry(tokens_post, shape)
625
+ quant = torch.concat([quant_pre, quant_post], dim=1)
626
+ return self.decode(quant)
627
+ else:
628
+ if tokens.ndim == 1:
629
+ batch_size = shape[0]
630
+ tokens = tokens.view(batch_size, -1)
631
+ quant = self.quantize.get_codebook_entry(tokens, shape)
632
+ return self.decode(quant)