示例#1
0
def train(rank, cfg: TrainConfig):
    if cfg.distributed.n_gpus_per_node > 1:
        init_process_group(backend=cfg.distributed.dist_backend,
                           init_method=cfg.distributed.dist_url,
                           world_size=cfg.distributed.n_nodes *
                           cfg.distributed.n_gpus_per_node,
                           rank=rank)

    device = torch.device(f'cuda:{rank:d}')

    model = ConvRNNEmbedder(cfg.model_cfg).to(device)
    loss_fn = GE2ELoss(device).to(device)

    logging.info(f"Initialized rank {rank}")

    if rank == 0:
        logging.getLogger().setLevel(logging.INFO)
        logging.info(f"Model initialized as:\n {model}")
        os.makedirs(cfg.checkpoint_path, exist_ok=True)
        logging.info(f"checkpoints directory : {cfg.checkpoint_path}")
        logging.info(
            f"Model has {sum([p.numel() for p in model.parameters()]):,d} parameters."
        )

    steps = 0
    if cfg.resume_checkpoint != '' and os.path.isfile(cfg.resume_checkpoint):
        state_dict = torch.load(cfg.resume_checkpoint, map_location=device)
        model.load_state_dict(state_dict['model_state_dict'])
        loss_fn.load_state_dict(state_dict['loss_fn_state_dict'])
        steps = state_dict['steps'] + 1
        last_epoch = state_dict['epoch']
        print(
            f"Checkpoint loaded from {cfg.resume_checkpoint}. Resuming training from {steps} steps at epoch {last_epoch}"
        )
    else:
        state_dict = None
        last_epoch = -1

    if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1:
        if rank == 0: logging.info("Multi-gpu detected")
        model = DDP(model, device_ids=[rank]).to(device)
        loss_fn = DDP(loss_fn, device_ids=[rank]).to(device)

    optim = torch.optim.AdamW(chain(model.parameters(), loss_fn.parameters()),
                              1.0,
                              betas=cfg.betas)
    if state_dict is not None:
        optim.load_state_dict(state_dict['optim_state_dict'])

    train_df, valid_df = pd.read_csv(cfg.train_csv), pd.read_csv(cfg.valid_csv)

    trainset = UtteranceDS(train_df, cfg.sample_rate, cfg.n_uttr_per_spk)

    train_sampler = DistributedSampler(
        trainset
    ) if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else None

    train_loader = DataLoader(trainset,
                              num_workers=cfg.num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=cfg.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=SpecialCollater(
                                  cfg.min_seq_len, cfg.max_seq_len))

    if rank == 0:
        validset = UtteranceDS(valid_df, cfg.sample_rate, cfg.n_uttr_per_spk)
        validation_loader = DataLoader(validset,
                                       num_workers=cfg.num_workers,
                                       shuffle=False,
                                       sampler=None,
                                       batch_size=cfg.batch_size,
                                       pin_memory=False,
                                       drop_last=True,
                                       collate_fn=SpecialCollater(
                                           cfg.min_seq_len, cfg.max_seq_len))

        sw = SummaryWriter(os.path.join(cfg.checkpoint_path, 'logs'))

    total_iters = cfg.n_epochs * len(train_loader)

    def sched_lam(x):
        return lin_one_cycle(cfg.start_lr, cfg.max_lr, cfg.end_lr,
                             cfg.warmup_pct, total_iters, x)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optim,
                                                  lr_lambda=[sched_lam],
                                                  last_epoch=steps - 1)

    if state_dict is not None:
        scheduler.load_state_dict(state_dict['scheduler_state_dict'])

    if cfg.fp16:
        scaler = GradScaler()
        if state_dict is not None and 'scaler_state_dict' in state_dict:
            scaler.load_state_dict(state_dict['scaler_state_dict'])

    model.train()

    if rank == 0:
        mb = master_bar(range(max(0, last_epoch), cfg.n_epochs))
        smooth_loss = None
    else:
        mb = range(max(0, last_epoch), cfg.n_epochs)

    for epoch in mb:
        if rank == 0:
            start = time.time()
            mb.write("Epoch: {}".format(epoch + 1))

        if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1:
            train_sampler.set_epoch(epoch)

        if rank == 0:
            pb = progress_bar(enumerate(train_loader),
                              total=len(train_loader),
                              parent=mb)
        else:
            pb = enumerate(train_loader)

        for i, batch in pb:
            if rank == 0: start_b = time.time()
            x, xlen = batch
            x = x.to(device, non_blocking=True)
            xlen = xlen.to(device, non_blocking=True)

            optim.zero_grad()

            with torch.cuda.amp.autocast(enabled=cfg.fp16):
                embeds = model(x, xlen)
                loss = loss_fn(embeds)
            if cfg.fp16:
                scaler.scale(loss).backward()
                scaler.unscale_(optim)
                gnorm = torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), cfg.grad_clip)
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    loss_fn.parameters(), cfg.grad_clip / 2)
                scaler.step(optim)
                scaler.update()
            else:
                loss.backward()
                gnorm = torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), cfg.grad_clip)
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    loss_fn.parameters(), cfg.grad_clip / 2)
                optim.step()

            if rank == 0:
                if smooth_loss is None: smooth_loss = float(loss.item())
                else:
                    smooth_loss = smooth_loss + 0.1 * (float(loss.item()) -
                                                       smooth_loss)
                # STDOUT logging
                if steps % cfg.stdout_interval == 0:
                    mb.write('steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}, peak mem: {:5.2f}GB'. \
                            format(steps, loss.item(), time.time() - start_b, torch.cuda.max_memory_allocated()/1e9))
                    mb.child.comment = 'steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}'. \
                            format(steps, loss.item(), time.time() - start_b)
                    # mb.write(f"lr = {float(optim.param_groups[0]['lr'])}")

                # checkpointing
                if steps % cfg.checkpoint_interval == 0 and steps != 0:
                    checkpoint_path = f"{cfg.checkpoint_path}/ckpt_{steps:08d}.pt"
                    torch.save(
                        {
                            'model_state_dict':
                            (model.module if cfg.distributed.n_gpus_per_node *
                             cfg.distributed.n_nodes > 1 else
                             model).state_dict(),
                            'loss_fn_state_dict':
                            (loss_fn.module
                             if cfg.distributed.n_gpus_per_node *
                             cfg.distributed.n_nodes > 1 else
                             loss_fn).state_dict(),
                            'optim_state_dict':
                            optim.state_dict(),
                            'scheduler_state_dict':
                            scheduler.state_dict(),
                            'scaler_state_dict':
                            (scaler.state_dict() if cfg.fp16 else None),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        }, checkpoint_path)
                    logging.info(f"Saved checkpoint to {checkpoint_path}")

                # Tensorboard summary logging
                if steps % cfg.summary_interval == 0:
                    sw.add_scalar("training/loss_smooth", smooth_loss, steps)
                    sw.add_scalar("training/loss_raw", loss.item(), steps)
                    sw.add_scalar(
                        "ge2e/w",
                        float((loss_fn.module
                               if cfg.distributed.n_gpus_per_node *
                               cfg.distributed.n_nodes > 1 else
                               loss_fn).w.item()), steps)
                    sw.add_scalar(
                        "ge2e/b",
                        float((loss_fn.module
                               if cfg.distributed.n_gpus_per_node *
                               cfg.distributed.n_nodes > 1 else
                               loss_fn).b.item()), steps)
                    sw.add_scalar("opt/lr", float(optim.param_groups[0]['lr']),
                                  steps)
                    sw.add_scalar('opt/grad_norm', float(gnorm), steps)

                # Validation
                if steps % cfg.validation_interval == 0 and steps != 0:
                    model.eval()
                    loss_fn.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    flat_embeds = []
                    flat_lbls = []
                    with torch.no_grad():
                        for j, batch in progress_bar(
                                enumerate(validation_loader),
                                total=len(validation_loader),
                                parent=mb):
                            x, xlen = batch
                            embeds = model(x.to(device), xlen.to(device))
                            val_err_tot += loss_fn(embeds)

                            if j <= 2:
                                lbls = [
                                    f'spk-{j}-{indr:03d}'
                                    for indr in range(cfg.batch_size)
                                    for _ in range(cfg.n_uttr_per_spk)
                                ]
                                fembeds = embeds.view(
                                    cfg.batch_size * cfg.n_uttr_per_spk,
                                    cfg.model_cfg.fc_dim)
                                flat_embeds.append(fembeds.cpu())
                                flat_lbls.extend(lbls)
                            elif j == 3:
                                flat_embeds = torch.cat(flat_embeds, dim=0)
                                sw.add_embedding(flat_embeds,
                                                 metadata=flat_lbls,
                                                 global_step=steps)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/loss", val_err, steps)
                        mb.write(
                            f"validation run complete at {steps:,d} steps. validation loss: {val_err:5.4f}"
                        )

                    model.train()
                    loss_fn.train()
                    sw.add_scalar("memory/max_allocated_gb",
                                  torch.cuda.max_memory_allocated() / 1e9,
                                  steps)
                    sw.add_scalar("memory/max_reserved_gb",
                                  torch.cuda.max_memory_reserved() / 1e9,
                                  steps)
                    torch.cuda.reset_peak_memory_stats()
                    torch.cuda.reset_accumulated_memory_stats()

            steps += 1
            scheduler.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
    sw.add_hparams(flatten_cfg(cfg),
                   metric_dict={'validation/loss': val_err},
                   run_name=f'run-{cfg.checkpoint_path}')
    print("Training completed!")
