Example #1
0
 def save_checkpoint(self, filename, extra_state):
     """Save all training state in a checkpoint file."""
     if distributed_utils.is_master(self.args):  # only save one checkpoint
         extra_state['train_meters'] = self.meters
         utils.save_state(
             filename, self.args, self.model, self.criterion, self.optimizer,
             self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
         )
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
        return
    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
            end_of_epoch and not args.no_epoch_checkpoints and
            epoch % args.save_interval == 0
    )
    checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
            not end_of_epoch and args.save_interval_updates > 0 and
            updates % args.save_interval_updates == 0
    )
    checkpoint_conds['checkpoint_best.pt'] = (
            val_loss is not None and
            (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
    )
    checkpoint_conds['checkpoint_last.pt'] = True  # keep this last so that it's a symlink

    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
    extra_state = {
        'best': save_checkpoint.best,
        'train_iterator': epoch_itr.state_dict(),
        'val_loss': val_loss,
    }

    checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
    if len(checkpoints) > 0:
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
        for old_chk in checkpoints[args.keep_interval_updates:]:
            os.remove(old_chk)
Example #3
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    trainer.begin_valid_epoch(epoch_itr.epoch)
    valid_losses = []
    for subset in subsets:
        logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
        if getattr(args, "tpu", False):
            itr = utils.tpu_data_loader(itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=("tqdm"
                                if not args.no_progress_bar else "simple"),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses
Example #4
0
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'):
    if args.log_format is None:
        args.log_format = no_progress_bar if args.no_progress_bar else default

    if args.log_format == 'tqdm' and not sys.stderr.isatty():
        args.log_format = 'simple'

    if args.log_format == 'json':
        bar = json_progress_bar(iterator, epoch, prefix, args.log_interval)
    elif args.log_format == 'none':
        bar = noop_progress_bar(iterator, epoch, prefix)
    elif args.log_format == 'simple':
        bar = simple_progress_bar(iterator, epoch, prefix, args.log_interval)
    elif args.log_format == 'tqdm':
        bar = tqdm_progress_bar(iterator, epoch, prefix)
    else:
        raise ValueError('Unknown log format: {}'.format(args.log_format))

    if args.tensorboard_logdir and distributed_utils.is_master(args):
        bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir)

    return bar
Example #5
0
def validate_save_and_evaluate_bleu(
    args,
    trainer,
    dataset,
    extra_state: Dict[str, Any],
    do_validate: bool,
    do_save: bool,
    do_eval_bleu: bool,
) -> Tuple[Optional[float], Optional[float], Optional[float], bool]:
    # evaluate on validate set
    val_loss = None
    val_ppl = None
    stop_due_to_val_loss = False
    if do_validate:
        val_loss, val_ppl, stop_due_to_val_loss = validate(
            args=args,
            trainer=trainer,
            dataset=dataset,
            subset=args.valid_subset,
            epoch=extra_state["epoch"],
        )
    extra_state["val_loss"] = val_loss

    val_bleu = None
    stop_due_to_val_bleu = False
    if do_save and distributed_utils.is_master(args):
        # save checkpoint
        save_checkpoint(trainer=trainer, args=args, extra_state=extra_state)
        if do_eval_bleu:
            val_bleu, stop_due_to_val_bleu = evaluate_bleu(
                args=args,
                dataset=dataset,
                epoch=extra_state["epoch"],
                offset=extra_state["batch_offset"],
            )

    return (val_loss, val_ppl, val_bleu, stop_due_to_val_loss
            or stop_due_to_val_bleu)
Example #6
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch,
                                               load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Example #7
0
def validate_metric(args, trainer, task, epoch_itr, subsets):
    # when training with distributed trainer, only one of them (the one args.distributed_rank == 0) is working ...
    print('args.distributed_rank', args.distributed_rank)
    print('args.distributed_world_size', args.distributed_world_size)
    if not distributed_utils.is_master(args):
        return
    """Evaluate the model on the validation set(s) and return the losses."""
    for subset in subsets:

        model_output_placeholder = os.path.join(
            args.save_dir, '{}.{}.txt'.format('placeholder', subset))
        model_output_file_list = []

        # fout = open(model_output_file, 'w', encoding='utf8')
        # # firstly, output dictionary information
        # fout.write('%d\n'%len(task.target_dictionary))
        # for i in range(len(task.target_dictionary)):
        #     fout.write('{}\t{}\n'.format(task.target_dictionary[i], i))
        #     fout.flush()

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            # max_positions=trainer.get_model().max_positions(),
            max_positions=None,
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=1,
            shard_id=0,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')
        cnt = 0
        for sample in progress:
            preds = []
            scores = []
            trainer.model.eval()
            sample = utils.move_to_cuda(sample)
            # net_output = trainer.model(args.lam1, args.lam2, args.transpose_method, **sample['net_input'])
            with torch.no_grad():
                net_output = trainer.model(**sample['net_input'])
            # probs = trainer.model.get_normalized_probs(net_output, log_probs=False)
            # _, pred = probs.max(2)
            if isinstance(net_output[0], list):
                if len(model_output_file_list) < len(net_output[0]):
                    for idx, sub_net_output in enumerate(net_output[0]):
                        model_output_file_list.append(
                            init_output_file(
                                model_output_placeholder.replace(
                                    'placeholder', str(idx)),
                                task.target_dictionary))
                for sub_net_output, sub_score in zip(net_output[0],
                                                     net_output[1]):
                    preds.append(sub_net_output)
                    scores.append(sub_score)
            else:
                if len(model_output_file_list) == 0:
                    model_output_file_list.append(
                        init_output_file(
                            model_output_placeholder.replace(
                                'placeholder', '1'), task.target_dictionary))
                preds.append(net_output[0])
                scores.append(net_output[1])

            if sample.get('target', None) is not None:
                target = trainer.model.get_targets(sample, net_output)
                if target.size(1) > preds[0].size(1):
                    target = target[:, :preds[0].size(1)]
            else:
                target = torch.ones_like(preds[0])
            target = torch.where(preds[0] == 0, torch.zeros_like(preds[0]),
                                 target.int())
            assert len(preds) == len(scores) == len(model_output_file_list)
            for pred, score, fout in zip(preds, scores,
                                         model_output_file_list):
                for i in range(pred.size(0)):
                    labels = []
                    pred_labels = []
                    pred_dists = []
                    pred_scores = []
                    for j in range(pred.size(1)):
                        if target[i, j] != task.target_dictionary.pad():
                            labels.append(task.target_dictionary[target[i, j]])
                            pred_labels.append(task.target_dictionary[pred[i,
                                                                           j]])
                            pred_scores.append(
                                str(round(score[i, j].item(), 5)))
                            # pred_dists.append( ' '.join( map(lambda x: str(x.item()), probs[i, j]) ) )d
                        else:
                            break
                    fout.write('True      Labels:\t%s\n' % ' '.join(labels))
                    fout.write('Predicted Labels:\t%s\n' %
                               ' '.join(pred_labels))
                    fout.write('Score:\t%s\n' % ' '.join(pred_scores))
                    fout.write('Predicted Distri:\t%s\n' %
                               ' | '.join(pred_dists))
                    fout.flush()
            assert cnt == sample['id'][0]
            cnt += sample['id'].shape[0]

        for fout in model_output_file_list:
            fout.close()
            utils.xprintln('valid metric %s done!' % fout.name)
Example #8
0
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask,
          epoch_itr) -> Tuple[List[Optional[float]], bool]:
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
    )
    update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(cfg.optimization.update_freq) else
                   cfg.optimization.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if cfg.common.tpu:
        itr = utils.tpu_data_loader(itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(cfg.common.tensorboard_logdir
                            if distributed_utils.is_master(
                                cfg.distributed_training) else None),
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
        wandb_project=(cfg.common.wandb_project if distributed_utils.is_master(
            cfg.distributed_training) else None),
        wandb_run_name=os.environ.get(
            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        azureml_logging=(cfg.common.azureml_logging
                         if distributed_utils.is_master(
                             cfg.distributed_training) else False),
    )
    progress.update_config(_flatten_config(cfg))

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = cfg.dataset.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    logger.info("Start iterating over samples")
    for i, samples in enumerate(progress):
        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % cfg.common.log_interval == 0:
                stats = get_training_stats(
                    metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(cfg, trainer, task,
                                                      epoch_itr, valid_subsets,
                                                      end_of_epoch)

        if should_stop:
            break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
Example #9
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = utils.tpu_data_loader(itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
    )

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    for i, samples in enumerate(progress):
        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % args.log_interval == 0:
                stats = get_training_stats(
                    metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(args, trainer, task,
                                                      epoch_itr, valid_subsets,
                                                      end_of_epoch)

        if should_stop:
            break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
Example #10
0
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    from fairseq import distributed_utils, meters

    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        best_function = max if args.maximize_best_checkpoint_metric else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if args.no_save or not distributed_utils.is_master(args):
        return

    def is_better(a, b):
        return a >= b if args.maximize_best_checkpoint_metric else a <= b

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
        end_of_epoch and not args.no_epoch_checkpoints
        and epoch % args.save_interval == 0)
    checkpoint_conds["checkpoint_{}_{}.pt".format(
        epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0
                            and updates % args.save_interval_updates == 0)
    checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
        not hasattr(save_checkpoint, "best")
        or is_better(val_loss, save_checkpoint.best))
    checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
        args.best_checkpoint_metric,
        val_loss)] = (val_loss is not None and args.keep_best_checkpoints > 0
                      and (not hasattr(save_checkpoint, "best")
                           or is_better(val_loss, save_checkpoint.best)))
    checkpoint_conds["checkpoint_last.pt"] = not args.no_last_checkpoints

    extra_state = {
        "train_iterator": epoch_itr.state_dict(),
        "val_loss": val_loss
    }
    if hasattr(save_checkpoint, "best"):
        extra_state.update({"best": save_checkpoint.best})

    checkpoints = [
        os.path.join(args.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            PathManager.copy(checkpoints[0], cp, overwrite=True)

        write_timer.stop()
        logger.info(
            "| saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)"
            .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum))

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(args.save_dir,
                                       pattern=r"checkpoint_\d+_(\d+)\.pt")
        for old_chk in checkpoints[args.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(args.save_dir,
                                       pattern=r"checkpoint(\d+)\.pt")
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_best_checkpoints > 0:
        # only keep the best N checkpoints according to validation metric
        checkpoints = checkpoint_paths(
            args.save_dir,
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
                args.best_checkpoint_metric))
        if not args.maximize_best_checkpoint_metric:
            checkpoints = checkpoints[::-1]
        for old_chk in checkpoints[args.keep_best_checkpoints:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
Example #11
0
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    if args.no_save or not distributed_utils.is_master(args):
        return

    write_timer = StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
        end_of_epoch and not args.no_epoch_checkpoints
        and epoch % args.save_interval == 0)
    checkpoint_conds['checkpoint_{}_{}.pt'.format(
        epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0
                            and updates % args.save_interval_updates == 0)
    checkpoint_conds['checkpoint_best.pt'] = (
        val_loss is not None and (not hasattr(save_checkpoint, 'best')
                                  or val_loss < save_checkpoint.best))
    checkpoint_conds[
        'checkpoint_last.pt'] = True  # keep this last so that it's a symlink

    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        save_checkpoint.best = min(val_loss, prev_best)
    extra_state = {
        'train_iterator': epoch_itr.state_dict(),
        'val_loss': val_loss,
    }
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})

    checkpoints = [
        os.path.join(args.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
        for old_chk in checkpoints[args.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0 and epoch > 1:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(args.save_dir,
                                             pattern=r'checkpoint(\d+)\.pt')
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    write_timer.stop()

    print(
        '| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'
        .format(checkpoints[0], epoch, updates, write_timer.sum))
Example #12
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch,
                                                criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args.max_tokens,
            args.max_sentences,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')

    tokenize = sacrebleu.DEFAULT_TOKENIZER if not args.eval_tokenized_bleu else 'none'
    hyps, refs = validate(args, trainer, task, epoch_itr, valid_subsets)

    for h, r, split in zip(hyps, refs, args.valid_subset.split(',')):
        assert len(h) == len(r)

        sacrebleu_score, _, _ = sacrebleu.corpus_bleu(
            h, [r], tokenize=tokenize), hyps, refs
        bleu = compute_cvpr_bleu(h, r)
        rouge_score = rouge.rouge(h, r)

        print('{} set has {} samples,\n'
              'sacrebleu: {},\n'
              'CVPR BLEU scripts: {}\n'
              'CVPR ROUGE: {}'.format(split, len(h), sacrebleu_score, bleu,
                                      rouge_score))

        print('performance: {:.2f} {}'.format(
            rouge_score['rouge_l/f_score'] * 100,
            ' '.join([str(b) for b in bleu])))
Example #13
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr, filtered_maxpos_indices = checkpoint_utils.load_checkpoint(
        args, trainer)

    # pretrain data actor
    # only the language actor model can be pretrained

    if args.pretrain_laser and args.pretrain_data_actor and args.data_actor == 'ave':
        # pretrain the agent with LASER score
        # epoch_itr, indices = trainer.get_train_iterator(1)
        path = '/home/wtan12/multiDDS/'
        trainer.pretrain_LASER('en-ps.laser-score', epoch_itr)

    if args.compare_laser:
        epoch_itr, indices = trainer.get_train_iterator(1)
        print('Number of Indices: ', len(indices))
        scores = collections.defaultdict(float)
        # compare with laser label using R^2 Score, only used after model is trained
        # itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=False, shuffle=False)
        data_actor = trainer.data_actor
        itr = epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=args.fix_batches_to_gpus,
            shuffle=False,
            offset=0,
            datasize=-1,
        )
        for i, sample in enumerate(itr):
            sample = trainer._prepare_sample(sample)
            sample = list(sample.values())[0]
            score = data_actor(sample).cpu().detach().numpy().tolist()
            indices = sample['id'].data.cpu().numpy().ravel().tolist()
            for k, v in zip(indices, score):
                scores[k] = float(v[0])

        scores = sorted(scores.items(), key=lambda x: x[0])
        print('Number of Indices in Scoring file: ', len(scores))
        path = '/home/wtan12/multiDDS/'
        with open(path + 'en-ps.laser-score', 'r') as r:
            data = r.read()
        laser_score = []
        for i, item in enumerate(data.split('\n')):
            laser_score.append(item)
        laser_score.pop()
        r2 = 0.0
        with open(path + 'en-ps.dds_score', 'w') as f:
            for k, v in scores:
                f.write(str(v) + '\n')
                truth = float(laser_score[k])
                r2 += (truth - v)**2
        print('R2 Score compared to LASER file: ', r2)
        return

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    if args.eval_bleu:
        generator = task.build_generator(args)
        args.maximize_best_checkpoint_metric = True
    else:
        generator = None
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        epoch_itr = train(args, trainer, task, epoch_itr, generator,
                          filtered_maxpos_indices)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets, generator)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)[0]
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Example #14
0
def validate(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.FairseqTask,
    epoch_itr,
    subsets: List[str],
) -> List[Optional[float]]:
    """Evaluate the model on the validation set(s) and return the losses."""

    if cfg.dataset.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(cfg.dataset.fixed_validation_seed)

    trainer.begin_valid_epoch(epoch_itr.epoch)
    valid_losses = []
    predictions = []
    for subset in subsets:
        logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(
            shuffle=False,
            set_dataset_epoch=False  # use a fixed valid set
        )
        if cfg.common.tpu:
            itr = utils.tpu_data_loader(itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=cfg.common.log_format,
            log_interval=cfg.common.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(cfg.common.tensorboard_logdir
                                if distributed_utils.is_master(
                                    cfg.distributed_training) else None),
            default_log_format=("tqdm" if not cfg.common.no_progress_bar else
                                "simple"),
            wandb_project=(cfg.common.wandb_project
                           if distributed_utils.is_master(
                               cfg.distributed_training) else None),
            wandb_run_name=os.environ.get(
                "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                prediction, _ = trainer.valid_step(sample)
                predictions.extend(prediction)

        with open(
                cfg.criterion.save_predictions +
                str(torch.cuda.current_device()) + ".txt", "w") as f:
            for prediction in predictions:
                f.write(prediction)
                f.write("\n")

        if trainer.is_data_parallel_master:
            with open(cfg.criterion.save_predictions + ".txt", "w") as outf:
                for i in range(torch.cuda.device_count()):
                    with open(cfg.criterion.save_predictions + str(i) + ".txt",
                              "r") as inf:
                        lines = inf.read()
                        outf.write(lines)

        # log validation stats
        stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]
                            )  ###############################error
    return valid_losses
