#!/usr/bin/env python -u # -*- coding: utf-8 -*- # Import future compatibility features for Python 2/3 from __future__ import absolute_import from __future__ import division from __future__ import print_function # Import necessary libraries import torch import torch.nn as nn import numpy as np from joblib import Parallel, delayed from pesq import pesq # PESQ metric for speech quality evaluation import os import sys import librosa # Library for audio processing import torchaudio # Library for audio processing with PyTorch # Constants MAX_WAV_VALUE = 32768.0 # Maximum value for WAV files EPS = 1e-6 # Small value to avoid division by zero def read_and_config_file(input_path, decode=0): """Reads input paths from a file or directory and configures them for processing. Args: input_path (str): Path to the input directory or file. decode (int): Flag indicating if decoding should occur (1 for decode, 0 for standard read). Returns: list: A list of processed paths or dictionaries containing input and label paths. """ processed_list = [] # If decoding is requested, find files in a directory if decode: if os.path.isdir(input_path): processed_list = librosa.util.find_files(input_path, ext="wav") # Look for WAV files if len(processed_list) == 0: processed_list = librosa.util.find_files(input_path, ext="flac") # Fallback to FLAC files else: # Read paths from a file with open(input_path) as fid: for line in fid: path_s = line.strip().split() # Split line into parts processed_list.append(path_s[0]) # Append the first part (input path) return processed_list # Read input-label pairs from a file with open(input_path) as fid: for line in fid: tmp_paths = line.strip().split() # Split line into parts if len(tmp_paths) == 3: # Expecting input, label, and duration sample = {'inputs': tmp_paths[0], 'labels': tmp_paths[1], 'duration': float(tmp_paths[2])} elif len(tmp_paths) == 2: # Expecting input and label only sample = {'inputs': tmp_paths[0], 'labels': tmp_paths[1]} processed_list.append(sample) # Append the sample dictionary return processed_list def load_checkpoint(checkpoint_path, use_cuda): """Loads the model checkpoint from the specified path. Args: checkpoint_path (str): Path to the checkpoint file. use_cuda (bool): Flag indicating whether to use CUDA for loading. Returns: dict: The loaded checkpoint containing model parameters. """ if use_cuda: checkpoint = torch.load(checkpoint_path) # Load using CUDA else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) # Load to CPU return checkpoint def get_learning_rate(optimizer): """Retrieves the current learning rate from the optimizer. Args: optimizer (torch.optim.Optimizer): The optimizer instance. Returns: float: The current learning rate. """ return optimizer.param_groups[0]["lr"] def reload_for_eval(model, checkpoint_dir, use_cuda): """Reloads a model for evaluation from the specified checkpoint directory. Args: model (nn.Module): The model to be reloaded. checkpoint_dir (str): Directory containing checkpoints. use_cuda (bool): Flag indicating whether to use CUDA. Returns: None """ print('Reloading from: {}'.format(checkpoint_dir)) best_name = os.path.join(checkpoint_dir, 'last_best_checkpoint') # Path to the best checkpoint ckpt_name = os.path.join(checkpoint_dir, 'last_checkpoint') # Path to the last checkpoint if os.path.isfile(best_name): name = best_name elif os.path.isfile(ckpt_name): name = ckpt_name else: print('Warning: No existing checkpoint or best_model found!') return with open(name, 'r') as f: model_name = f.readline().strip() # Read the model name from the checkpoint file checkpoint_path = os.path.join(checkpoint_dir, model_name) # Construct full checkpoint path print('Checkpoint path: {}'.format(checkpoint_path)) checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) #checkpoint = load_checkpoint(checkpoint_path, use_cuda) # Load the checkpoint ''' if 'model' in checkpoint: model.load_state_dict(checkpoint['model'], strict=False) # Load model parameters else: model.load_state_dict(checkpoint, strict=False) ''' if 'model' in checkpoint: pretrained_model = checkpoint['model'] else: pretrained_model = checkpoint state = model.state_dict() for key in state.keys(): if key in pretrained_model and state[key].shape == pretrained_model[key].shape: state[key] = pretrained_model[key] elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape: state[key] = pretrained_model[key.replace('module.', '')] elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape: state[key] = pretrained_model['module.'+key] model.load_state_dict(state) print('=> Reloaded well-trained model {} for decoding.'.format(model_name)) def reload_model(model, optimizer, checkpoint_dir, use_cuda=True, strict=True): """Reloads the model and optimizer state from a checkpoint. Args: model (nn.Module): The model to be reloaded. optimizer (torch.optim.Optimizer): The optimizer to be reloaded. checkpoint_dir (str): Directory containing checkpoints. use_cuda (bool): Flag indicating whether to use CUDA. strict (bool): If True, requires keys in state_dict to match exactly. Returns: tuple: Current epoch and step. """ ckpt_name = os.path.join(checkpoint_dir, 'checkpoint') # Path to the checkpoint file if os.path.isfile(ckpt_name): with open(ckpt_name, 'r') as f: model_name = f.readline().strip() # Read model name from checkpoint file checkpoint_path = os.path.join(checkpoint_dir, model_name) # Construct full checkpoint path checkpoint = load_checkpoint(checkpoint_path, use_cuda) # Load the checkpoint model.load_state_dict(checkpoint['model'], strict=strict) # Load model parameters optimizer.load_state_dict(checkpoint['optimizer']) # Load optimizer parameters epoch = checkpoint['epoch'] # Get current epoch step = checkpoint['step'] # Get current step print('=> Reloaded previous model and optimizer.') else: print('[!] Checkpoint directory is empty. Train a new model ...') epoch = 0 # Initialize epoch step = 0 # Initialize step return epoch, step def save_checkpoint(model, optimizer, epoch, step, checkpoint_dir, mode='checkpoint'): """Saves the model and optimizer state to a checkpoint file. Args: model (nn.Module): The model to be saved. optimizer (torch.optim.Optimizer): The optimizer to be saved. epoch (int): Current epoch number. step (int): Current training step number. checkpoint_dir (str): Directory to save the checkpoint. mode (str): Mode of the checkpoint ('checkpoint' or other). Returns: None """ checkpoint_path = os.path.join( checkpoint_dir, 'model.ckpt-{}-{}.pt'.format(epoch, step)) # Construct checkpoint file path torch.save({'model': model.state_dict(), # Save model parameters 'optimizer': optimizer.state_dict(), # Save optimizer parameters 'epoch': epoch, # Save epoch 'step': step}, checkpoint_path) # Save checkpoint to file # Save the checkpoint name to a file for easy access with open(os.path.join(checkpoint_dir, mode), 'w') as f: f.write('model.ckpt-{}-{}.pt'.format(epoch, step)) print("=> Saved checkpoint:", checkpoint_path) def setup_lr(opt, lr): """Sets the learning rate for all parameter groups in the optimizer. Args: opt (torch.optim.Optimizer): The optimizer instance whose learning rate needs to be set. lr (float): The new learning rate to be assigned. Returns: None """ for param_group in opt.param_groups: param_group['lr'] = lr # Update the learning rate for each parameter group def pesq_loss(clean, noisy, sr=16000): """Calculates the PESQ (Perceptual Evaluation of Speech Quality) score between clean and noisy signals. Args: clean (ndarray): The clean audio signal. noisy (ndarray): The noisy audio signal. sr (int): Sample rate of the audio signals (default is 16000 Hz). Returns: float: The PESQ score or -1 in case of an error. """ try: pesq_score = pesq(sr, clean, noisy, 'wb') # Compute PESQ score except: # PESQ may fail due to silent periods in audio pesq_score = -1 # Assign -1 to indicate error return pesq_score def batch_pesq(clean, noisy): """Computes the PESQ scores for batches of clean and noisy audio signals. Args: clean (list of ndarray): List of clean audio signals. noisy (list of ndarray): List of noisy audio signals. Returns: torch.FloatTensor: A tensor of normalized PESQ scores or None if any score is -1. """ # Parallel processing for calculating PESQ scores for each pair of clean and noisy signals pesq_score = Parallel(n_jobs=-1)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy)) pesq_score = np.array(pesq_score) # Convert to NumPy array if -1 in pesq_score: # Check for errors in PESQ calculations return None # Normalize PESQ scores to a scale of 0 to 1 pesq_score = (pesq_score - 1) / 3.5 return torch.FloatTensor(pesq_score).to('cuda') # Return normalized scores as a tensor def power_compress(x): """Compresses the power of a complex spectrogram. Args: x (torch.Tensor): Input tensor with real and imaginary components. Returns: torch.Tensor: Compressed magnitude and phase representation of the input. """ real = x[..., 0] # Extract real part imag = x[..., 1] # Extract imaginary part spec = torch.complex(real, imag) # Create complex tensor from real and imaginary parts mag = torch.abs(spec) # Compute magnitude phase = torch.angle(spec) # Compute phase mag = mag**0.3 # Compress magnitude using power of 0.3 real_compress = mag * torch.cos(phase) # Reconstruct real part imag_compress = mag * torch.sin(phase) # Reconstruct imaginary part return torch.stack([real_compress, imag_compress], 1) # Stack compressed parts def power_uncompress(real, imag): """Uncompresses the power of a compressed complex spectrogram. Args: real (torch.Tensor): Compressed real component. imag (torch.Tensor): Compressed imaginary component. Returns: torch.Tensor: Uncompressed complex spectrogram. """ spec = torch.complex(real, imag) # Create complex tensor from real and imaginary parts mag = torch.abs(spec) # Compute magnitude phase = torch.angle(spec) # Compute phase mag = mag**(1./0.3) # Uncompress magnitude by raising to the power of 1/0.3 real_uncompress = mag * torch.cos(phase) # Reconstruct real part imag_uncompress = mag * torch.sin(phase) # Reconstruct imaginary part return torch.stack([real_uncompress, imag_uncompress], -1) # Stack uncompressed parts def stft(x, args, center=False): """Computes the Short-Time Fourier Transform (STFT) of an audio signal. Args: x (torch.Tensor): Input audio signal. args (Namespace): Configuration arguments containing window type and lengths. center (bool): Whether to center the window. Returns: torch.Tensor: The computed STFT of the input signal. """ win_type = args.win_type win_len = args.win_len win_inc = args.win_inc fft_len = args.fft_len # Select window type and create window tensor if win_type == 'hamming': window = torch.hamming_window(win_len, periodic=False).to(x.device) elif win_type == 'hanning': window = torch.hann_window(win_len, periodic=False).to(x.device) else: print(f"In STFT, {win_type} is not supported!") return # Compute and return the STFT return torch.stft(x, fft_len, win_inc, win_len, center=center, window=window, return_complex=False) def istft(x, args, slen=None, center=False, normalized=False, onsided=None, return_complex=False): """Computes the inverse Short-Time Fourier Transform (ISTFT) of a complex spectrogram. Args: x (torch.Tensor): Input complex spectrogram. args (Namespace): Configuration arguments containing window type and lengths. slen (int, optional): Length of the output signal. center (bool): Whether to center the window. normalized (bool): Whether to normalize the output. onsided (bool, optional): If True, computes only the one-sided transform. return_complex (bool): If True, returns complex output. Returns: torch.Tensor: The reconstructed audio signal from the spectrogram. """ win_type = args.win_type win_len = args.win_len win_inc = args.win_inc fft_len = args.fft_len # Select window type and create window tensor if win_type == 'hamming': window = torch.hamming_window(win_len, periodic=False).to(x.device) elif win_type == 'hanning': window = torch.hann_window(win_len, periodic=False).to(x.device) else: print(f"In ISTFT, {win_type} is not supported!") return try: # Attempt to compute ISTFT output = torch.istft(x, n_fft=fft_len, hop_length=win_inc, win_length=win_len, window=window, center=center, normalized=normalized, onesided=onsided, length=slen, return_complex=False) except: # Handle potential errors by converting x to a complex tensor x_complex = torch.view_as_complex(x) output = torch.istft(x_complex, n_fft=fft_len, hop_length=win_inc, win_length=win_len, window=window, center=center, normalized=normalized, onesided=onsided, length=slen, return_complex=False) return output def compute_fbank(audio_in, args): """Computes the filter bank features from an audio signal. Args: audio_in (torch.Tensor): Input audio signal. args (Namespace): Configuration arguments containing window length, shift, and sampling rate. Returns: torch.Tensor: Computed filter bank features. """ frame_length = args.win_len / args.sampling_rate * 1000 # Frame length in milliseconds frame_shift = args.win_inc / args.sampling_rate * 1000 # Frame shift in milliseconds # Compute and return filter bank features using Kaldi's implementation return torchaudio.compliance.kaldi.fbank(audio_in, dither=1.0, frame_length=frame_length, frame_shift=frame_shift, num_mel_bins=args.num_mels, sample_frequency=args.sampling_rate, window_type=args.win_type)