示例#2
0
class Amp:
    def __init__(
        self,
        enabled: bool = False,
        max_norm: Optional[float] = None,
    ) -> None:
        self.grad_scaler = GradScaler(enabled=enabled)
        self.enabled = enabled
        self.max_norm = max_norm

        _logger.info("amp: %s", self.enabled)
        if self.max_norm:
            _logger.info(
                "you are using grad clip, don't forget to pass params in")

    def autocast(self):
        return autocast(enabled=self.enabled)

    def scale(self, outputs: TensorOrIterableTensors) -> TensorOrIterableTensors:
        return self.grad_scaler.scale(outputs)

    def unscale_(self, optimizer: Optimizer):
        return self.grad_scaler.unscale_(optimizer)

    def step(self, optimizer: Optimizer, *args, **kwargs):
        return self.grad_scaler.step(optimizer, *args, **kwargs)

    def update(self, new_scale: Union[float, Tensor, None] = None):
        return self.grad_scaler.update(new_scale=new_scale)

    def clip_grad_norm_(self, params: TensorOrIterableTensors):
        torch.nn.utils.clip_grad_norm_(params, self.max_norm)

    def state_dict(self) -> dict:
        return self.grad_scaler.state_dict()

    def load_state_dict(self, state_dict: dict):
        return self.grad_scaler.load_state_dict(state_dict)

    def __call__(
        self,
        loss: Tensor,
        optimizer: torch.optim.Optimizer,
        parameters: Optional[TensorOrIterableTensors] = None,
        zero_grad_set_to_none: bool = False,
    ):
        self.scale(loss).backward()

        if self.max_norm is not None:
            assert parameters is not None
            self.unscale_(optimizer)
            self.clip_grad_norm_(parameters)

        self.grad_scaler.step(optimizer)
        self.grad_scaler.update()
        optimizer.zero_grad(set_to_none=zero_grad_set_to_none)

    def backward(
        self,
        loss: Tensor,
        optimizer: torch.optim.Optimizer,
        parameters: Optional[TensorOrIterableTensors] = None,
    ):
        return self(loss, optimizer, parameters=parameters)