Example #15
0
 def is_data_parallel_master(self):
     return distributed_utils.is_master(self.cfg.distributed_training)
Example #16
0
def save_checkpoint_bleu(args, trainer, epoch_itr, valid_losses, valid_bleus,
                         valid_select, begin):
    if args.no_save or not distributed_utils.is_master(args):
        return
    epoch = epoch_itr.epoch
    if begin:
        end_of_epoch = True
    else:
        end_of_epoch = epoch_itr.end_of_epoch()

    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
        end_of_epoch and not args.no_epoch_checkpoints
        and epoch % args.save_interval == 0)
    checkpoint_conds['checkpoint_{}_{}.pt'.format(
        epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0
                            and updates % args.save_interval_updates == 0)
    checkpoint_conds['checkpoint_best_bleu.pt'] = (
        valid_select in valid_bleus.keys()
        and (not hasattr(save_checkpoint_bleu, 'best_bleu')
             or valid_bleus[valid_select] > save_checkpoint_bleu.best_bleu))

    checkpoint_conds[
        'checkpoint_last.pt'] = True  # keep this last so that it's a symlink

    prev_best_bleu = getattr(save_checkpoint_bleu, 'best_bleu',
                             valid_bleus[valid_select])

    if valid_select in valid_bleus.keys():
        save_checkpoint_bleu.best_bleu = max(valid_bleus[valid_select],
                                             prev_best_bleu)

    extra_state = {
        'train_iterator': epoch_itr.state_dict(),
    }
    for domain, bleu_domain in valid_bleus.items():
        extra_state.update({'valid_loss_' + domain: valid_losses[domain]})
        extra_state.update({'valid_bleu_' + domain: valid_bleus[domain]})

    if hasattr(save_checkpoint_bleu, 'best_bleu'):
        extra_state.update({'best_bleu': save_checkpoint_bleu.best_bleu})

    checkpoints = [
        os.path.join(args.save_dir, fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        for cp in checkpoints:
            trainer.save_checkpoint(cp, extra_state)

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(
            args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
        for old_chk in checkpoints[args.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = utils.checkpoint_paths(args.save_dir,
                                             pattern=r'checkpoint(\d+)\.pt')
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
Example #17
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
    metrics.reset()

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch,
                                                criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args.max_tokens,
            args.max_sentences,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while (lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch):
        # train for one epoch
        valid_losses = train(args, trainer, task, epoch_itr, max_update)
        if should_stop_early(
                args,
                valid_losses[0]) or trainer.get_num_updates() >= max_update:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
Example #18
0
def validate(args, trainer, task, epoch_itr, subsets, prune=-1):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        # added by Junxian
        if prune > 0:
            index_map = trainer.get_model().set_prune_index(prune)
            task.set_index_map(index_map)

        # not write templates for time profiling
        write_template_flag = False if args.eval_mode == 'time' else True

        # only one worker deals with the template file in DDP
        if args.distributed_rank == 0 and write_template_flag:
            print('write template files')

            if args.eval_mode == 'none':
                fout = open(
                    os.path.join(
                        args.save_dir, 'templates_{}_{}.txt'.format(
                            epoch_itr.epoch, trainer.get_num_updates())), 'w')
            else:
                fout = open(
                    os.path.join(args.save_dir,
                                 'templates_eval_{}.txt'.format(subset)), 'w')

            if prune <= 0:
                task.write_lambda(fout, trainer.get_model())
        else:
            fout = None

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample, split=subset)

                # added by Junxian
                if args.distributed_rank == 0:
                    task.write_template(sample, trainer.get_model(), fout)

        if fout is not None:
            fout.close()

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses
Example #19
0
def main(cfg: DictConfig) -> None:
    if isinstance(cfg, argparse.Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"
    metrics.reset()

    np.random.seed(cfg.common.seed)
    utils.set_torch_seed(cfg.common.seed)

    if distributed_utils.is_master(cfg.distributed_training):
        checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)

    # Print args
    logger.info(cfg)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(cfg.task)
    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in cfg.dataset.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    assert cfg.criterion, "Please specify criterion to train a model"

    # Build model and criterion
    model = task.build_model(cfg.model)
    criterion = task.build_criterion(cfg.criterion)
    logger.info(model)
    logger.info("task: {}".format(task.__class__.__name__))
    logger.info("model: {}".format(model.__class__.__name__))
    logger.info("criterion: {})".format(criterion.__class__.__name__))
    logger.info("num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if cfg.common.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=cfg.common.quantization_config_path,
            max_epoch=cfg.optimization.max_epoch,
            max_update=cfg.optimization.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if cfg.common.model_parallel_size == 1:
        trainer = Trainer(cfg, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(cfg, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        cfg.distributed_training.distributed_world_size))
    logger.info("max tokens per GPU = {} and batch size per GPU = {}".format(
        cfg.dataset.max_tokens,
        cfg.dataset.batch_size,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        cfg.checkpoint,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    max_epoch = cfg.optimization.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
Example #20
0
def validate_iw(args, trainer, task, epoch_itr, subsets, prune=-1, mode='iw'):
    """Evaluate the model on the validation set(s) and return the losses."""

    if mode == 'none' or mode == 'time' or args.criterior == 'lm_baseline':
        return [0]

    # top k instead of sampling to approximate sum of prototypes for evaluation
    for subset in subsets:
        task.dataset(subset).set_sampling(False)

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=1,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        if prune > 0:
            index_map = trainer.get_model().set_prune_index(prune)
            task.set_index_map(index_map)

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_iw_step(sample, mode=mode)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag='valid_iw', step=trainer.get_num_updates())

        # valid_losses.append(stats[args.best_checkpoint_metric])

        if prune > 0:
            trainer.get_model().reset_prune_index()
            task.reset_index_map()

    return valid_losses
Example #21
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info('max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')

    print(args.multi_views)

    while (
        lr > args.min_lr
        and (
            epoch_itr.epoch < max_epoch
            # allow resuming training from the final checkpoint
            or epoch_itr._next_epoch_itr is not None
        )
        and trainer.get_num_updates() < max_update
    ):
        

        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        
        
        bart = BARTHubInterface(args, task, trainer.model).cuda()
        #print(bart.device)
        bart.eval()
        count = 1
        bsz = 8


        print("Test on val set: ")
        

        with open('../data/val_sent_trans_cons_label.source') as source, open('../data/val_sent_c99_label.source') as source2, open('./val_best_multi_attn_'+str(args.lr_weight)+'_.hypo', 'wt', encoding='utf-8') as fout:
            s1 = source.readlines()
            s2 = source2.readlines()
            
            slines = [s1[0].strip()]
            slines2 = [s2[0].strip()]
            
            for i in tqdm(range(1, len(s1))):
                if count % bsz == 0:
                    with torch.no_grad():
                        if args.multi_views:
                            hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                        else:
                            hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                    for hypothesis in hypotheses_batch:
                        fout.write(hypothesis + '\n')
                        fout.flush()
                    slines = []
                    slines2 = []
                
                slines.append(s1[i].strip())
                slines2.append(s2[i].strip())
            
                count += 1
                
            if slines != []:
                if args.multi_views:
                    hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                else:
                    hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                #hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                for hypothesis in hypotheses_batch:
                    fout.write(hypothesis + '\n')
                    fout.flush()
        hyp_path = './val_best_multi_attn_'+str(args.lr_weight)+'_.hypo'
        ref_path = '../data/val_sent_trans_cons_label.target'
        hypothesis = []
        with open(hyp_path, 'r') as f:
            lines = f.readlines()
            for l in lines:
                hypothesis.append(l[:-1])
        
        reference = []
        with open(ref_path, 'r') as f:
            lines = f.readlines()
            for l in lines:
                reference.append(l[:-1])

        rouge = Rouge()
        print("Val", rouge.get_scores(hypothesis, reference, avg = True))
        

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
        
        
        print("Test on testing set: ")

        count = 1
        bsz = 8
        with open('../data/test_sent_trans_cons_label.source') as source, open('../data/test_sent_c99_label.source') as source2, open('./test_best_multi_attn_'+str(args.lr_weight)+'_.hypo', 'wt', encoding='utf-8') as fout:
            s1 = source.readlines()
            s2 = source2.readlines()
            
            slines = [s1[0].strip()]
            slines2 = [s2[0].strip()]
            
            for i in tqdm(range(1, len(s1))):
                if count % bsz == 0:
                    with torch.no_grad():
                        if args.multi_views:
                            hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                        else:
                            hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                    for hypothesis in hypotheses_batch:
                        fout.write(hypothesis + '\n')
                        fout.flush()
                    slines = []
                    slines2 = []
                
                slines.append(s1[i].strip())
                slines2.append(s2[i].strip())
            
                count += 1
                
            if slines != []:
                if args.multi_views:
                    hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                else:
                    hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3)
                
                for hypothesis in hypotheses_batch:
                    fout.write(hypothesis + '\n')
                    fout.flush()
        hyp_path = './test_best_multi_attn_'+str(args.lr_weight)+'_.hypo'
        ref_path = '../data/test_sent_trans_cons_label.target'
        hypothesis = []
        with open(hyp_path, 'r') as f:
            lines = f.readlines()
            for l in lines:
                hypothesis.append(l[:-1])
        
        reference = []
        with open(ref_path, 'r') as f:
            lines = f.readlines()
            for l in lines:
                reference.append(l[:-1])

        rouge = Rouge()
        print('Test', rouge.get_scores(hypothesis, reference, avg = True))
        

        # early stop
        if should_stop_early(args, valid_losses[0]):
            logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience))
            break

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.epoch,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
Example #22
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch,
                                                criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args.max_tokens,
            args.max_sentences,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')

    if args.eval_mode != 'none':
        start_val_time = time.time()
        with torch.no_grad():
            if args.eval_mode != 'entropy':
                _ = validate(args, trainer, task, epoch_itr, valid_subsets,
                             args.prune_num)
            print('elapsed time (seconds): {}'.format(time.time() -
                                                      start_val_time))

            _ = validate_iw(args,
                            trainer,
                            task,
                            epoch_itr,
                            valid_subsets,
                            args.prune_num,
                            mode=args.eval_mode)
        return

    while (lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        # early stop
        if should_stop_early(args, valid_losses[0]):
            logger.info(
                'early stop since valid performance hasn\'t improved for last {} runs'
                .format(args.patience))
            break

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))

    # _ = validate_iw(args, trainer, task, epoch_itr, valid_subsets)

    train_meter.stop()
Example #23
0
 def is_data_parallel_master(self):
     return distributed_utils.is_master(self.args)
Example #24
0
def train(args, trainer, task, epoch_itr, max_update=math.inf, model=None):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    for i, samples in enumerate(progress):
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')
        if(i==0):
            print('epoch: ', epoch_itr.epoch)
            endeattn_norm=[]
            selfattn_norm=[]
            for m in model.modules():
                if(hasattr(m, 'selfattn_norm')):
                    if(m.selfattn_norm != None):
                        selfattn_norm.append(m.selfattn_norm)
                if(hasattr(m, 'endeattn_norm')):
                    if(m.endeattn_norm != None):
                        endeattn_norm.append(m.endeattn_norm)
            print('self attention norms: ', selfattn_norm)
            print('en/decoder attn norms:', endeattn_norm)
        valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
        if should_stop_early(args, valid_losses[0]) or num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return valid_losses
Example #25
0
def main(args):
    utils.import_user_module(args)

    assert (
        args.max_tokens is not None or args.max_sentences is not None
    ), "Must specify batch size either with --max-tokens or --max-sentences"

    metrics.reset()

    np.random.seed(args.seed)
    utils.set_torch_seed(args.seed)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
    logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
    logger.info("criterion: {} ({})".format(args.criterion,
                                            criterion.__class__.__name__))
    logger.info("num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        args.distributed_world_size))
    logger.info(
        "max tokens per GPU = {} and max sentences per GPU = {}".format(
            args.max_tokens, args.max_sentences))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()

    while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop = train(args, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
Example #26
0
def main(args):
    utils.import_user_module(args)

    assert (
        args.max_tokens is not None or args.max_sentences is not None
    ), "Must specify batch size either with --max-tokens or --max-sentences"

    metrics.reset()

    np.random.seed(args.seed)
    utils.set_torch_seed(args.seed)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)
        checkpoint_utils.verify_checkpoint_directory(args.jason_log_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
    logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
    logger.info(
        "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)
    )
    logger.info(
        "num. model params: {} (num. trained: {})".format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad),
        )
    )

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info(
        "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size)
    )
    logger.info(
        "max tokens per GPU = {} and max sentences per GPU = {}".format(
            args.max_tokens, args.max_sentences
        )
    )

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        args,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()

    ##### begin jason #####
    updates_list = []; train_ppl_list = []; train_loss_list = []; val_ppl_list = []; val_loss_list = []; train_uid_loss_list = []; val_uid_loss_list = []
    log_writer = open(os.path.join(args.save_dir, 'train_logs.csv'), 'w')
    log_writer.write(f'updates,train_loss,train_ppl,val_loss,val_ppl\n')
    backup_writefile = os.path.join(args.jason_log_dir, 'train_logs_backup.csv')
    os.system(f'touch {backup_writefile}')
    os.system(f'echo "updates,train_loss,train_ppl,val_loss,val_ppl,train_uid_loss,val_uid_loss" >> {backup_writefile}')
    ##### end jason #####

    while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop, train_stats, valid_stats = train(args, trainer, task, epoch_itr)
        print("hello", valid_stats, train_stats)

        ##### begin jason #####
        if train_stats and valid_stats: 
            updates_list.append(train_stats['num_updates'])
            train_loss_list.append(train_stats['loss'])
            train_ppl_list.append(train_stats['ppl'])
            val_loss_list.append(valid_stats['loss'])
            val_ppl_list.append(valid_stats['ppl'])
            if 'uid_loss' not in train_stats:
                train_stats['uid_loss'] = -1
                valid_stats['uid_loss'] = -1
            train_uid_loss_list.append(train_stats['uid_loss'])
            val_uid_loss_list.append(valid_stats['uid_loss'])
            log_line = f"{train_stats['num_updates']},{train_stats['loss']},{train_stats['ppl']},{valid_stats['loss']},{valid_stats['ppl']},{train_stats['uid_loss']},{valid_stats['uid_loss']}"
            log_writer.write(f"{log_line}\n")
            os.system(f'echo "{log_line}" >> {backup_writefile}')

            best_val_loss = min(val_loss_list)
            best_val_loss_idx = val_loss_list.index(best_val_loss)
            updates_to_best_val_loss = updates_list[best_val_loss_idx]
            train_loss_at_best_val_loss = train_loss_list[best_val_loss_idx]

            jasons_vis.plot_jasons_lineplot(
                x_list = updates_list,
                y_list_list = [train_loss_list, val_loss_list, train_uid_loss_list, val_uid_loss_list],
                y_labels_list = ['train', 'dev', 'train uid', 'dev uid'], 
                x_ax_label = "Updates",
                y_ax_label = "Loss",
                title = f"dev_l={best_val_loss} updates={updates_to_best_val_loss} train_l={train_loss_at_best_val_loss}",
                output_png_path = os.path.join(args.jason_log_dir, f"{args.jason_log_dir.split('/')[-1]}_loss.png"),
            )
            jasons_vis.plot_jasons_lineplot(
                x_list = updates_list,
                y_list_list = [train_ppl_list, val_ppl_list],
                y_labels_list = ['train', 'dev'], 
                x_ax_label = "Updates",
                y_ax_label = "Perplexity",
                title = f" best_val_ppl={best_val_loss} " + args.jason_log_dir[:20],
                output_png_path = os.path.join(args.jason_log_dir, f"{args.jason_log_dir.split('/')[-1]}_perplexity.png"),
            )
        ##### end jason #####

        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))
