Ejemplo n.º 1
0
class Fp16OptimizerHook(OptimizerHook):
    def __init__(self, grad_clip=None, grad_scaler_config=None):
        super().__init__(grad_clip)
        self._grad_scaler_config = grad_scaler_config
        self._scaler = None

    def before_train(self):
        if self._grad_scaler_config is None:
            self._scaler = GradScaler()
        else:
            self._scaler = GradScaler(**self._grad_scaler_config)

    def after_train_iter(self):
        loss = self.trainer.output[
            'loss'] / self.trainer.gradient_accumulation_steps
        self._scaler.scale(loss).backward()
        if self._grad_clip is not None:
            self._scaler.unscale_(self.trainer.optimizer)
            self._clip_grad_norm()

        if (self.trainer.iter +
                1) % self.trainer.gradient_accumulation_steps == 0:
            self._scaler.step(self.trainer.optimizer)
            self._scaler.update()
Ejemplo n.º 2
0
class Trainer:
    def __init__(self, config, device='cuda:0'):
        self.config = config
        self.device = torch.device(device)
        self.model_dir = self.config['model_dir']
        self.dev_gold_path = os.path.join(self.model_dir, 'dev-gold.txt')
        self.dev_pred_path = os.path.join(self.model_dir, 'dev-pred.txt')
        self.best_smatch = -1  # May get reloaded if using a checkpoint
        self.start_epoch = 1  # May get reloaded if using a checkpoint
        os.makedirs(self.model_dir, exist_ok=True)

    def train(self, checkpoint=None):
        self.load_model(
            checkpoint
        )  # sets self.model, tokenizer, optimizer, .., best_smatch, start_epoch
        self.load_training_data()  # sets self.train_loader
        self.load_eval_data()  # sets self.inference, graphs_gold
        # Loop through max epochs
        assert self.start_epoch < self.config[
            'max_epochs']  # May fail if reloading a checkpoint
        for epoch in range(self.start_epoch, self.config['max_epochs'] + 1):
            # Setup batch
            print('Training epoch %d' % epoch)
            trn_amr_loss = RunningAverage()
            self.optimizer.zero_grad()
            pbar = tqdm(total=len(self.train_loader.dataset), ncols=100)
            self.set_pbar_desc_train(pbar, None)
            self.model.train()
            # Loop through all the data
            for bnum, batch in enumerate(self.train_loader):
                x, y, extra = batch
                with autocast(enabled=self.config['fp16']):
                    rdict = self.model(**x, **y)
                    loss = rdict['loss']
                self.scaler.scale(
                    (loss / self.config['accum_steps'])).backward()
                trn_amr_loss.update(loss.item())
                # Perform an update every accum_steps
                if (bnum + 1) % self.config['accum_steps'] == 0:
                    self.step_otimizer()
                # Update progress
                pbar.update(x['input_ids'].shape[0])
                self.set_pbar_desc_train(pbar, trn_amr_loss.value)
            pbar.close()
            # Perform an update with the last batch if it wasn't already done in the loop
            if (bnum + 1) % self.config['accum_steps'] != 0:
                self.step_otimizer()
            # Run evaluate, compute smatch and save the model if it's the new best
            try:
                smatch = self.evaluate()
                if smatch > self.best_smatch:
                    self.best_smatch = smatch
                    self.save_and_remove_checkpoints(epoch, smatch)
            except:
                print('!! Evaluation / save failed !!')
                logger.exception('Evaluation or model save failed')
            print()

    # Run Inference and evaluate the model
    def evaluate(self):
        self.model.eval()
        sents = [g.metadata['snt'] for g in self.graphs_gold]
        graphs_gen = self.inference.parse_sents(sents,
                                                return_penman=True,
                                                disable_progress=False,
                                                pbar_desc='%-14s' %
                                                'Evaluating:')
        assert len(graphs_gen) == len(self.graphs_gold)
        # Detect bad graphs. In Penman 1.2.0, metadata does not impact penam.Graph.__eq__()
        num_bad = sum(g == Inference.invalid_graph for g in graphs_gen)
        print('Out of %d graphs, %d did not generate properly.' %
              (len(graphs_gen), num_bad))
        # Save the final graphs
        print('Generated graphs written to', self.dev_pred_path)
        penman.dump(graphs_gen, self.dev_pred_path, indent=6, model=amr_model)
        # Run smatch
        try:
            gold_entries = get_entries(self.dev_gold_path)
            test_entries = get_entries(self.dev_pred_path)
            precision, recall, f_score = compute_smatch(
                test_entries, gold_entries)
            print('SMATCH -> P: %.3f,  R: %.3f,  F: %.3f' %
                  (precision, recall, f_score))
        except:
            logger.exception('Failed to compute smatch score.')
            precision, recall, f_score = 0, 0, 0
        return f_score

    # Save the checkpoints if this is the best score
    def save_and_remove_checkpoints(self, epoch, smatch):
        prev_checkpoints = [
            fn for fn in os.listdir(self.model_dir) if fn.endswith('.pt')
        ]
        model_fn = 'checkpoint_epoch_%02d_smatch_%04d.pt' % (epoch,
                                                             smatch * 10000)
        model_fpath = os.path.join(self.model_dir, model_fn)
        # Create the dictionary with the optional optimizer and save it
        print('Saving new, best model to', model_fpath)
        save_dict = {'model': self.model.state_dict()}
        if self.config.get('save_optimizer'):
            save_dict['optimizer'] = self.optimizer.state_dict()
            save_dict['scheduler'] = self.scheduler.state_dict()
        torch.save(save_dict, model_fpath)
        # Save the config file
        self.config['smatch_dev'] = smatch
        self.config['last_epoch'] = epoch
        with open(os.path.join(self.model_dir, 'config.json'), 'w') as f:
            json.dump(self.config, f, indent=4)
        # Remove previous checkpoints
        for chkpt_fn in prev_checkpoints:
            os.remove(os.path.join(self.model_dir, chkpt_fn))

    # Load and setup the model, tokenizer, optimizer, etc..
    def load_model(self, checkpoint=None):
        print('Loading model from', self.config['model'])
        self.model, self.tokenizer = instantiate_model_and_tokenizer(
            self.config['model'],
            additional_tokens_smart_init=self.config['smart_init'],
            dropout=self.config['dropout'],
            attention_dropout=self.config['attention_dropout'],
            penman_linearization=self.config['penman_linearization'],
            collapse_name_ops=self.config['collapse_name_ops'],
            use_pointer_tokens=self.config['use_pointer_tokens'],
            raw_graph=self.config['raw_graph'])
        self.model.to(self.device)
        # Load optimization components
        self.optimizer = AdamW(self.model.parameters(),
                               lr=self.config['learning_rate'],
                               weight_decay=self.config['weight_decay'])
        self.scheduler = transformers.get_constant_schedule_with_warmup(
            self.optimizer, num_warmup_steps=self.config['warmup_steps'])
        self.scaler = GradScaler(enabled=self.config['fp16'])
        # Reload checkpoint model weights and optimizer params if loading from a checkpoint
        if checkpoint is not None:
            print('Checkpoint %s restored' % checkpoint)
            load_state_dict_from_checkpoint(checkpoint, self.model,
                                            self.optimizer, self.scheduler)
            # Try to load the smatch score and last_epoch from the config in the model directory.
            try:
                with open(os.path.join(self.model_dir, 'config.json')) as f:
                    model_config = json.load(f)
                self.best_smatch = model_config['smatch_dev']
                self.start_epoch = model_config['last_epoch'] + 1
            except:
                logger.exception(
                    'Unable to load config file in model directory')

    # Setup the training data loader
    def load_training_data(self):
        print('Loading train data from', self.config['train'])
        self.train_loader = get_dataloader(
            self.tokenizer,
            glob_pattern=self.config['train'],
            evaluation=False,
            batch_size=self.config['batch_size'],
            use_recategorization=self.config['use_recategorization'],
            remove_longer_than=self.config['remove_longer_than'],
            remove_wiki=self.config['remove_wiki'],
            dereify=self.config['dereify'],
            device=self.device)

    # Setup the inference object and create the gold data test file
    def load_eval_data(self):
        print('Loading eval data from ', self.config['dev'])
        self.inference = Inference(model=self.model,
                                   tokenizer=self.tokenizer,
                                   device=self.device,
                                   num_beams=self.config['eval_beam_size'],
                                   batch_size=self.config['eval_batch_sents'],
                                   config=self.config)
        self.graphs_gold = read_raw_amr_data(
            self.config['dev'],
            use_recategorization=self.config['use_recategorization'],
            dereify=self.config['dereify'],
            remove_wiki=self.config['remove_wiki'])
        penman.dump(self.graphs_gold,
                    self.dev_gold_path,
                    indent=6,
                    model=amr_model)

    # Function to update the model's parameters for accumulated loss
    def step_otimizer(self):
        self.scaler.unscale_(self.optimizer)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                       self.config['grad_norm'])
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
        self.scheduler.step()

    # Update tqdm progress bar description with loss values
    @staticmethod
    def set_pbar_desc_train(pbar, av_loss):
        desc = 'Loss: '
        if av_loss is None:
            desc += ' ' * 8
        else:
            desc += '%8.3f' % av_loss
        pbar.set_description(desc)
