Text-to-Speech
ONNX
English
hexgrad commited on
Commit
3767727
·
verified ·
1 Parent(s): 4519644

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ TTS-Spaces-Arena-25-Dec-2024.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,135 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ base_model:
6
+ - yl4579/StyleTTS2-LJSpeech
7
+ ---
8
+ **Kokoro** is a frontier TTS model for its size of **82 million parameters** (text in/audio out).
9
+
10
+ On 25 Dec 2024, Kokoro v0.19 weights were permissively released in full fp32 precision along with 2 voicepacks (Bella and Sarah), all under an Apache 2.0 license.
11
+
12
+ At the time of release, Kokoro v0.19 was the #1🥇 ranked model in [TTS Spaces Arena](https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena). With 82M params trained for <20 epochs on <100 total hours of audio, Kokoro achieved higher Elo in this single-voice Arena setting over models such as:
13
+ - XTTS v2: 467M, CPML, >10k hours
14
+ - Edge TTS: Microsoft, proprietary
15
+ - MetaVoice: 1.2B, Apache, 100k hours
16
+ - Parler Mini: 880M, Apache, 45k hours
17
+ - Fish Speech: ~500M, CC-BY-NC-SA, 1M hours
18
+
19
+ Kokoro's ability to top this Elo ladder using relatively low compute and data suggests that the scaling law for traditional TTS models might have a steeper slope than previously expected.
20
+
21
+ You can find a hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
22
+
23
+ ### Usage
24
+
25
+ The following can be run in a single cell on [Google Colab](https://colab.research.google.com/).
26
+ ```py
27
+ # 1️⃣ Install dependencies silently
28
+ !git clone https://huggingface.co/hexgrad/Kokoro-82M
29
+ %cd Kokoro-82M
30
+ !apt-get -qq -y install espeak-ng > /dev/null 2>&1
31
+ !pip install -q phonemizer torch transformers scipy munch
32
+
33
+ # 2️⃣ Build the model and load the default voicepack
34
+ from models import build_model
35
+ import torch
36
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
37
+ MODEL = build_model('kokoro-v0_19.pth', device)
38
+ VOICEPACK = torch.load('voices/af.pt', weights_only=True).to(device)
39
+
40
+ # 3️⃣ Call generate, which returns a 24khz audio waveform and a string of output phonemes
41
+ from kokoro import generate
42
+ text = "How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born."
43
+ audio, out_ps = generate(MODEL, text, VOICEPACK)
44
+
45
+ # 4️⃣ Display the 24khz audio and print the output phonemes
46
+ from IPython.display import display, Audio
47
+ display(Audio(data=audio, rate=24000, autoplay=True))
48
+ print(out_ps)
49
+ ```
50
+ This inference code was quickly hacked together on Christmas Day. It is not clean code and leaves a lot of room for improvement. If you'd like to contribute, feel free to open a PR.
51
+
52
+ ### Model Description
53
+
54
+ No affiliation can be assumed between parties on different lines.
55
+
56
+ **Architecture:**
57
+ - StyleTTS 2: https://arxiv.org/abs/2306.07691
58
+ - ISTFTNet: https://arxiv.org/abs/2203.02395
59
+ - Decoder only: no diffusion, no encoder release
60
+
61
+ **Architected by:** Li et al @ https://github.com/yl4579/StyleTTS2
62
+
63
+ **Trained by**: `@rzvzn` on Discord
64
+
65
+ **Supported Languages:** English
66
+
67
+ **Model SHA256 Hash:** `3b0c392f87508da38fad3a2f9d94c359f1b657ebd2ef79f9d56d69503e470b0a`
68
+
69
+ **Model Release Date:**
70
+ - v0.19, Bella, Sarah: 25 Dec 2024
71
+
72
+ **Licenses:**
73
+ - Apache 2.0 weights in this repository
74
+ - MIT inference code in [spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS) adapted from [yl4579/StyleTTS2](https://github.com/yl4579/StyleTTS2)
75
+ - GPLv3 dependency in [espeak-ng](https://github.com/espeak-ng/espeak-ng)
76
+
77
+ The inference code was originally MIT licensed by the paper author. Note that this card applies only to this model, Kokoro. Original models published by the paper author can be found at [hf.co/yl4579](https://huggingface.co/yl4579).
78
+
79
+ ### Evaluation
80
+
81
+ **Metric:** Elo rating
82
+
83
+ **Leaderboard:** [hf.co/spaces/Pendrokar/TTS-Spaces-Arena](https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena)
84
+
85
+ ![TTS-Spaces-Arena-25-Dec-2024](TTS-Spaces-Arena-25-Dec-2024.png)
86
+
87
+ The voice ranked in the Arena is a 50-50 mix of Bella and Sarah. For your convenience, this mix is included in this repository as `af.pt`, but you can trivially reproduce it like this:
88
+
89
+ ```py
90
+ import torch
91
+ bella = torch.load('voices/af_bella.pt', weights_only=True)
92
+ sarah = torch.load('voices/af_sarah.pt', weights_only=True)
93
+ af = torch.mean(torch.stack([bella, sarah]), dim=0)
94
+ assert torch.equal(af, torch.load('voices/af.pt', weights_only=True))
95
+ ```
96
+
97
+ ### Training Details
98
+
99
+ **Compute:** Kokoro was trained on A100 80GB vRAM instances rented from [Vast.ai](https://cloud.vast.ai/?ref_id=79907) (referral link). Vast was chosen over other compute providers due to its competitive on-demand hourly rates. The average hourly cost for the A100 80GB vRAM instances used for training was below $1/hr per GPU, which was around half the quoted rates from other providers at the time.
100
+
101
+ **Data:** Kokoro was trained exclusively on **permissive/non-copyrighted audio data** and IPA phoneme labels. Examples of permissive/non-copyrighted audio include:
102
+ - Public domain audio
103
+ - Audio licensed under Apache, MIT, etc
104
+ - Synthetic audio<sup>[1]</sup> generated by closed<sup>[2]</sup> TTS models from large providers<br/>
105
+ [1] https://copyright.gov/ai/ai_policy_guidance.pdf<br/>
106
+ [2] No synthetic audio from open TTS models or "custom voice clones"
107
+
108
+ **Epochs:** Less than **20 epochs**
109
+
110
+ **Total Dataset Size:** Less than **100 hours** of audio
111
+
112
+ ### Limitations
113
+
114
+ Kokoro v0.19 is limited in some ways, in its training set and architecture:
115
+ - [Data] Lacks voice cloning capability, likely due to small <100h training set
116
+ - [Arch] Relies on external g2p (espeak-ng), which introduces a class of g2p failure modes
117
+ - [Data] Training dataset is mostly long-form reading and narration, not conversation
118
+ - [Arch] At 82M params, Kokoro almost certainly falls to a well-trained 1B+ param diffusion transformer, or a many-billion-param MLLM like GPT-4o / Gemini 2.0 Flash
119
+ - [Data] Multilingual capability is architecturally feasible, but training data is almost entirely English
120
+
121
+ **Will the other voicepacks be released?** There is currently no release date scheduled for the other voicepacks, but in the meantime you can try them in the hosted demo at [hf.co/spaces/hexgrad/Kokoro-TTS](https://huggingface.co/spaces/hexgrad/Kokoro-TTS).
122
+
123
+ ### Acknowledgements
124
+ - [@yl4579](https://huggingface.co/yl4579) for architecting StyleTTS 2
125
+ - [@Pendrokar](https://huggingface.co/Pendrokar) for adding Kokoro as a contender in the TTS Spaces Arena
126
+
127
+ ### Model Card Contact
128
+
129
+ `@rzvzn` on Discord
130
+
131
+ ```py
132
+ # TODO: Add Discord server
133
+ ```
134
+
135
+ <img src="https://static0.gamerantimages.com/wordpress/wp-content/uploads/2024/08/terminator-zero-41-1.jpg" width="400" alt="kokoro" />
TTS-Spaces-Arena-25-Dec-2024.png ADDED

Git LFS Details

  • SHA256: e78b5ec1557323fa0e62681c83f6b81777f9834b91bbf26bf7567b036f011d52
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
istftnet.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
2
+ from scipy.signal import get_window
3
+ from torch.nn import Conv1d, ConvTranspose1d
4
+ from torch.nn.utils import weight_norm, remove_weight_norm
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size*dilation - dilation)/2)
18
+
19
+ LRELU_SLOPE = 0.1
20
+
21
+ class AdaIN1d(nn.Module):
22
+ def __init__(self, style_dim, num_features):
23
+ super().__init__()
24
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
25
+ self.fc = nn.Linear(style_dim, num_features*2)
26
+
27
+ def forward(self, x, s):
28
+ h = self.fc(s)
29
+ h = h.view(h.size(0), h.size(1), 1)
30
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
31
+ return (1 + gamma) * self.norm(x) + beta
32
+
33
+ class AdaINResBlock1(torch.nn.Module):
34
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
35
+ super(AdaINResBlock1, self).__init__()
36
+ self.convs1 = nn.ModuleList([
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
38
+ padding=get_padding(kernel_size, dilation[0]))),
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
40
+ padding=get_padding(kernel_size, dilation[1]))),
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
42
+ padding=get_padding(kernel_size, dilation[2])))
43
+ ])
44
+ self.convs1.apply(init_weights)
45
+
46
+ self.convs2 = nn.ModuleList([
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1))),
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1))),
51
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
+ padding=get_padding(kernel_size, 1)))
53
+ ])
54
+ self.convs2.apply(init_weights)
55
+
56
+ self.adain1 = nn.ModuleList([
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ AdaIN1d(style_dim, channels),
60
+ ])
61
+
62
+ self.adain2 = nn.ModuleList([
63
+ AdaIN1d(style_dim, channels),
64
+ AdaIN1d(style_dim, channels),
65
+ AdaIN1d(style_dim, channels),
66
+ ])
67
+
68
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
69
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
70
+
71
+
72
+ def forward(self, x, s):
73
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
74
+ xt = n1(x, s)
75
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
76
+ xt = c1(xt)
77
+ xt = n2(xt, s)
78
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
79
+ xt = c2(xt)
80
+ x = xt + x
81
+ return x
82
+
83
+ def remove_weight_norm(self):
84
+ for l in self.convs1:
85
+ remove_weight_norm(l)
86
+ for l in self.convs2:
87
+ remove_weight_norm(l)
88
+
89
+ class TorchSTFT(torch.nn.Module):
90
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
91
+ super().__init__()
92
+ self.filter_length = filter_length
93
+ self.hop_length = hop_length
94
+ self.win_length = win_length
95
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
96
+
97
+ def transform(self, input_data):
98
+ forward_transform = torch.stft(
99
+ input_data,
100
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
101
+ return_complex=True)
102
+
103
+ return torch.abs(forward_transform), torch.angle(forward_transform)
104
+
105
+ def inverse(self, magnitude, phase):
106
+ inverse_transform = torch.istft(
107
+ magnitude * torch.exp(phase * 1j),
108
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
109
+
110
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
111
+
112
+ def forward(self, input_data):
113
+ self.magnitude, self.phase = self.transform(input_data)
114
+ reconstruction = self.inverse(self.magnitude, self.phase)
115
+ return reconstruction
116
+
117
+ class SineGen(torch.nn.Module):
118
+ """ Definition of sine generator
119
+ SineGen(samp_rate, harmonic_num = 0,
120
+ sine_amp = 0.1, noise_std = 0.003,
121
+ voiced_threshold = 0,
122
+ flag_for_pulse=False)
123
+ samp_rate: sampling rate in Hz
124
+ harmonic_num: number of harmonic overtones (default 0)
125
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
126
+ noise_std: std of Gaussian noise (default 0.003)
127
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
128
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
129
+ Note: when flag_for_pulse is True, the first time step of a voiced
130
+ segment is always sin(np.pi) or cos(0)
131
+ """
132
+
133
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
134
+ sine_amp=0.1, noise_std=0.003,
135
+ voiced_threshold=0,
136
+ flag_for_pulse=False):
137
+ super(SineGen, self).__init__()
138
+ self.sine_amp = sine_amp
139
+ self.noise_std = noise_std
140
+ self.harmonic_num = harmonic_num
141
+ self.dim = self.harmonic_num + 1
142
+ self.sampling_rate = samp_rate
143
+ self.voiced_threshold = voiced_threshold
144
+ self.flag_for_pulse = flag_for_pulse
145
+ self.upsample_scale = upsample_scale
146
+
147
+ def _f02uv(self, f0):
148
+ # generate uv signal
149
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
150
+ return uv
151
+
152
+ def _f02sine(self, f0_values):
153
+ """ f0_values: (batchsize, length, dim)
154
+ where dim indicates fundamental tone and overtones
155
+ """
156
+ # convert to F0 in rad. The interger part n can be ignored
157
+ # because 2 * np.pi * n doesn't affect phase
158
+ rad_values = (f0_values / self.sampling_rate) % 1
159
+
160
+ # initial phase noise (no noise for fundamental component)
161
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
162
+ device=f0_values.device)
163
+ rand_ini[:, 0] = 0
164
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
165
+
166
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
167
+ if not self.flag_for_pulse:
168
+ # # for normal case
169
+
170
+ # # To prevent torch.cumsum numerical overflow,
171
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
172
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
173
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
174
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
175
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
176
+ # cumsum_shift = torch.zeros_like(rad_values)
177
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
178
+
179
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
180
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
181
+ scale_factor=1/self.upsample_scale,
182
+ mode="linear").transpose(1, 2)
183
+
184
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
185
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
186
+ # cumsum_shift = torch.zeros_like(rad_values)
187
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
188
+
189
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
190
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
191
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
192
+ sines = torch.sin(phase)
193
+
194
+ else:
195
+ # If necessary, make sure that the first time step of every
196
+ # voiced segments is sin(pi) or cos(0)
197
+ # This is used for pulse-train generation
198
+
199
+ # identify the last time step in unvoiced segments
200
+ uv = self._f02uv(f0_values)
201
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
202
+ uv_1[:, -1, :] = 1
203
+ u_loc = (uv < 1) * (uv_1 > 0)
204
+
205
+ # get the instantanouse phase
206
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
207
+ # different batch needs to be processed differently
208
+ for idx in range(f0_values.shape[0]):
209
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
210
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
211
+ # stores the accumulation of i.phase within
212
+ # each voiced segments
213
+ tmp_cumsum[idx, :, :] = 0
214
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
215
+
216
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
217
+ # within the previous voiced segment.
218
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
219
+
220
+ # get the sines
221
+ sines = torch.cos(i_phase * 2 * np.pi)
222
+ return sines
223
+
224
+ def forward(self, f0):
225
+ """ sine_tensor, uv = forward(f0)
226
+ input F0: tensor(batchsize=1, length, dim=1)
227
+ f0 for unvoiced steps should be 0
228
+ output sine_tensor: tensor(batchsize=1, length, dim)
229
+ output uv: tensor(batchsize=1, length, 1)
230
+ """
231
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
232
+ device=f0.device)
233
+ # fundamental component
234
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
235
+
236
+ # generate sine waveforms
237
+ sine_waves = self._f02sine(fn) * self.sine_amp
238
+
239
+ # generate uv signal
240
+ # uv = torch.ones(f0.shape)
241
+ # uv = uv * (f0 > self.voiced_threshold)
242
+ uv = self._f02uv(f0)
243
+
244
+ # noise: for unvoiced should be similar to sine_amp
245
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
246
+ # . for voiced regions is self.noise_std
247
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
248
+ noise = noise_amp * torch.randn_like(sine_waves)
249
+
250
+ # first: set the unvoiced part to 0 by uv
251
+ # then: additive noise
252
+ sine_waves = sine_waves * uv + noise
253
+ return sine_waves, uv, noise
254
+
255
+
256
+ class SourceModuleHnNSF(torch.nn.Module):
257
+ """ SourceModule for hn-nsf
258
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
259
+ add_noise_std=0.003, voiced_threshod=0)
260
+ sampling_rate: sampling_rate in Hz
261
+ harmonic_num: number of harmonic above F0 (default: 0)
262
+ sine_amp: amplitude of sine source signal (default: 0.1)
263
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
264
+ note that amplitude of noise in unvoiced is decided
265
+ by sine_amp
266
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
267
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
268
+ F0_sampled (batchsize, length, 1)
269
+ Sine_source (batchsize, length, 1)
270
+ noise_source (batchsize, length 1)
271
+ uv (batchsize, length, 1)
272
+ """
273
+
274
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
275
+ add_noise_std=0.003, voiced_threshod=0):
276
+ super(SourceModuleHnNSF, self).__init__()
277
+
278
+ self.sine_amp = sine_amp
279
+ self.noise_std = add_noise_std
280
+
281
+ # to produce sine waveforms
282
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
283
+ sine_amp, add_noise_std, voiced_threshod)
284
+
285
+ # to merge source harmonics into a single excitation
286
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
287
+ self.l_tanh = torch.nn.Tanh()
288
+
289
+ def forward(self, x):
290
+ """
291
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
292
+ F0_sampled (batchsize, length, 1)
293
+ Sine_source (batchsize, length, 1)
294
+ noise_source (batchsize, length 1)
295
+ """
296
+ # source for harmonic branch
297
+ with torch.no_grad():
298
+ sine_wavs, uv, _ = self.l_sin_gen(x)
299
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
300
+
301
+ # source for noise branch, in the same shape as uv
302
+ noise = torch.randn_like(uv) * self.sine_amp / 3
303
+ return sine_merge, noise, uv
304
+ def padDiff(x):
305
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
306
+
307
+
308
+ class Generator(torch.nn.Module):
309
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
310
+ super(Generator, self).__init__()
311
+
312
+ self.num_kernels = len(resblock_kernel_sizes)
313
+ self.num_upsamples = len(upsample_rates)
314
+ resblock = AdaINResBlock1
315
+
316
+ self.m_source = SourceModuleHnNSF(
317
+ sampling_rate=24000,
318
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
319
+ harmonic_num=8, voiced_threshod=10)
320
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
321
+ self.noise_convs = nn.ModuleList()
322
+ self.noise_res = nn.ModuleList()
323
+
324
+ self.ups = nn.ModuleList()
325
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
326
+ self.ups.append(weight_norm(
327
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
328
+ k, u, padding=(k-u)//2)))
329
+
330
+ self.resblocks = nn.ModuleList()
331
+ for i in range(len(self.ups)):
332
+ ch = upsample_initial_channel//(2**(i+1))
333
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
334
+ self.resblocks.append(resblock(ch, k, d, style_dim))
335
+
336
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
337
+
338
+ if i + 1 < len(upsample_rates): #
339
+ stride_f0 = np.prod(upsample_rates[i + 1:])
340
+ self.noise_convs.append(Conv1d(
341
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
342
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
343
+ else:
344
+ self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
345
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
346
+
347
+
348
+ self.post_n_fft = gen_istft_n_fft
349
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
350
+ self.ups.apply(init_weights)
351
+ self.conv_post.apply(init_weights)
352
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
353
+ self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
354
+
355
+
356
+ def forward(self, x, s, f0):
357
+ with torch.no_grad():
358
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
359
+
360
+ har_source, noi_source, uv = self.m_source(f0)
361
+ har_source = har_source.transpose(1, 2).squeeze(1)
362
+ har_spec, har_phase = self.stft.transform(har_source)
363
+ har = torch.cat([har_spec, har_phase], dim=1)
364
+
365
+ for i in range(self.num_upsamples):
366
+ x = F.leaky_relu(x, LRELU_SLOPE)
367
+ x_source = self.noise_convs[i](har)
368
+ x_source = self.noise_res[i](x_source, s)
369
+
370
+ x = self.ups[i](x)
371
+ if i == self.num_upsamples - 1:
372
+ x = self.reflection_pad(x)
373
+
374
+ x = x + x_source
375
+ xs = None
376
+ for j in range(self.num_kernels):
377
+ if xs is None:
378
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
379
+ else:
380
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
381
+ x = xs / self.num_kernels
382
+ x = F.leaky_relu(x)
383
+ x = self.conv_post(x)
384
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
385
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
386
+ return self.stft.inverse(spec, phase)
387
+
388
+ def fw_phase(self, x, s):
389
+ for i in range(self.num_upsamples):
390
+ x = F.leaky_relu(x, LRELU_SLOPE)
391
+ x = self.ups[i](x)
392
+ xs = None
393
+ for j in range(self.num_kernels):
394
+ if xs is None:
395
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
396
+ else:
397
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
398
+ x = xs / self.num_kernels
399
+ x = F.leaky_relu(x)
400
+ x = self.reflection_pad(x)
401
+ x = self.conv_post(x)
402
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
403
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
404
+ return spec, phase
405
+
406
+ def remove_weight_norm(self):
407
+ print('Removing weight norm...')
408
+ for l in self.ups:
409
+ remove_weight_norm(l)
410
+ for l in self.resblocks:
411
+ l.remove_weight_norm()
412
+ remove_weight_norm(self.conv_pre)
413
+ remove_weight_norm(self.conv_post)
414
+
415
+
416
+ class AdainResBlk1d(nn.Module):
417
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
418
+ upsample='none', dropout_p=0.0):
419
+ super().__init__()
420
+ self.actv = actv
421
+ self.upsample_type = upsample
422
+ self.upsample = UpSample1d(upsample)
423
+ self.learned_sc = dim_in != dim_out
424
+ self._build_weights(dim_in, dim_out, style_dim)
425
+ self.dropout = nn.Dropout(dropout_p)
426
+
427
+ if upsample == 'none':
428
+ self.pool = nn.Identity()
429
+ else:
430
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
431
+
432
+
433
+ def _build_weights(self, dim_in, dim_out, style_dim):
434
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
435
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
436
+ self.norm1 = AdaIN1d(style_dim, dim_in)
437
+ self.norm2 = AdaIN1d(style_dim, dim_out)
438
+ if self.learned_sc:
439
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
440
+
441
+ def _shortcut(self, x):
442
+ x = self.upsample(x)
443
+ if self.learned_sc:
444
+ x = self.conv1x1(x)
445
+ return x
446
+
447
+ def _residual(self, x, s):
448
+ x = self.norm1(x, s)
449
+ x = self.actv(x)
450
+ x = self.pool(x)
451
+ x = self.conv1(self.dropout(x))
452
+ x = self.norm2(x, s)
453
+ x = self.actv(x)
454
+ x = self.conv2(self.dropout(x))
455
+ return x
456
+
457
+ def forward(self, x, s):
458
+ out = self._residual(x, s)
459
+ out = (out + self._shortcut(x)) / np.sqrt(2)
460
+ return out
461
+
462
+ class UpSample1d(nn.Module):
463
+ def __init__(self, layer_type):
464
+ super().__init__()
465
+ self.layer_type = layer_type
466
+
467
+ def forward(self, x):
468
+ if self.layer_type == 'none':
469
+ return x
470
+ else:
471
+ return F.interpolate(x, scale_factor=2, mode='nearest')
472
+
473
+ class Decoder(nn.Module):
474
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
475
+ resblock_kernel_sizes = [3,7,11],
476
+ upsample_rates = [10, 6],
477
+ upsample_initial_channel=512,
478
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
479
+ upsample_kernel_sizes=[20, 12],
480
+ gen_istft_n_fft=20, gen_istft_hop_size=5):
481
+ super().__init__()
482
+
483
+ self.decode = nn.ModuleList()
484
+
485
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
486
+
487
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
488
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
489
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
490
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
491
+
492
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
493
+
494
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
495
+
496
+ self.asr_res = nn.Sequential(
497
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
498
+ )
499
+
500
+
501
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
502
+ upsample_initial_channel, resblock_dilation_sizes,
503
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
504
+
505
+ def forward(self, asr, F0_curve, N, s):
506
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
507
+ N = self.N_conv(N.unsqueeze(1))
508
+
509
+ x = torch.cat([asr, F0, N], axis=1)
510
+ x = self.encode(x, s)
511
+
512
+ asr_res = self.asr_res(asr)
513
+
514
+ res = True
515
+ for block in self.decode:
516
+ if res:
517
+ x = torch.cat([x, asr_res, F0, N], axis=1)
518
+ x = block(x, s)
519
+ if block.upsample_type != "none":
520
+ res = False
521
+
522
+ x = self.generator(x, s, F0_curve)
523
+ return x
kokoro-v0_19.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b0c392f87508da38fad3a2f9d94c359f1b657ebd2ef79f9d56d69503e470b0a
3
+ size 327211206
kokoro.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import phonemizer
2
+ import re
3
+ import torch
4
+
5
+ def split_num(num):
6
+ num = num.group()
7
+ if '.' in num:
8
+ return num
9
+ elif ':' in num:
10
+ h, m = [int(n) for n in num.split(':')]
11
+ if m == 0:
12
+ return f"{h} o'clock"
13
+ elif m < 10:
14
+ return f'{h} oh {m}'
15
+ return f'{h} {m}'
16
+ year = int(num[:4])
17
+ if year < 1100 or year % 1000 < 10:
18
+ return num
19
+ left, right = num[:2], int(num[2:4])
20
+ s = 's' if num.endswith('s') else ''
21
+ if 100 <= year % 1000 <= 999:
22
+ if right == 0:
23
+ return f'{left} hundred{s}'
24
+ elif right < 10:
25
+ return f'{left} oh {right}{s}'
26
+ return f'{left} {right}{s}'
27
+
28
+ def flip_money(m):
29
+ m = m.group()
30
+ bill = 'dollar' if m[0] == '$' else 'pound'
31
+ if m[-1].isalpha():
32
+ return f'{m[1:]} {bill}s'
33
+ elif '.' not in m:
34
+ s = '' if m[1:] == '1' else 's'
35
+ return f'{m[1:]} {bill}{s}'
36
+ b, c = m[1:].split('.')
37
+ s = '' if b == '1' else 's'
38
+ c = int(c.ljust(2, '0'))
39
+ coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
40
+ return f'{b} {bill}{s} and {c} {coins}'
41
+
42
+ def point_num(num):
43
+ a, b = num.group().split('.')
44
+ return ' point '.join([a, ' '.join(b)])
45
+
46
+ def normalize_text(text):
47
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
48
+ text = text.replace('«', chr(8220)).replace('»', chr(8221))
49
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
50
+ text = text.replace('(', '«').replace(')', '»')
51
+ for a, b in zip('、。!,:;?', ',.!,:;?'):
52
+ text = text.replace(a, b+' ')
53
+ text = re.sub(r'[^\S \n]', ' ', text)
54
+ text = re.sub(r' +', ' ', text)
55
+ text = re.sub(r'(?<=\n) +(?=\n)', '', text)
56
+ text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
57
+ text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
58
+ text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
59
+ text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
60
+ text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
61
+ text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
62
+ text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
63
+ text = re.sub(r'(?<=\d),(?=\d)', '', text)
64
+ text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
65
+ text = re.sub(r'\d*\.\d+', point_num, text)
66
+ text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
67
+ text = re.sub(r'(?<=\d)S', ' S', text)
68
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
69
+ text = re.sub(r"(?<=X')S\b", 's', text)
70
+ text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
71
+ text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
72
+ return text.strip()
73
+
74
+ def get_vocab():
75
+ _pad = "$"
76
+ _punctuation = ';:,.!?¡¿—…"«»“” '
77
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
78
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
79
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
80
+ dicts = {}
81
+ for i in range(len((symbols))):
82
+ dicts[symbols[i]] = i
83
+ return dicts
84
+
85
+ VOCAB = get_vocab()
86
+ def tokenize(ps):
87
+ return [i for i in map(VOCAB.get, ps) if i is not None]
88
+
89
+ en_us = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
90
+ def phonemize(text, norm=True):
91
+ if norm:
92
+ text = normalize_text(text)
93
+ ps = en_us.phonemize([text])
94
+ ps = ps[0] if ps else ''
95
+ # https://en.wiktionary.org/wiki/kokoro#English
96
+ ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
97
+ ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
98
+ ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
99
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
100
+ ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
101
+ ps = ''.join(filter(lambda p: p in VOCAB, ps))
102
+ return ps.strip()
103
+
104
+ def length_to_mask(lengths):
105
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
106
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
107
+ return mask
108
+
109
+ @torch.no_grad()
110
+ def forward(model, tokens, ref_s, speed):
111
+ device = ref_s.device
112
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
113
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
114
+ text_mask = length_to_mask(input_lengths).to(device)
115
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
116
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
117
+ s = ref_s[:, 128:]
118
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
119
+ x, _ = model.predictor.lstm(d)
120
+ duration = model.predictor.duration_proj(x)
121
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
122
+ pred_dur = torch.round(duration).clamp(min=1).long()
123
+ pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
124
+ c_frame = 0
125
+ for i in range(pred_aln_trg.size(0)):
126
+ pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
127
+ c_frame += pred_dur[0,i].item()
128
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
129
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
130
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
131
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
132
+ return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
133
+
134
+ def generate(model, text, voicepack, speed=1):
135
+ ps = phonemize(text)
136
+ tokens = tokenize(ps)
137
+ if not tokens:
138
+ return None
139
+ elif len(tokens) > 510:
140
+ tokens = tokens[:510]
141
+ print('Truncated to 510 tokens')
142
+ ref_s = voicepack[len(tokens)]
143
+ out = forward(model, tokens, ref_s, speed)
144
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
145
+ return out, ps
models.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
+ from istftnet import Decoder
3
+ from munch import Munch
4
+ from plbert import load_plbert
5
+ from torch.nn.utils import weight_norm, spectral_norm
6
+ import numpy as np
7
+ import os
8
+ import os.path as osp
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ class LearnedDownSample(nn.Module):
14
+ def __init__(self, layer_type, dim_in):
15
+ super().__init__()
16
+ self.layer_type = layer_type
17
+
18
+ if self.layer_type == 'none':
19
+ self.conv = nn.Identity()
20
+ elif self.layer_type == 'timepreserve':
21
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
22
+ elif self.layer_type == 'half':
23
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
24
+ else:
25
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
26
+
27
+ def forward(self, x):
28
+ return self.conv(x)
29
+
30
+ class LearnedUpSample(nn.Module):
31
+ def __init__(self, layer_type, dim_in):
32
+ super().__init__()
33
+ self.layer_type = layer_type
34
+
35
+ if self.layer_type == 'none':
36
+ self.conv = nn.Identity()
37
+ elif self.layer_type == 'timepreserve':
38
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
39
+ elif self.layer_type == 'half':
40
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
41
+ else:
42
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
43
+
44
+
45
+ def forward(self, x):
46
+ return self.conv(x)
47
+
48
+ class DownSample(nn.Module):
49
+ def __init__(self, layer_type):
50
+ super().__init__()
51
+ self.layer_type = layer_type
52
+
53
+ def forward(self, x):
54
+ if self.layer_type == 'none':
55
+ return x
56
+ elif self.layer_type == 'timepreserve':
57
+ return F.avg_pool2d(x, (2, 1))
58
+ elif self.layer_type == 'half':
59
+ if x.shape[-1] % 2 != 0:
60
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
61
+ return F.avg_pool2d(x, 2)
62
+ else:
63
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
64
+
65
+
66
+ class UpSample(nn.Module):
67
+ def __init__(self, layer_type):
68
+ super().__init__()
69
+ self.layer_type = layer_type
70
+
71
+ def forward(self, x):
72
+ if self.layer_type == 'none':
73
+ return x
74
+ elif self.layer_type == 'timepreserve':
75
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
76
+ elif self.layer_type == 'half':
77
+ return F.interpolate(x, scale_factor=2, mode='nearest')
78
+ else:
79
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
80
+
81
+
82
+ class ResBlk(nn.Module):
83
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
84
+ normalize=False, downsample='none'):
85
+ super().__init__()
86
+ self.actv = actv
87
+ self.normalize = normalize
88
+ self.downsample = DownSample(downsample)
89
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
90
+ self.learned_sc = dim_in != dim_out
91
+ self._build_weights(dim_in, dim_out)
92
+
93
+ def _build_weights(self, dim_in, dim_out):
94
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
95
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
96
+ if self.normalize:
97
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
98
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
99
+ if self.learned_sc:
100
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
101
+
102
+ def _shortcut(self, x):
103
+ if self.learned_sc:
104
+ x = self.conv1x1(x)
105
+ if self.downsample:
106
+ x = self.downsample(x)
107
+ return x
108
+
109
+ def _residual(self, x):
110
+ if self.normalize:
111
+ x = self.norm1(x)
112
+ x = self.actv(x)
113
+ x = self.conv1(x)
114
+ x = self.downsample_res(x)
115
+ if self.normalize:
116
+ x = self.norm2(x)
117
+ x = self.actv(x)
118
+ x = self.conv2(x)
119
+ return x
120
+
121
+ def forward(self, x):
122
+ x = self._shortcut(x) + self._residual(x)
123
+ return x / np.sqrt(2) # unit variance
124
+
125
+ class LinearNorm(torch.nn.Module):
126
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
127
+ super(LinearNorm, self).__init__()
128
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
129
+
130
+ torch.nn.init.xavier_uniform_(
131
+ self.linear_layer.weight,
132
+ gain=torch.nn.init.calculate_gain(w_init_gain))
133
+
134
+ def forward(self, x):
135
+ return self.linear_layer(x)
136
+
137
+ class Discriminator2d(nn.Module):
138
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
139
+ super().__init__()
140
+ blocks = []
141
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
142
+
143
+ for lid in range(repeat_num):
144
+ dim_out = min(dim_in*2, max_conv_dim)
145
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
146
+ dim_in = dim_out
147
+
148
+ blocks += [nn.LeakyReLU(0.2)]
149
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
150
+ blocks += [nn.LeakyReLU(0.2)]
151
+ blocks += [nn.AdaptiveAvgPool2d(1)]
152
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
153
+ self.main = nn.Sequential(*blocks)
154
+
155
+ def get_feature(self, x):
156
+ features = []
157
+ for l in self.main:
158
+ x = l(x)
159
+ features.append(x)
160
+ out = features[-1]
161
+ out = out.view(out.size(0), -1) # (batch, num_domains)
162
+ return out, features
163
+
164
+ def forward(self, x):
165
+ out, features = self.get_feature(x)
166
+ out = out.squeeze() # (batch)
167
+ return out, features
168
+
169
+ class ResBlk1d(nn.Module):
170
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
171
+ normalize=False, downsample='none', dropout_p=0.2):
172
+ super().__init__()
173
+ self.actv = actv
174
+ self.normalize = normalize
175
+ self.downsample_type = downsample
176
+ self.learned_sc = dim_in != dim_out
177
+ self._build_weights(dim_in, dim_out)
178
+ self.dropout_p = dropout_p
179
+
180
+ if self.downsample_type == 'none':
181
+ self.pool = nn.Identity()
182
+ else:
183
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
184
+
185
+ def _build_weights(self, dim_in, dim_out):
186
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
187
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
188
+ if self.normalize:
189
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
190
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
191
+ if self.learned_sc:
192
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
193
+
194
+ def downsample(self, x):
195
+ if self.downsample_type == 'none':
196
+ return x
197
+ else:
198
+ if x.shape[-1] % 2 != 0:
199
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
200
+ return F.avg_pool1d(x, 2)
201
+
202
+ def _shortcut(self, x):
203
+ if self.learned_sc:
204
+ x = self.conv1x1(x)
205
+ x = self.downsample(x)
206
+ return x
207
+
208
+ def _residual(self, x):
209
+ if self.normalize:
210
+ x = self.norm1(x)
211
+ x = self.actv(x)
212
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
213
+
214
+ x = self.conv1(x)
215
+ x = self.pool(x)
216
+ if self.normalize:
217
+ x = self.norm2(x)
218
+
219
+ x = self.actv(x)
220
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
221
+
222
+ x = self.conv2(x)
223
+ return x
224
+
225
+ def forward(self, x):
226
+ x = self._shortcut(x) + self._residual(x)
227
+ return x / np.sqrt(2) # unit variance
228
+
229
+ class LayerNorm(nn.Module):
230
+ def __init__(self, channels, eps=1e-5):
231
+ super().__init__()
232
+ self.channels = channels
233
+ self.eps = eps
234
+
235
+ self.gamma = nn.Parameter(torch.ones(channels))
236
+ self.beta = nn.Parameter(torch.zeros(channels))
237
+
238
+ def forward(self, x):
239
+ x = x.transpose(1, -1)
240
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
241
+ return x.transpose(1, -1)
242
+
243
+ class TextEncoder(nn.Module):
244
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
245
+ super().__init__()
246
+ self.embedding = nn.Embedding(n_symbols, channels)
247
+
248
+ padding = (kernel_size - 1) // 2
249
+ self.cnn = nn.ModuleList()
250
+ for _ in range(depth):
251
+ self.cnn.append(nn.Sequential(
252
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
253
+ LayerNorm(channels),
254
+ actv,
255
+ nn.Dropout(0.2),
256
+ ))
257
+ # self.cnn = nn.Sequential(*self.cnn)
258
+
259
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
260
+
261
+ def forward(self, x, input_lengths, m):
262
+ x = self.embedding(x) # [B, T, emb]
263
+ x = x.transpose(1, 2) # [B, emb, T]
264
+ m = m.to(input_lengths.device).unsqueeze(1)
265
+ x.masked_fill_(m, 0.0)
266
+
267
+ for c in self.cnn:
268
+ x = c(x)
269
+ x.masked_fill_(m, 0.0)
270
+
271
+ x = x.transpose(1, 2) # [B, T, chn]
272
+
273
+ input_lengths = input_lengths.cpu().numpy()
274
+ x = nn.utils.rnn.pack_padded_sequence(
275
+ x, input_lengths, batch_first=True, enforce_sorted=False)
276
+
277
+ self.lstm.flatten_parameters()
278
+ x, _ = self.lstm(x)
279
+ x, _ = nn.utils.rnn.pad_packed_sequence(
280
+ x, batch_first=True)
281
+
282
+ x = x.transpose(-1, -2)
283
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
284
+
285
+ x_pad[:, :, :x.shape[-1]] = x
286
+ x = x_pad.to(x.device)
287
+
288
+ x.masked_fill_(m, 0.0)
289
+
290
+ return x
291
+
292
+ def inference(self, x):
293
+ x = self.embedding(x)
294
+ x = x.transpose(1, 2)
295
+ x = self.cnn(x)
296
+ x = x.transpose(1, 2)
297
+ self.lstm.flatten_parameters()
298
+ x, _ = self.lstm(x)
299
+ return x
300
+
301
+ def length_to_mask(self, lengths):
302
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
303
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
304
+ return mask
305
+
306
+
307
+
308
+ class AdaIN1d(nn.Module):
309
+ def __init__(self, style_dim, num_features):
310
+ super().__init__()
311
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
312
+ self.fc = nn.Linear(style_dim, num_features*2)
313
+
314
+ def forward(self, x, s):
315
+ h = self.fc(s)
316
+ h = h.view(h.size(0), h.size(1), 1)
317
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
318
+ return (1 + gamma) * self.norm(x) + beta
319
+
320
+ class UpSample1d(nn.Module):
321
+ def __init__(self, layer_type):
322
+ super().__init__()
323
+ self.layer_type = layer_type
324
+
325
+ def forward(self, x):
326
+ if self.layer_type == 'none':
327
+ return x
328
+ else:
329
+ return F.interpolate(x, scale_factor=2, mode='nearest')
330
+
331
+ class AdainResBlk1d(nn.Module):
332
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
333
+ upsample='none', dropout_p=0.0):
334
+ super().__init__()
335
+ self.actv = actv
336
+ self.upsample_type = upsample
337
+ self.upsample = UpSample1d(upsample)
338
+ self.learned_sc = dim_in != dim_out
339
+ self._build_weights(dim_in, dim_out, style_dim)
340
+ self.dropout = nn.Dropout(dropout_p)
341
+
342
+ if upsample == 'none':
343
+ self.pool = nn.Identity()
344
+ else:
345
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
346
+
347
+
348
+ def _build_weights(self, dim_in, dim_out, style_dim):
349
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
350
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
351
+ self.norm1 = AdaIN1d(style_dim, dim_in)
352
+ self.norm2 = AdaIN1d(style_dim, dim_out)
353
+ if self.learned_sc:
354
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
355
+
356
+ def _shortcut(self, x):
357
+ x = self.upsample(x)
358
+ if self.learned_sc:
359
+ x = self.conv1x1(x)
360
+ return x
361
+
362
+ def _residual(self, x, s):
363
+ x = self.norm1(x, s)
364
+ x = self.actv(x)
365
+ x = self.pool(x)
366
+ x = self.conv1(self.dropout(x))
367
+ x = self.norm2(x, s)
368
+ x = self.actv(x)
369
+ x = self.conv2(self.dropout(x))
370
+ return x
371
+
372
+ def forward(self, x, s):
373
+ out = self._residual(x, s)
374
+ out = (out + self._shortcut(x)) / np.sqrt(2)
375
+ return out
376
+
377
+ class AdaLayerNorm(nn.Module):
378
+ def __init__(self, style_dim, channels, eps=1e-5):
379
+ super().__init__()
380
+ self.channels = channels
381
+ self.eps = eps
382
+
383
+ self.fc = nn.Linear(style_dim, channels*2)
384
+
385
+ def forward(self, x, s):
386
+ x = x.transpose(-1, -2)
387
+ x = x.transpose(1, -1)
388
+
389
+ h = self.fc(s)
390
+ h = h.view(h.size(0), h.size(1), 1)
391
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
392
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
393
+
394
+
395
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
396
+ x = (1 + gamma) * x + beta
397
+ return x.transpose(1, -1).transpose(-1, -2)
398
+
399
+ class ProsodyPredictor(nn.Module):
400
+
401
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
402
+ super().__init__()
403
+
404
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
405
+ d_model=d_hid,
406
+ nlayers=nlayers,
407
+ dropout=dropout)
408
+
409
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
410
+ self.duration_proj = LinearNorm(d_hid, max_dur)
411
+
412
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
413
+ self.F0 = nn.ModuleList()
414
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
415
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
416
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
417
+
418
+ self.N = nn.ModuleList()
419
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
420
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
421
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
422
+
423
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
424
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
425
+
426
+
427
+ def forward(self, texts, style, text_lengths, alignment, m):
428
+ d = self.text_encoder(texts, style, text_lengths, m)
429
+
430
+ batch_size = d.shape[0]
431
+ text_size = d.shape[1]
432
+
433
+ # predict duration
434
+ input_lengths = text_lengths.cpu().numpy()
435
+ x = nn.utils.rnn.pack_padded_sequence(
436
+ d, input_lengths, batch_first=True, enforce_sorted=False)
437
+
438
+ m = m.to(text_lengths.device).unsqueeze(1)
439
+
440
+ self.lstm.flatten_parameters()
441
+ x, _ = self.lstm(x)
442
+ x, _ = nn.utils.rnn.pad_packed_sequence(
443
+ x, batch_first=True)
444
+
445
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
446
+
447
+ x_pad[:, :x.shape[1], :] = x
448
+ x = x_pad.to(x.device)
449
+
450
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
451
+
452
+ en = (d.transpose(-1, -2) @ alignment)
453
+
454
+ return duration.squeeze(-1), en
455
+
456
+ def F0Ntrain(self, x, s):
457
+ x, _ = self.shared(x.transpose(-1, -2))
458
+
459
+ F0 = x.transpose(-1, -2)
460
+ for block in self.F0:
461
+ F0 = block(F0, s)
462
+ F0 = self.F0_proj(F0)
463
+
464
+ N = x.transpose(-1, -2)
465
+ for block in self.N:
466
+ N = block(N, s)
467
+ N = self.N_proj(N)
468
+
469
+ return F0.squeeze(1), N.squeeze(1)
470
+
471
+ def length_to_mask(self, lengths):
472
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
473
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
474
+ return mask
475
+
476
+ class DurationEncoder(nn.Module):
477
+
478
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
479
+ super().__init__()
480
+ self.lstms = nn.ModuleList()
481
+ for _ in range(nlayers):
482
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
483
+ d_model // 2,
484
+ num_layers=1,
485
+ batch_first=True,
486
+ bidirectional=True,
487
+ dropout=dropout))
488
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
489
+
490
+
491
+ self.dropout = dropout
492
+ self.d_model = d_model
493
+ self.sty_dim = sty_dim
494
+
495
+ def forward(self, x, style, text_lengths, m):
496
+ masks = m.to(text_lengths.device)
497
+
498
+ x = x.permute(2, 0, 1)
499
+ s = style.expand(x.shape[0], x.shape[1], -1)
500
+ x = torch.cat([x, s], axis=-1)
501
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
502
+
503
+ x = x.transpose(0, 1)
504
+ input_lengths = text_lengths.cpu().numpy()
505
+ x = x.transpose(-1, -2)
506
+
507
+ for block in self.lstms:
508
+ if isinstance(block, AdaLayerNorm):
509
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
510
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
511
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
512
+ else:
513
+ x = x.transpose(-1, -2)
514
+ x = nn.utils.rnn.pack_padded_sequence(
515
+ x, input_lengths, batch_first=True, enforce_sorted=False)
516
+ block.flatten_parameters()
517
+ x, _ = block(x)
518
+ x, _ = nn.utils.rnn.pad_packed_sequence(
519
+ x, batch_first=True)
520
+ x = F.dropout(x, p=self.dropout, training=self.training)
521
+ x = x.transpose(-1, -2)
522
+
523
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
524
+
525
+ x_pad[:, :, :x.shape[-1]] = x
526
+ x = x_pad.to(x.device)
527
+
528
+ return x.transpose(-1, -2)
529
+
530
+ def inference(self, x, style):
531
+ x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
532
+ style = style.expand(x.shape[0], x.shape[1], -1)
533
+ x = torch.cat([x, style], axis=-1)
534
+ src = self.pos_encoder(x)
535
+ output = self.transformer_encoder(src).transpose(0, 1)
536
+ return output
537
+
538
+ def length_to_mask(self, lengths):
539
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
540
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
541
+ return mask
542
+
543
+ # https://github.com/yl4579/StyleTTS2/blob/main/utils.py
544
+ def recursive_munch(d):
545
+ if isinstance(d, dict):
546
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
547
+ elif isinstance(d, list):
548
+ return [recursive_munch(v) for v in d]
549
+ else:
550
+ return d
551
+
552
+ def build_model(path, device):
553
+ args = recursive_munch(dict(
554
+ decoder=dict(
555
+ type='istftnet', upsample_kernel_sizes=[20, 12], upsample_rates=[10, 6], gen_istft_hop_size=5, gen_istft_n_fft=20,
556
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_kernel_sizes=[3, 7, 11], upsample_initial_channel=512,
557
+ ),
558
+ dim_in=64, dropout=0.2, hidden_dim=512, max_conv_dim=512, max_dur=50,
559
+ multispeaker=True, n_layer=3, n_mels=80, n_token=178, style_dim=128
560
+ ))
561
+ assert args.decoder.type == 'istftnet', 'Decoder type unknown'
562
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
563
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
564
+ upsample_rates = args.decoder.upsample_rates,
565
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
566
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
567
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
568
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
569
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
570
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
571
+ bert = load_plbert()
572
+ bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
573
+ for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
574
+ for child in parent.children():
575
+ if isinstance(child, nn.RNNBase):
576
+ child.flatten_parameters()
577
+ model = Munch(
578
+ bert=bert.to(device).eval(),
579
+ bert_encoder=bert_encoder.to(device).eval(),
580
+ predictor=predictor.to(device).eval(),
581
+ decoder=decoder.to(device).eval(),
582
+ text_encoder=text_encoder.to(device).eval(),
583
+ )
584
+ for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
585
+ assert key in model, key
586
+ try:
587
+ model[key].load_state_dict(state_dict)
588
+ except:
589
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
590
+ model[key].load_state_dict(state_dict, strict=False)
591
+ return model
plbert.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
2
+ from transformers import AlbertConfig, AlbertModel
3
+
4
+ class CustomAlbert(AlbertModel):
5
+ def forward(self, *args, **kwargs):
6
+ # Call the original forward method
7
+ outputs = super().forward(*args, **kwargs)
8
+ # Only return the last_hidden_state
9
+ return outputs.last_hidden_state
10
+
11
+ def load_plbert():
12
+ plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
13
+ albert_base_configuration = AlbertConfig(**plbert_config)
14
+ bert = CustomAlbert(albert_base_configuration)
15
+ return bert
voices/af.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fad4192fd8a840f925b0e3fc2be54e20531f91a9ac816a485b7992ca0bd83ebf
3
+ size 524355
voices/af_bella.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2828c6c2f94275ef3441a2edfcf48293298ee0f9b56ce70fb2e344345487b922
3
+ size 524449
voices/af_sarah.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba7918c4ace6ace4221e7e01eb3a6d16596cba9729850551c758cd2ad3a4cd08
3
+ size 524449