Example #27
0
def main(cfg: FairseqConfig) -> None:
    if isinstance(cfg, argparse.Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
        # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
        logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"
    metrics.reset()

    np.random.seed(cfg.common.seed)
    utils.set_torch_seed(cfg.common.seed)

    if distributed_utils.is_master(cfg.distributed_training):
        checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)

    # Print args
    logger.info(cfg)

    if cfg.checkpoint.write_checkpoints_asynchronously:
        try:
            import iopath  # noqa: F401
        except ImportError:
            logging.exception(
                "Asynchronous checkpoint writing is specified but iopath is "
                "not installed: `pip install iopath`")
            return

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(cfg.task)
    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in cfg.dataset.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    assert cfg.criterion, "Please specify criterion to train a model"

    # Build model and criterion
    model = task.build_model(cfg.model)
    criterion = task.build_criterion(cfg.criterion)
    logger.info(model)
    logger.info("task: {}".format(task.__class__.__name__))
    logger.info("model: {}".format(model.__class__.__name__))
    logger.info("criterion: {}".format(criterion.__class__.__name__))
    logger.info("num. model params: {:,} (num. trained: {:,})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # (optionally) Configure quantization
    if cfg.common.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=cfg.common.quantization_config_path,
            max_epoch=cfg.optimization.max_epoch,
            max_update=cfg.optimization.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if cfg.common.model_parallel_size == 1:
        trainer = Trainer(cfg, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(cfg, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        cfg.distributed_training.distributed_world_size))
    logger.info("max tokens per GPU = {} and batch size per GPU = {}".format(
        cfg.dataset.max_tokens,
        cfg.dataset.batch_size,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        cfg.checkpoint,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    max_epoch = cfg.optimization.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while epoch_itr.next_epoch_idx <= max_epoch:
        if lr <= cfg.optimization.stop_min_lr:
            logger.info(
                f"stopping training because current learning rate ({lr}) is smaller "
                "than or equal to minimum learning rate "
                f"(--stop-min-lr={cfg.optimization.stop_min_lr})")
            break

        # train for one epoch
        valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))

    # ioPath implementation to wait for all asynchronous file writes to complete.
    if cfg.checkpoint.write_checkpoints_asynchronously:
        logger.info(
            "ioPath PathManager waiting for all asynchronous checkpoint "
            "writes to finish.")
        PathManager.async_close()
        logger.info("ioPath PathManager finished waiting.")
Example #28
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    should_end_training = False
    for samples in progress:
        with metrics.aggregate('train_inner'):
            try:
                log_output = trainer.train_step(samples)

            except ResetTrainerException:
                trainer._wrapped_criterion = None
                trainer._wrapped_model = None
                trainer._optimizer = None

                logger.info("reset the trainer at {}".format(
                    trainer.get_num_updates()))
                log_output = trainer.train_step(samples)

            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        valid_losses = validate_and_save(args, trainer, task, epoch_itr,
                                         valid_subsets)
        if should_stop_early(args,
                             valid_losses[0]) or num_updates >= max_update:
            should_end_training = True
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return should_end_training
Example #29
0
def main(args):
    # we should not do this!
    '''
    if args.max_tokens is None:
        args.max_tokens = 6000
    '''
    utils.xpprint(args)

    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported')
    torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    utils.xprintln('setup task done!')

    # Load dataset splits
    load_dataset_splits(args, task, ['train'])
    valid_dataset = args.valid_subset.split(',')
    load_dataset_splits(args, task, valid_dataset, shuffle=False)
    utils.xprintln('load dataset done!')

    if args.task.startswith('extractive_summarization'):
        if distributed_utils.is_master(args):
            from sum_eval import MultiProcSumEval
            sum_eval_pool = MultiProcSumEval(args.ncpu_eval)
            sum_valid_pool_params = dict(
                article_file=args.raw_valid + '.article',
                summary_file=args.raw_valid + '.summary',
                entity_map_file=None,
                length=-1,
                eval_type='predict',
                topk=args.topk_sent_eval,
                rerank=False,
                with_m=False,
                cmd='-a -c 95 -m -n 4 -w 1.2',
                trigram_block=args.trigram_block,
            )

            sum_test_pool_params = dict(
                article_file=args.raw_test + '.article',
                summary_file=args.raw_test + '.summary',
                entity_map_file=None,
                length=-1,
                eval_type='predict',
                topk=args.topk_sent_eval,
                rerank=False,
                with_m=False,
                cmd='-a -c 95 -m -n 4 -w 1.2',
                trigram_block=args.trigram_block,
            )
            sum_pool_params = dict(valid=sum_valid_pool_params,
                                   test=sum_test_pool_params)

            def make_params(default_dict,
                            result_file,
                            out_rouge_file,
                            rerank=False,
                            with_m=False):
                para_dict = dict(default_dict)
                para_dict['result_file'] = result_file
                para_dict['out_rouge_file'] = out_rouge_file
                para_dict['rerank'] = rerank
                para_dict['with_m'] = with_m
                return para_dict

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {}'.format(
        sum(p.numel() for p in model.parameters())))
    # print(model)
    import sys
    sys.stdout.flush()

    # if summarization try to load pretrained model
    # if args.task.startswith('extractive_summarization') or args.task == 'pretrain_document_modeling':
    #     # assume this is a single GPU program
    if args.init_from_pretrained_doc_model:
        task.load_pretrained_model(model, args.pretrained_doc_model_path)
    sys.stdout.flush()

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Initialize dataloader
    max_positions = trainer.get_model().max_positions()
    epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=False)

    # Load the latest checkpoint if one is available
    # load_checkpoint(args, trainer, epoch_itr)
    # make sure training from a different checkpoint will use different random seed
    cur_dataset = task.dataset('train')
    if hasattr(cur_dataset, 'rng'):
        print('epoch ', epoch_itr.epoch)
        cur_dataset.rng = numpy.random.RandomState(args.seed + epoch_itr.epoch)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')
    for alpha in range(10, 9, -1):
        # train for one epoch
        # train(args, trainer, task, epoch_itr)

        epoch_itr.next_epoch_itr()

        if epoch_itr.epoch % args.validate_interval == 0:
            if args.task.startswith('extractive_summarization'):
                if distributed_utils.is_master(args):
                    validate_metric(args, trainer, task, epoch_itr,
                                    valid_subsets)