Ejemplo n.º 3
0
def train(rank, cfg: TrainConfig):
    if cfg.distributed.n_gpus_per_node > 1:
        init_process_group(backend=cfg.distributed.dist_backend,
                           init_method=cfg.distributed.dist_url,
                           world_size=cfg.distributed.n_nodes *
                           cfg.distributed.n_gpus_per_node,
                           rank=rank)

    device = torch.device(f'cuda:{rank:d}')

    model = ConvRNNEmbedder(cfg.model_cfg).to(device)
    loss_fn = GE2ELoss(device).to(device)

    logging.info(f"Initialized rank {rank}")

    if rank == 0:
        logging.getLogger().setLevel(logging.INFO)
        logging.info(f"Model initialized as:\n {model}")
        os.makedirs(cfg.checkpoint_path, exist_ok=True)
        logging.info(f"checkpoints directory : {cfg.checkpoint_path}")
        logging.info(
            f"Model has {sum([p.numel() for p in model.parameters()]):,d} parameters."
        )

    steps = 0
    if cfg.resume_checkpoint != '' and os.path.isfile(cfg.resume_checkpoint):
        state_dict = torch.load(cfg.resume_checkpoint, map_location=device)
        model.load_state_dict(state_dict['model_state_dict'])
        loss_fn.load_state_dict(state_dict['loss_fn_state_dict'])
        steps = state_dict['steps'] + 1
        last_epoch = state_dict['epoch']
        print(
            f"Checkpoint loaded from {cfg.resume_checkpoint}. Resuming training from {steps} steps at epoch {last_epoch}"
        )
    else:
        state_dict = None
        last_epoch = -1

    if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1:
        if rank == 0: logging.info("Multi-gpu detected")
        model = DDP(model, device_ids=[rank]).to(device)
        loss_fn = DDP(loss_fn, device_ids=[rank]).to(device)

    optim = torch.optim.AdamW(chain(model.parameters(), loss_fn.parameters()),
                              1.0,
                              betas=cfg.betas)
    if state_dict is not None:
        optim.load_state_dict(state_dict['optim_state_dict'])

    train_df, valid_df = pd.read_csv(cfg.train_csv), pd.read_csv(cfg.valid_csv)

    trainset = UtteranceDS(train_df, cfg.sample_rate, cfg.n_uttr_per_spk)

    train_sampler = DistributedSampler(
        trainset
    ) if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1 else None

    train_loader = DataLoader(trainset,
                              num_workers=cfg.num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=cfg.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=SpecialCollater(
                                  cfg.min_seq_len, cfg.max_seq_len))

    if rank == 0:
        validset = UtteranceDS(valid_df, cfg.sample_rate, cfg.n_uttr_per_spk)
        validation_loader = DataLoader(validset,
                                       num_workers=cfg.num_workers,
                                       shuffle=False,
                                       sampler=None,
                                       batch_size=cfg.batch_size,
                                       pin_memory=False,
                                       drop_last=True,
                                       collate_fn=SpecialCollater(
                                           cfg.min_seq_len, cfg.max_seq_len))

        sw = SummaryWriter(os.path.join(cfg.checkpoint_path, 'logs'))

    total_iters = cfg.n_epochs * len(train_loader)

    def sched_lam(x):
        return lin_one_cycle(cfg.start_lr, cfg.max_lr, cfg.end_lr,
                             cfg.warmup_pct, total_iters, x)

    scheduler = torch.optim.lr_scheduler.LambdaLR(optim,
                                                  lr_lambda=[sched_lam],
                                                  last_epoch=steps - 1)

    if state_dict is not None:
        scheduler.load_state_dict(state_dict['scheduler_state_dict'])

    if cfg.fp16:
        scaler = GradScaler()
        if state_dict is not None and 'scaler_state_dict' in state_dict:
            scaler.load_state_dict(state_dict['scaler_state_dict'])

    model.train()

    if rank == 0:
        mb = master_bar(range(max(0, last_epoch), cfg.n_epochs))
        smooth_loss = None
    else:
        mb = range(max(0, last_epoch), cfg.n_epochs)

    for epoch in mb:
        if rank == 0:
            start = time.time()
            mb.write("Epoch: {}".format(epoch + 1))

        if cfg.distributed.n_gpus_per_node * cfg.distributed.n_nodes > 1:
            train_sampler.set_epoch(epoch)

        if rank == 0:
            pb = progress_bar(enumerate(train_loader),
                              total=len(train_loader),
                              parent=mb)
        else:
            pb = enumerate(train_loader)

        for i, batch in pb:
            if rank == 0: start_b = time.time()
            x, xlen = batch
            x = x.to(device, non_blocking=True)
            xlen = xlen.to(device, non_blocking=True)

            optim.zero_grad()

            with torch.cuda.amp.autocast(enabled=cfg.fp16):
                embeds = model(x, xlen)
                loss = loss_fn(embeds)
            if cfg.fp16:
                scaler.scale(loss).backward()
                scaler.unscale_(optim)
                gnorm = torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), cfg.grad_clip)
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    loss_fn.parameters(), cfg.grad_clip / 2)
                scaler.step(optim)
                scaler.update()
            else:
                loss.backward()
                gnorm = torch.nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), cfg.grad_clip)
                torch.nn.utils.clip_grad.clip_grad_norm_(
                    loss_fn.parameters(), cfg.grad_clip / 2)
                optim.step()

            if rank == 0:
                if smooth_loss is None: smooth_loss = float(loss.item())
                else:
                    smooth_loss = smooth_loss + 0.1 * (float(loss.item()) -
                                                       smooth_loss)
                # STDOUT logging
                if steps % cfg.stdout_interval == 0:
                    mb.write('steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}, peak mem: {:5.2f}GB'. \
                            format(steps, loss.item(), time.time() - start_b, torch.cuda.max_memory_allocated()/1e9))
                    mb.child.comment = 'steps : {:,d}, loss : {:4.3f}, sec/batch : {:4.3f}'. \
                            format(steps, loss.item(), time.time() - start_b)
                    # mb.write(f"lr = {float(optim.param_groups[0]['lr'])}")

                # checkpointing
                if steps % cfg.checkpoint_interval == 0 and steps != 0:
                    checkpoint_path = f"{cfg.checkpoint_path}/ckpt_{steps:08d}.pt"
                    torch.save(
                        {
                            'model_state_dict':
                            (model.module if cfg.distributed.n_gpus_per_node *
                             cfg.distributed.n_nodes > 1 else
                             model).state_dict(),
                            'loss_fn_state_dict':
                            (loss_fn.module
                             if cfg.distributed.n_gpus_per_node *
                             cfg.distributed.n_nodes > 1 else
                             loss_fn).state_dict(),
                            'optim_state_dict':
                            optim.state_dict(),
                            'scheduler_state_dict':
                            scheduler.state_dict(),
                            'scaler_state_dict':
                            (scaler.state_dict() if cfg.fp16 else None),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        }, checkpoint_path)
                    logging.info(f"Saved checkpoint to {checkpoint_path}")

                # Tensorboard summary logging
                if steps % cfg.summary_interval == 0:
                    sw.add_scalar("training/loss_smooth", smooth_loss, steps)
                    sw.add_scalar("training/loss_raw", loss.item(), steps)
                    sw.add_scalar(
                        "ge2e/w",
                        float((loss_fn.module
                               if cfg.distributed.n_gpus_per_node *
                               cfg.distributed.n_nodes > 1 else
                               loss_fn).w.item()), steps)
                    sw.add_scalar(
                        "ge2e/b",
                        float((loss_fn.module
                               if cfg.distributed.n_gpus_per_node *
                               cfg.distributed.n_nodes > 1 else
                               loss_fn).b.item()), steps)
                    sw.add_scalar("opt/lr", float(optim.param_groups[0]['lr']),
                                  steps)
                    sw.add_scalar('opt/grad_norm', float(gnorm), steps)

                # Validation
                if steps % cfg.validation_interval == 0 and steps != 0:
                    model.eval()
                    loss_fn.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    flat_embeds = []
                    flat_lbls = []
                    with torch.no_grad():
                        for j, batch in progress_bar(
                                enumerate(validation_loader),
                                total=len(validation_loader),
                                parent=mb):
                            x, xlen = batch
                            embeds = model(x.to(device), xlen.to(device))
                            val_err_tot += loss_fn(embeds)

                            if j <= 2:
                                lbls = [
                                    f'spk-{j}-{indr:03d}'
                                    for indr in range(cfg.batch_size)
                                    for _ in range(cfg.n_uttr_per_spk)
                                ]
                                fembeds = embeds.view(
                                    cfg.batch_size * cfg.n_uttr_per_spk,
                                    cfg.model_cfg.fc_dim)
                                flat_embeds.append(fembeds.cpu())
                                flat_lbls.extend(lbls)
                            elif j == 3:
                                flat_embeds = torch.cat(flat_embeds, dim=0)
                                sw.add_embedding(flat_embeds,
                                                 metadata=flat_lbls,
                                                 global_step=steps)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/loss", val_err, steps)
                        mb.write(
                            f"validation run complete at {steps:,d} steps. validation loss: {val_err:5.4f}"
                        )

                    model.train()
                    loss_fn.train()
                    sw.add_scalar("memory/max_allocated_gb",
                                  torch.cuda.max_memory_allocated() / 1e9,
                                  steps)
                    sw.add_scalar("memory/max_reserved_gb",
                                  torch.cuda.max_memory_reserved() / 1e9,
                                  steps)
                    torch.cuda.reset_peak_memory_stats()
                    torch.cuda.reset_accumulated_memory_stats()

            steps += 1
            scheduler.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))
    sw.add_hparams(flatten_cfg(cfg),
                   metric_dict={'validation/loss': val_err},
                   run_name=f'run-{cfg.checkpoint_path}')
    print("Training completed!")
