# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import math import argparse import shutil import datetime import logging from omegaconf import OmegaConf from tqdm.auto import tqdm from einops import rearrange import torch import torch.nn.functional as F import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP import diffusers from diffusers import AutoencoderKL, DDIMScheduler from diffusers.utils.logging import get_logger from diffusers.optimization import get_scheduler from diffusers.utils.import_utils import is_xformers_available from accelerate.utils import set_seed from latentsync.data.unet_dataset import UNetDataset from latentsync.models.unet import UNet3DConditionModel from latentsync.models.syncnet import SyncNet from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline from latentsync.utils.util import ( init_dist, cosine_loss, reversed_forward, ) from latentsync.utils.util import plot_loss_chart, gather_loss from latentsync.whisper.audio2feature import Audio2Feature from latentsync.trepa import TREPALoss from eval.syncnet import SyncNetEval from eval.syncnet_detect import SyncNetDetector from eval.eval_sync_conf import syncnet_eval import lpips logger = get_logger(__name__) def main(config): # Initialize distributed training local_rank = init_dist() global_rank = dist.get_rank() num_processes = dist.get_world_size() is_main_process = global_rank == 0 seed = config.run.seed + global_rank set_seed(seed) # Logging folder folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S") output_dir = os.path.join(config.data.train_output_dir, folder_name) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Handle the output folder creation if is_main_process: diffusers.utils.logging.set_verbosity_info() os.makedirs(output_dir, exist_ok=True) os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) os.makedirs(f"{output_dir}/val_videos", exist_ok=True) os.makedirs(f"{output_dir}/loss_charts", exist_ok=True) shutil.copy(config.unet_config_path, output_dir) shutil.copy(config.data.syncnet_config_path, output_dir) device = torch.device(local_rank) noise_scheduler = DDIMScheduler.from_pretrained("configs") vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) vae.config.scaling_factor = 0.18215 vae.config.shift_factor = 0 vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) vae.requires_grad_(False) vae.to(device) syncnet_eval_model = SyncNetEval(device=device) syncnet_eval_model.loadParameters("checkpoints/auxiliary/syncnet_v2.model") syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results") if config.model.cross_attention_dim == 768: whisper_model_path = "checkpoints/whisper/small.pt" elif config.model.cross_attention_dim == 384: whisper_model_path = "checkpoints/whisper/tiny.pt" else: raise NotImplementedError("cross_attention_dim must be 768 or 384") audio_encoder = Audio2Feature( model_path=whisper_model_path, device=device, audio_embeds_cache_dir=config.data.audio_embeds_cache_dir, num_frames=config.data.num_frames, ) unet, resume_global_step = UNet3DConditionModel.from_pretrained( OmegaConf.to_container(config.model), config.ckpt.resume_ckpt_path, # load checkpoint device=device, ) if config.model.add_audio_layer and config.run.use_syncnet: syncnet_config = OmegaConf.load(config.data.syncnet_config_path) if syncnet_config.ckpt.inference_ckpt_path == "": raise ValueError("SyncNet path is not provided") syncnet = SyncNet(OmegaConf.to_container(syncnet_config.model)).to(device=device, dtype=torch.float16) syncnet_checkpoint = torch.load(syncnet_config.ckpt.inference_ckpt_path, map_location=device) syncnet.load_state_dict(syncnet_checkpoint["state_dict"]) syncnet.requires_grad_(False) unet.requires_grad_(True) trainable_params = list(unet.parameters()) if config.optimizer.scale_lr: config.optimizer.lr = config.optimizer.lr * num_processes optimizer = torch.optim.AdamW(trainable_params, lr=config.optimizer.lr) if is_main_process: logger.info(f"trainable params number: {len(trainable_params)}") logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") # Enable xformers if config.run.enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") # Enable gradient checkpointing if config.run.enable_gradient_checkpointing: unet.enable_gradient_checkpointing() # Get the training dataset train_dataset = UNetDataset(config.data.train_data_dir, config) distributed_sampler = DistributedSampler( train_dataset, num_replicas=num_processes, rank=global_rank, shuffle=True, seed=config.run.seed, ) # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=config.data.batch_size, shuffle=False, sampler=distributed_sampler, num_workers=config.data.num_workers, pin_memory=False, drop_last=True, worker_init_fn=train_dataset.worker_init_fn, ) # Get the training iteration if config.run.max_train_steps == -1: assert config.run.max_train_epochs != -1 config.run.max_train_steps = config.run.max_train_epochs * len(train_dataloader) # Scheduler lr_scheduler = get_scheduler( config.optimizer.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.optimizer.lr_warmup_steps, num_training_steps=config.run.max_train_steps, ) if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise: lpips_loss_func = lpips.LPIPS(net="vgg").to(device) if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise: trepa_loss_func = TREPALoss(device=device) # Validation pipeline pipeline = LipsyncPipeline( vae=vae, audio_encoder=audio_encoder, unet=unet, scheduler=noise_scheduler, ).to(device) pipeline.set_progress_bar_config(disable=True) # DDP warpper unet = DDP(unet, device_ids=[local_rank], output_device=local_rank) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader)) # Afterwards we recalculate our number of training epochs num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch) # Train! total_batch_size = config.data.batch_size * num_processes if is_main_process: logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Instantaneous batch size per device = {config.data.batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total optimization steps = {config.run.max_train_steps}") global_step = resume_global_step first_epoch = resume_global_step // num_update_steps_per_epoch # Only show the progress bar once on each machine. progress_bar = tqdm( range(0, config.run.max_train_steps), initial=resume_global_step, desc="Steps", disable=not is_main_process, ) train_step_list = [] sync_loss_list = [] recon_loss_list = [] val_step_list = [] sync_conf_list = [] # Support mixed-precision training scaler = torch.cuda.amp.GradScaler() if config.run.mixed_precision_training else None for epoch in range(first_epoch, num_train_epochs): train_dataloader.sampler.set_epoch(epoch) unet.train() for step, batch in enumerate(train_dataloader): ### >>>> Training >>>> ### if config.model.add_audio_layer: if batch["mel"] != []: mel = batch["mel"].to(device, dtype=torch.float16) audio_embeds_list = [] try: for idx in range(len(batch["video_path"])): video_path = batch["video_path"][idx] start_idx = batch["start_idx"][idx] with torch.no_grad(): audio_feat = audio_encoder.audio2feat(video_path) audio_embeds = audio_encoder.crop_overlap_audio_window(audio_feat, start_idx) audio_embeds_list.append(audio_embeds) except Exception as e: logger.info(f"{type(e).__name__} - {e} - {video_path}") continue audio_embeds = torch.stack(audio_embeds_list) # (B, 16, 50, 384) audio_embeds = audio_embeds.to(device, dtype=torch.float16) else: audio_embeds = None # Convert videos to latent space gt_images = batch["gt"].to(device, dtype=torch.float16) gt_masked_images = batch["masked_gt"].to(device, dtype=torch.float16) mask = batch["mask"].to(device, dtype=torch.float16) ref_images = batch["ref"].to(device, dtype=torch.float16) gt_images = rearrange(gt_images, "b f c h w -> (b f) c h w") gt_masked_images = rearrange(gt_masked_images, "b f c h w -> (b f) c h w") mask = rearrange(mask, "b f c h w -> (b f) c h w") ref_images = rearrange(ref_images, "b f c h w -> (b f) c h w") with torch.no_grad(): gt_latents = vae.encode(gt_images).latent_dist.sample() gt_masked_images = vae.encode(gt_masked_images).latent_dist.sample() ref_images = vae.encode(ref_images).latent_dist.sample() mask = torch.nn.functional.interpolate(mask, size=config.data.resolution // vae_scale_factor) gt_latents = ( rearrange(gt_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor ) * vae.config.scaling_factor gt_masked_images = ( rearrange(gt_masked_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor ) * vae.config.scaling_factor ref_images = ( rearrange(ref_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor ) * vae.config.scaling_factor mask = rearrange(mask, "(b f) c h w -> b c f h w", f=config.data.num_frames) # Sample noise that we'll add to the latents if config.run.use_mixed_noise: # Refer to the paper: https://arxiv.org/abs/2305.10474 noise_shared_std_dev = (config.run.mixed_noise_alpha**2 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5 noise_shared = torch.randn_like(gt_latents) * noise_shared_std_dev noise_shared = noise_shared[:, :, 0:1].repeat(1, 1, config.data.num_frames, 1, 1) noise_ind_std_dev = (1 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5 noise_ind = torch.randn_like(gt_latents) * noise_ind_std_dev noise = noise_ind + noise_shared else: noise = torch.randn_like(gt_latents) noise = noise[:, :, 0:1].repeat( 1, 1, config.data.num_frames, 1, 1 ) # Using the same noise for all frames, refer to the paper: https://arxiv.org/abs/2308.09716 bsz = gt_latents.shape[0] # Sample a random timestep for each video timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_tensor = noise_scheduler.add_noise(gt_latents, noise, timesteps) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": raise NotImplementedError else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") unet_input = torch.cat([noisy_tensor, mask, gt_masked_images, ref_images], dim=1) # Predict the noise and compute loss # Mixed-precision training with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training): pred_noise = unet(unet_input, timesteps, encoder_hidden_states=audio_embeds).sample if config.run.recon_loss_weight != 0: recon_loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean") else: recon_loss = 0 pred_latents = reversed_forward(noise_scheduler, pred_noise, timesteps, noisy_tensor) if config.run.pixel_space_supervise: pred_images = vae.decode( rearrange(pred_latents, "b c f h w -> (b f) c h w") / vae.config.scaling_factor + vae.config.shift_factor ).sample if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise: pred_images_perceptual = pred_images[:, :, pred_images.shape[2] // 2 :, :] gt_images_perceptual = gt_images[:, :, gt_images.shape[2] // 2 :, :] lpips_loss = lpips_loss_func(pred_images_perceptual.float(), gt_images_perceptual.float()).mean() else: lpips_loss = 0 if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise: trepa_pred_images = rearrange(pred_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) trepa_gt_images = rearrange(gt_images, "(b f) c h w -> b c f h w", f=config.data.num_frames) trepa_loss = trepa_loss_func(trepa_pred_images, trepa_gt_images) else: trepa_loss = 0 if config.model.add_audio_layer and config.run.use_syncnet: if config.run.pixel_space_supervise: syncnet_input = rearrange(pred_images, "(b f) c h w -> b (f c) h w", f=config.data.num_frames) else: syncnet_input = rearrange(pred_latents, "b c f h w -> b (f c) h w") if syncnet_config.data.lower_half: height = syncnet_input.shape[2] syncnet_input = syncnet_input[:, :, height // 2 :, :] ones_tensor = torch.ones((config.data.batch_size, 1)).float().to(device=device) vision_embeds, audio_embeds = syncnet(syncnet_input, mel) sync_loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), ones_tensor).mean() sync_loss_list.append(gather_loss(sync_loss, device)) else: sync_loss = 0 loss = ( recon_loss * config.run.recon_loss_weight + sync_loss * config.run.sync_loss_weight + lpips_loss * config.run.perceptual_loss_weight + trepa_loss * config.run.trepa_loss_weight ) train_step_list.append(global_step) if config.run.recon_loss_weight != 0: recon_loss_list.append(gather_loss(recon_loss, device)) optimizer.zero_grad() # Backpropagate if config.run.mixed_precision_training: scaler.scale(loss).backward() """ >>> gradient clipping >>> """ scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm) """ <<< gradient clipping <<< """ scaler.step(optimizer) scaler.update() else: loss.backward() """ >>> gradient clipping >>> """ torch.nn.utils.clip_grad_norm_(unet.parameters(), config.optimizer.max_grad_norm) """ <<< gradient clipping <<< """ optimizer.step() # Check the grad of attn blocks for debugging # print(unet.module.up_blocks[3].attentions[2].transformer_blocks[0].audio_cross_attn.attn.to_q.weight.grad) lr_scheduler.step() progress_bar.update(1) global_step += 1 ### <<<< Training <<<< ### # Save checkpoint and conduct validation if is_main_process and (global_step % config.ckpt.save_ckpt_steps == 0): if config.run.recon_loss_weight != 0: plot_loss_chart( os.path.join(output_dir, f"loss_charts/recon_loss_chart-{global_step}.png"), ("Reconstruction loss", train_step_list, recon_loss_list), ) if config.model.add_audio_layer: if sync_loss_list != []: plot_loss_chart( os.path.join(output_dir, f"loss_charts/sync_loss_chart-{global_step}.png"), ("Sync loss", train_step_list, sync_loss_list), ) model_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt") state_dict = { "global_step": global_step, "state_dict": unet.module.state_dict(), # to unwrap DDP } try: torch.save(state_dict, model_save_path) logger.info(f"Saved checkpoint to {model_save_path}") except Exception as e: logger.error(f"Error saving model: {e}") # Validation logger.info("Running validation... ") validation_video_out_path = os.path.join(output_dir, f"val_videos/val_video_{global_step}.mp4") validation_video_mask_path = os.path.join(output_dir, f"val_videos/val_video_mask.mp4") with torch.autocast(device_type="cuda", dtype=torch.float16): pipeline( config.data.val_video_path, config.data.val_audio_path, validation_video_out_path, validation_video_mask_path, num_frames=config.data.num_frames, num_inference_steps=config.run.inference_steps, guidance_scale=config.run.guidance_scale, weight_dtype=torch.float16, width=config.data.resolution, height=config.data.resolution, mask=config.data.mask, ) logger.info(f"Saved validation video output to {validation_video_out_path}") val_step_list.append(global_step) if config.model.add_audio_layer: try: _, conf = syncnet_eval(syncnet_eval_model, syncnet_detector, validation_video_out_path, "temp") except Exception as e: logger.info(e) conf = 0 sync_conf_list.append(conf) plot_loss_chart( os.path.join(output_dir, f"loss_charts/sync_conf_chart-{global_step}.png"), ("Sync confidence", val_step_list, sync_conf_list), ) logs = {"step_loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= config.run.max_train_steps: break progress_bar.close() dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() # Config file path parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml") args = parser.parse_args() config = OmegaConf.load(args.unet_config_path) config.unet_config_path = args.unet_config_path main(config)