Example #30
0
def setup_training_state(args, trainer, task, epoch_itr):
    """Set up the directory for saving checkpoints.
    Load pretrained model if specified."""
    os.makedirs(args.save_dir, exist_ok=True)

    # If --restore-file is already present under --save-dir, use that one
    # instead of --pretrained-checkpoint-file. The idea is that
    # --pretrained-checkpoint-file allows the user to specify restoring from a
    # different run's checkpoint (possibly with different training params),
    # while not polluting the previous run's checkpoint directory
    # with new checkpoints. However, if training gets interrupted
    # and the user restarts training, we want to resume from
    # the checkpoints under --save-dir, instead of
    # restarting again from the old run's checkpoint at
    # --pretrained-checkpoint-file.
    #
    # Note that if args.restore_file is an absolute path, os.path.join() will
    # ignore previous directory args and just use the absolute path as is.
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    restore_state = True
    if os.path.isfile(checkpoint_path):
        print(
            f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}."
        )
    elif args.pretrained_checkpoint_file and os.path.isfile(
        args.pretrained_checkpoint_file
    ):
        checkpoint_path = args.pretrained_checkpoint_file
        restore_state = args.load_pretrained_checkpoint_state
        print(
            f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, "
            f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}."
        )

    extra_state = default_extra_state(args)
    if not os.path.isfile(checkpoint_path) and args.multi_model_restore_files:
        print(f"| Restoring individual models from {args.multi_model_restore_files}")
        multi_model.import_individual_models(args.multi_model_restore_files, trainer)
    else:
        loaded, loaded_extra_state = checkpoint.load_existing_checkpoint(
            checkpoint_path=checkpoint_path,
            trainer=trainer,
            restore_state=restore_state,
        )
        if loaded_extra_state:
            extra_state.update(loaded_extra_state)

    # Reset the start time for the current training run.
    extra_state["start_time"] = time.time()

    # Skips printing all training progress to prevent log spam.
    training_progress = extra_state["training_progress"]
    extra_state["training_progress"] = (
        ["...truncated...", training_progress[-1]] if len(training_progress) > 0 else []
    )
    print(f"| extra_state: {extra_state}")
    extra_state["training_progress"] = training_progress

    epoch = extra_state["epoch"]
    if extra_state["batch_offset"] == 0:
        epoch -= 1  # this will be incremented when we call epoch_itr.next_epoch_itr()
    epoch_itr.load_state_dict(
        {"epoch": epoch, "iterations_in_epoch": extra_state["batch_offset"]}
    )

    checkpoint_manager = None
    if distributed_utils.is_master(args):
        checkpoint_manager = checkpoint.CheckpointManager(
            num_avg_checkpoints=args.num_avg_checkpoints,
            auto_clear_checkpoints=args.auto_clear_checkpoints,
            log_verbose=args.log_verbose,
            checkpoint_files=extra_state["checkpoint_files"],
        )

    return extra_state, epoch_itr, checkpoint_manager