Ejemplo n.º 4
0
class Amp:
    def __init__(
        self,
        enabled: bool = False,
        max_norm: Optional[float] = None,
    ) -> None:
        self.grad_scaler = GradScaler(enabled=enabled)
        self.enabled = enabled
        self.max_norm = max_norm

        _logger.info("amp: %s", self.enabled)
        if self.max_norm:
            _logger.info(
                "you are using grad clip, don't forget to pass params in")

    def autocast(self):
        return autocast(enabled=self.enabled)

    def scale(self, outputs: TensorOrIterableTensors) -> TensorOrIterableTensors:
        return self.grad_scaler.scale(outputs)

    def unscale_(self, optimizer: Optimizer):
        return self.grad_scaler.unscale_(optimizer)

    def step(self, optimizer: Optimizer, *args, **kwargs):
        return self.grad_scaler.step(optimizer, *args, **kwargs)

    def update(self, new_scale: Union[float, Tensor, None] = None):
        return self.grad_scaler.update(new_scale=new_scale)

    def clip_grad_norm_(self, params: TensorOrIterableTensors):
        torch.nn.utils.clip_grad_norm_(params, self.max_norm)

    def state_dict(self) -> dict:
        return self.grad_scaler.state_dict()

    def load_state_dict(self, state_dict: dict):
        return self.grad_scaler.load_state_dict(state_dict)

    def __call__(
        self,
        loss: Tensor,
        optimizer: torch.optim.Optimizer,
        parameters: Optional[TensorOrIterableTensors] = None,
        zero_grad_set_to_none: bool = False,
    ):
        self.scale(loss).backward()

        if self.max_norm is not None:
            assert parameters is not None
            self.unscale_(optimizer)
            self.clip_grad_norm_(parameters)

        self.grad_scaler.step(optimizer)
        self.grad_scaler.update()
        optimizer.zero_grad(set_to_none=zero_grad_set_to_none)

    def backward(
        self,
        loss: Tensor,
        optimizer: torch.optim.Optimizer,
        parameters: Optional[TensorOrIterableTensors] = None,
    ):
        return self(loss, optimizer, parameters=parameters)