def train(rank, cfg: TrainConfig): if cfg.distributed.n_gpus_per_node > 1: init_process_group(backend=cfg.distributed.dist_backend, init_method=cfg.distributed.dist_url, world_size=cfg.distributed.n_nodes * cfg.distributed.n_gpus_per_node, rank=rank) device = torch.device(f'cuda:{rank:d}') model = ConvRNNEmbedder(cfg.model_cfg).to(device) loss_fn = GE2ELoss(device).to(device) logging.info(f"Initialized rank {rank}") if rank == 0: logging.getLogger().setLevel(logging.INFO) logging.info(f"Model initialized as:\n {model}") os.makedirs(cfg.checkpoint_path, exist_ok=True) logging.info(f"checkpoints directory : {cfg.checkpoint_path}") logging.info( f"Model has {sum([p.numel() for p in model.parameters()]):,d} parameters." ) steps = 0 if cfg.resume_checkpoint != '' and os.path.isfile(cfg.resume_checkpoint): state_dict = torch.load(cfg.resume_checkpoint, map_location=device) model.load_state_dict(state_dict['model_state_dict']) loss_fn.load_state_dict(state_dict['loss_fn_state_dict']) steps = state_dict['steps'] + 1 last_epoch = state_dict['epoch'] print( f"Checkpoint loaded from {cfg.resume_checkpoint}. Resuming training from {steps} steps at epoch {last_epoch}" ) else: state_dict = None last_epoch = -1 if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1: if rank == 0: logging.info("Multi-gpu detected") model = DDP(model, device_ids=[rank]).to(device) loss_fn = DDP(loss_fn, device_ids=[rank]).to(device) optim = torch.optim.AdamW(chain(model.parameters(), loss_fn.parameters()), 1.0, betas=cfg.betas) if state_dict is not None: optim.load_state_dict(state_dict['optim_state_dict']) train_df, valid_df = pd.read_csv(cfg.train_csv), pd.read_csv(cfg.valid_csv) trainset = UtteranceDS(train_df, cfg.sample_rate, cfg.n_uttr_per_spk) train_sampler = DistributedSampler( trainset ) if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else None train_loader = DataLoader(trainset, num_workers=cfg.num_workers, shuffle=False, sampler=train_sampler, batch_size=cfg.batch_size, pin_memory=False, drop_last=True, collate_fn=SpecialCollater( cfg.min_seq_len, cfg.max_seq_len)) if rank == 0: validset = UtteranceDS(valid_df, cfg.sample_rate, cfg.n_uttr_per_spk) validation_loader = DataLoader(validset, num_workers=cfg.num_workers, shuffle=False, sampler=None, batch_size=cfg.batch_size, pin_memory=False, drop_last=True, collate_fn=SpecialCollater( cfg.min_seq_len, cfg.max_seq_len)) sw = SummaryWriter(os.path.join(cfg.checkpoint_path, 'logs')) total_iters = cfg.n_epochs * len(train_loader) def sched_lam(x): return lin_one_cycle(cfg.start_lr, cfg.max_lr, cfg.end_lr, cfg.warmup_pct, total_iters, x) scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[sched_lam], last_epoch=steps - 1) if state_dict is not None: scheduler.load_state_dict(state_dict['scheduler_state_dict']) if cfg.fp16: scaler = GradScaler() if state_dict is not None and 'scaler_state_dict' in state_dict: scaler.load_state_dict(state_dict['scaler_state_dict']) model.train() if rank == 0: mb = master_bar(range(max(0, last_epoch), cfg.n_epochs)) smooth_loss = None else: mb = range(max(0, last_epoch), cfg.n_epochs) for epoch in mb: if rank == 0: start = time.time() mb.write("Epoch: {}".format(epoch + 1)) if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1: train_sampler.set_epoch(epoch) if rank == 0: pb = progress_bar(enumerate(train_loader), total=len(train_loader), parent=mb) else: pb = enumerate(train_loader) for i, batch in pb: if rank == 0: start_b = time.time() x, xlen = batch x = x.to(device, non_blocking=True) xlen = xlen.to(device, non_blocking=True) optim.zero_grad() with torch.cuda.amp.autocast(enabled=cfg.fp16): embeds = model(x, xlen) loss = loss_fn(embeds) if cfg.fp16: scaler.scale(loss).backward() scaler.unscale_(optim) gnorm = torch.nn.utils.clip_grad.clip_grad_norm_( model.parameters(), cfg.grad_clip) torch.nn.utils.clip_grad.clip_grad_norm_( loss_fn.parameters(), cfg.grad_clip / 2) scaler.step(optim) scaler.update() else: loss.backward() gnorm = torch.nn.utils.clip_grad.clip_grad_norm_( model.parameters(), cfg.grad_clip) torch.nn.utils.clip_grad.clip_grad_norm_( loss_fn.parameters(), cfg.grad_clip / 2) optim.step() if rank == 0: if smooth_loss is None: smooth_loss = float(loss.item()) else: smooth_loss = smooth_loss + 0.1 * (float(loss.item()) - smooth_loss) # STDOUT logging if steps % cfg.stdout_interval == 0: mb.write('steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}, peak mem: {:5.2f}GB'. \ format(steps, loss.item(), time.time() - start_b, torch.cuda.max_memory_allocated()/1e9)) mb.child.comment = 'steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}'. \ format(steps, loss.item(), time.time() - start_b) # mb.write(f"lr = {float(optim.param_groups[0]['lr'])}") # checkpointing if steps % cfg.checkpoint_interval == 0 and steps != 0: checkpoint_path = f"{cfg.checkpoint_path}/ckpt_{steps:08d}.pt" torch.save( { 'model_state_dict': (model.module if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else model).state_dict(), 'loss_fn_state_dict': (loss_fn.module if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else loss_fn).state_dict(), 'optim_state_dict': optim.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'scaler_state_dict': (scaler.state_dict() if cfg.fp16 else None), 'steps': steps, 'epoch': epoch }, checkpoint_path) logging.info(f"Saved checkpoint to {checkpoint_path}") # Tensorboard summary logging if steps % cfg.summary_interval == 0: sw.add_scalar("training/loss_smooth", smooth_loss, steps) sw.add_scalar("training/loss_raw", loss.item(), steps) sw.add_scalar( "ge2e/w", float((loss_fn.module if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else loss_fn).w.item()), steps) sw.add_scalar( "ge2e/b", float((loss_fn.module if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else loss_fn).b.item()), steps) sw.add_scalar("opt/lr", float(optim.param_groups[0]['lr']), steps) sw.add_scalar('opt/grad_norm', float(gnorm), steps) # Validation if steps % cfg.validation_interval == 0 and steps != 0: model.eval() loss_fn.eval() torch.cuda.empty_cache() val_err_tot = 0 flat_embeds = [] flat_lbls = [] with torch.no_grad(): for j, batch in progress_bar( enumerate(validation_loader), total=len(validation_loader), parent=mb): x, xlen = batch embeds = model(x.to(device), xlen.to(device)) val_err_tot += loss_fn(embeds) if j <= 2: lbls = [ f'spk-{j}-{indr:03d}' for indr in range(cfg.batch_size) for _ in range(cfg.n_uttr_per_spk) ] fembeds = embeds.view( cfg.batch_size * cfg.n_uttr_per_spk, cfg.model_cfg.fc_dim) flat_embeds.append(fembeds.cpu()) flat_lbls.extend(lbls) elif j == 3: flat_embeds = torch.cat(flat_embeds, dim=0) sw.add_embedding(flat_embeds, metadata=flat_lbls, global_step=steps) val_err = val_err_tot / (j + 1) sw.add_scalar("validation/loss", val_err, steps) mb.write( f"validation run complete at {steps:,d} steps. validation loss: {val_err:5.4f}" ) model.train() loss_fn.train() sw.add_scalar("memory/max_allocated_gb", torch.cuda.max_memory_allocated() / 1e9, steps) sw.add_scalar("memory/max_reserved_gb", torch.cuda.max_memory_reserved() / 1e9, steps) torch.cuda.reset_peak_memory_stats() torch.cuda.reset_accumulated_memory_stats() steps += 1 scheduler.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start))) sw.add_hparams(flatten_cfg(cfg), metric_dict={'validation/loss': val_err}, run_name=f'run-{cfg.checkpoint_path}') print("Training completed!")
class Amp: def __init__( self, enabled: bool = False, max_norm: Optional[float] = None, ) -> None: self.grad_scaler = GradScaler(enabled=enabled) self.enabled = enabled self.max_norm = max_norm _logger.info("amp: %s", self.enabled) if self.max_norm: _logger.info( "you are using grad clip, don't forget to pass params in") def autocast(self): return autocast(enabled=self.enabled) def scale(self, outputs: TensorOrIterableTensors) -> TensorOrIterableTensors: return self.grad_scaler.scale(outputs) def unscale_(self, optimizer: Optimizer): return self.grad_scaler.unscale_(optimizer) def step(self, optimizer: Optimizer, *args, **kwargs): return self.grad_scaler.step(optimizer, *args, **kwargs) def update(self, new_scale: Union[float, Tensor, None] = None): return self.grad_scaler.update(new_scale=new_scale) def clip_grad_norm_(self, params: TensorOrIterableTensors): torch.nn.utils.clip_grad_norm_(params, self.max_norm) def state_dict(self) -> dict: return self.grad_scaler.state_dict() def load_state_dict(self, state_dict: dict): return self.grad_scaler.load_state_dict(state_dict) def __call__( self, loss: Tensor, optimizer: torch.optim.Optimizer, parameters: Optional[TensorOrIterableTensors] = None, zero_grad_set_to_none: bool = False, ): self.scale(loss).backward() if self.max_norm is not None: assert parameters is not None self.unscale_(optimizer) self.clip_grad_norm_(parameters) self.grad_scaler.step(optimizer) self.grad_scaler.update() optimizer.zero_grad(set_to_none=zero_grad_set_to_none) def backward( self, loss: Tensor, optimizer: torch.optim.Optimizer, parameters: Optional[TensorOrIterableTensors] = None, ): return self(loss, optimizer, parameters=parameters)