Example #31
0
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    from fairseq import distributed_utils, meters

    prev_best = getattr(save_checkpoint, 'best', val_loss)
    if val_loss is not None:
        best_function = max if args.maximize_best_checkpoint_metric else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if args.no_save or not distributed_utils.is_master(args):
        return

    def is_better(a, b):
        return a >= b if args.maximize_best_checkpoint_metric else a <= b

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
        end_of_epoch and not args.no_epoch_checkpoints and
        epoch % args.save_interval == 0
    )
    checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
        not end_of_epoch and args.save_interval_updates > 0 and
        updates % args.save_interval_updates == 0
    )
    checkpoint_conds['checkpoint_best.pt'] = (
        val_loss is not None and
        (not hasattr(save_checkpoint, 'best') or is_better(val_loss, save_checkpoint.best))
    )
    checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints

    extra_state = {
        'train_iterator': epoch_itr.state_dict(),
        'val_loss': val_loss,
    }
    if hasattr(save_checkpoint, 'best'):
        extra_state.update({'best': save_checkpoint.best})

    checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            shutil.copyfile(checkpoints[0], cp)

        write_timer.stop()
        print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
            checkpoints[0], epoch, updates, write_timer.sum))

    if not end_of_epoch and args.keep_interval_updates > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(
            args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
        )
        for old_chk in checkpoints[args.keep_interval_updates:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args.keep_last_epochs > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(
            args.save_dir, pattern=r'checkpoint(\d+)\.pt',
        )
        for old_chk in checkpoints[args.keep_last_epochs:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
Example #32
0
def main(args):
    import_user_module(args)

    assert (
        args.max_tokens is not None or args.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"

    metrics.reset()

    np.random.seed(args.seed)
    utils.set_torch_seed(args.seed)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(","):
        task.load_dataset(valid_sub_split, combine=False, epoch=1)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
    logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
    logger.info("criterion: {} ({})".format(args.criterion,
                                            criterion.__class__.__name__))
    logger.info("num. model params: {} (num. trained: {})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # breakpoint()

    # ========== initialize the model with pretrained BART parameters ==========
    # for shared embeddings and subtoken split for amr nodes
    if 'bartsv' in args.arch:

        if args.initialize_with_bart:
            logger.info(
                '-' * 10 +
                ' initializing model parameters with pretrained BART model ' +
                '-' * 10)

            new_state_dict = copy.deepcopy(task.bart.model.state_dict())
            # treat the embedding initialization separately later, as the size different
            logger.info(
                '-' * 10 +
                ' delay encoder embeddings, decoder input and output embeddings initialization '
                + '-' * 10)
            ignore_keys = set([
                'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight',
                'decoder.output_projection.weight'
            ])
            for k in ignore_keys:
                del new_state_dict[k]

            if not args.initialize_with_bart_enc:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART encoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('encoder'):
                        del new_state_dict[k]

            if not args.initialize_with_bart_dec:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART decoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('decoder'):
                        del new_state_dict[k]

            model.load_state_dict(new_state_dict, strict=False, args=args)

            # initialize the Bart part embeddings
            bart_vocab_size = task.target_dictionary.bart_vocab_size
            # NOTE we need to prune the pretrained BART embeddings, especially for bart.base
            bart_embed_weight = task.bart.model.encoder.embed_tokens.weight.data[:
                                                                                 bart_vocab_size]
            assert len(bart_embed_weight) == bart_vocab_size

            with torch.no_grad():
                model.encoder.embed_tokens.weight[:bart_vocab_size].copy_(
                    bart_embed_weight)
                model.decoder.embed_tokens.weight[:bart_vocab_size].copy_(
                    bart_embed_weight)
                model.decoder.output_projection.weight[:bart_vocab_size].copy_(
                    bart_embed_weight)

        if args.bart_emb_init_composition:
            logger.info(
                '-' * 10 +
                ' initialize extended target embeddings with compositional embeddings '
                'from BART vocabulary ' + '-' * 10)

            # breakpoint()
            symbols = [
                task.target_dictionary[idx]
                for idx in range(bart_vocab_size, len(task.target_dictionary))
            ]
            mapper = MapAvgEmbeddingBART(task.bart,
                                         task.bart.model.decoder.embed_tokens)
            comp_embed_weight, map_all = mapper.map_avg_embeddings(
                symbols, transform=transform_action_symbol, add_noise=False)
            assert len(comp_embed_weight) == len(symbols)

            with torch.no_grad():
                model.encoder.embed_tokens.weight[bart_vocab_size:].copy_(
                    comp_embed_weight)
                model.decoder.embed_tokens.weight[bart_vocab_size:].copy_(
                    comp_embed_weight)
                model.decoder.output_projection.weight[bart_vocab_size:].copy_(
                    comp_embed_weight)

    elif 'bart' in args.arch:

        if args.initialize_with_bart:
            logger.info(
                '-' * 10 +
                ' initializing model parameters with pretrained BART model ' +
                '-' * 10)

            new_state_dict = copy.deepcopy(task.bart.model.state_dict())
            if not args.bart_emb_decoder:
                logger.info('-' * 10 +
                            ' build a separate decoder dictionary embedding ' +
                            '-' * 10)
                if not args.bart_emb_decoder_input:
                    ignore_keys = set([
                        'decoder.embed_tokens.weight',
                        'decoder.output_projection.weight'
                    ])
                else:
                    logger.info(
                        '-' * 10 +
                        ' use BART dictionary embedding for target input ' +
                        '-' * 10)
                    ignore_keys = set(['decoder.output_projection.weight'])
                for k in ignore_keys:
                    del new_state_dict[k]

            if not args.initialize_with_bart_enc:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART encoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('encoder'):
                        del new_state_dict[k]

            if not args.initialize_with_bart_dec:
                logger.info(
                    '-' * 10 +
                    ' do not initialize with BART decoder parameters ' +
                    '-' * 10)
                for k in list(new_state_dict.keys()):
                    if k.startswith('decoder'):
                        del new_state_dict[k]

            model.load_state_dict(new_state_dict, strict=False, args=args)

        # initialize the target embeddings with average of subtoken embeddings in BART vocabulary
        if args.bart_emb_init_composition:
            assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of BART vocabulary here'
            logger.info(
                '-' * 10 +
                ' initialize target embeddings with compositional embeddings from BART vocabulary '
                + '-' * 10)
            composite_embed = CompositeEmbeddingBART(
                task.bart, task.bart.model.decoder.embed_tokens,
                task.target_dictionary)
            if args.bart_emb_decoder_input:
                # only initialize the decoder output embeddings
                with torch.no_grad():
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)
            else:
                # initialize both the decoder input and output embeddings
                with torch.no_grad():
                    model.decoder.embed_tokens.weight.copy_(
                        composite_embed.embedding_weight)
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)

    elif 'roberta' in args.arch:
        # initialize the target embeddings with average of subtoken embeddings in BART vocabulary
        if args.bart_emb_init_composition:
            assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of RoBERTa vocabulary here'
            logger.info(
                '-' * 10 +
                ' initialize target embeddings with compositional embeddings from RoBERTa vocabulary '
                + '-' * 10)
            composite_embed = CompositeEmbeddingBART(
                task.bart,  # NOTE here "bart" means roberta
                task.bart.model.encoder.sentence_encoder.embed_tokens,
                task.target_dictionary)

            if args.bart_emb_decoder_input:
                # only initialize the decoder output embeddings
                with torch.no_grad():
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)
            else:
                # initialize both the decoder input and output embeddings
                with torch.no_grad():
                    model.decoder.embed_tokens.weight.copy_(
                        composite_embed.embedding_weight)
                    model.decoder.output_projection.weight.copy_(
                        composite_embed.embedding_weight)

    else:
        raise ValueError
    # ==========================================================================

    # breakpoint()

    # (optionally) Configure quantization
    if args.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=args.quantization_config_path,
            max_epoch=args.max_epoch,
            max_update=args.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if args.model_parallel_size == 1:
        trainer = Trainer(args, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(args, task, model, criterion)

    logger.info("training on {} devices (GPUs/TPUs)".format(
        args.distributed_world_size))
    logger.info(
        "max tokens per GPU = {} and max sentences per GPU = {}".format(
            args.max_tokens, args.batch_size))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        args,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()

    while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
        # train for one epoch
        valid_losses, should_stop = train(args, trainer, task, epoch_itr)
        if should_stop:
            break

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))