alibabasglab commited on
Commit
65bccad
·
verified ·
1 Parent(s): 1690585

Update utils/decode.py

Browse files
Files changed (1) hide show
  1. utils/decode.py +8 -1
utils/decode.py CHANGED
@@ -67,6 +67,8 @@ def decode_one_audio_mossformer2_ss_16k(model, device, inputs, args):
67
  stride = int(window * 0.75) # Decoding stride if segmentation is used
68
  b, t = inputs.shape # Get batch size and input length
69
 
 
 
70
  # Check if input length exceeds one-time decode length to decide on segmentation
71
  if t > args.sampling_rate * args.one_time_decode_length:
72
  decode_do_segment = True # Enable segment decoding for long sequences
@@ -112,13 +114,18 @@ def decode_one_audio_mossformer2_ss_16k(model, device, inputs, args):
112
  out.append(out_list[spk][0, :].detach().cpu().numpy()) # Append output for each speaker
113
 
114
  # Normalize the outputs to the maximum absolute value for each speaker
 
115
  max_abs = 0
116
  for spk in range(args.num_spks):
117
  if max_abs < max(abs(out[spk])):
118
  max_abs = max(abs(out[spk]))
119
  for spk in range(args.num_spks):
120
  out[spk] = out[spk] / max_abs # Normalize output by max absolute value
121
-
 
 
 
 
122
  return out # Return the list of normalized outputs
123
 
124
  def decode_one_audio_frcrn_se_16k(model, device, inputs, args):
 
67
  stride = int(window * 0.75) # Decoding stride if segmentation is used
68
  b, t = inputs.shape # Get batch size and input length
69
 
70
+ rms_input = (inputs ** 2).mean() ** 0.5
71
+
72
  # Check if input length exceeds one-time decode length to decide on segmentation
73
  if t > args.sampling_rate * args.one_time_decode_length:
74
  decode_do_segment = True # Enable segment decoding for long sequences
 
114
  out.append(out_list[spk][0, :].detach().cpu().numpy()) # Append output for each speaker
115
 
116
  # Normalize the outputs to the maximum absolute value for each speaker
117
+ '''
118
  max_abs = 0
119
  for spk in range(args.num_spks):
120
  if max_abs < max(abs(out[spk])):
121
  max_abs = max(abs(out[spk]))
122
  for spk in range(args.num_spks):
123
  out[spk] = out[spk] / max_abs # Normalize output by max absolute value
124
+ '''
125
+ # Normalize the outputs back to the input magnitude for each speaker
126
+ for spk in range(args.num_spks):
127
+ rms_out = (out[spk] ** 2).mean() ** 0.5
128
+ out[spk] = out[spk] / rms_out * rms_input
129
  return out # Return the list of normalized outputs
130
 
131
  def decode_one_audio_frcrn_se_16k(model, device, inputs, args):