Example #1
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args['dataset']['fixed_validation_seed'] is not None:
        # set fixed seed for every validation
        set_seed.set_torch_seed(args['dataset']['fixed_validation_seed'])

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args['dataset']['max_tokens_valid'],
            max_sentences=args['dataset']['max_sentences_valid'],
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args['dataset']
            ['skip_invalid_size_inputs_valid_test'],
            required_batch_size_multiple=args['dataset']
            ['required_batch_size_multiple'],
            seed=args['common']['seed'],
            num_shards=args['distributed_training']['distributed_world_size'],
            shard_id=args['distributed_training']['distributed_rank'],
            num_workers=args['dataset']['num_workers'],
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args['common']['log_format'],
            log_interval=args['common']['log_interval'],
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args['common']['tensorboard_logdir'] if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm' if not args['common']['no_progress_bar']
                                else 'simple'),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(
            stats[args['checkpoint']['best_checkpoint_metric']])

    return valid_losses
Example #2
0
 def save_checkpoint(self, filename, extra_state):
     """Save all training state in a checkpoint file."""
     if distributed_utils.is_master(self.args):  # only save one checkpoint
         extra_state["metrics"] = metrics.state_dict()
         checkpoint_utils.save_state(
             filename,
             self.args,
             self.get_model().state_dict(),
             self.get_criterion(),
             self.optimizer,
             self.lr_scheduler,
             self.get_num_updates(),
             self._optim_history,
             extra_state,
         )
Example #3
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args['distributed_training']
        ['fix_batches_to_gpus'],
        # shuffle=(epoch_itr.next_epoch_idx > args['dataset']['curriculum']),
        shuffle=False,
    )
    update_freq = (args['optimization']['update_freq'][epoch_itr.epoch - 1] if
                   epoch_itr.epoch <= len(args['optimization']['update_freq'])
                   else args['optimization']['update_freq'][-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args['common']['log_format'],
        log_interval=args['common']['log_interval'],
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args['common']['tensorboard_logdir']
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args['common']['no_progress_bar']
                            else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args['dataset']['valid_subset'].split(',')
    max_update = args['optimization']['max_update'] or math.inf
    num_updates = 0  # init as 0, for zero-shot learning
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args['common']['log_interval'] == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset epoch-level meters
            metrics.reset_meters('train_inner')

        if (not args['dataset']['disable_validation']
                and args['checkpoint']['save_interval_updates'] > 0 and
                num_updates % args['checkpoint']['save_interval_updates'] == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
Example #4
0
def single_main(args, init_distributed=False):
    assert args['dataset']['max_tokens'] is not None or args['dataset']['max_sentences'] is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
    metrics.reset()

    # 0. Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args['common']['cpu']:
        torch.cuda.set_device(args['distributed_training']['device_id'])
    set_seed.set_seed(args['common']['seed'])
    if init_distributed:
        args['distributed_training'][
            'distributed_rank'] = distributed_utils.distributed_init(args)

    # Verify checkpoint directory
    if distributed_utils.is_master(args):
        save_dir = args['checkpoint']['save_dir']
        checkpoint_utils.verify_checkpoint_directory(save_dir)
        PathManager.rm(os.path.join(
            save_dir, '*.pt'))  # this code will remove pre-trained models

    # 1. Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # 2. Load valid dataset (we load training data below, based on the latest checkpoint)
    task.load_dataset(args['dataset']['valid_subset'], combine=False, epoch=1)

    # 3. Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    LOGGER.info(model)
    LOGGER.info('model {}, criterion {}'.format(args['model']['arch'],
                                                criterion.__class__.__name__))
    LOGGER.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # 4. Build trainer
    trainer = Trainer(args, task, model, criterion)
    LOGGER.info('training on {} GPUs'.format(
        args['distributed_training']['distributed_world_size']))
    LOGGER.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args['dataset']['max_tokens'],
            args['dataset']['max_sentences'],
        ))

    # 5. Load the latest checkpoint if one is available and restore the corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args,
                                                              trainer,
                                                              combine=False)

    # 6. Train until the learning rate gets too small
    max_epoch = args['optimization']['max_epoch'] or math.inf
    max_update = args['optimization']['max_update'] or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    valid_subsets = args['dataset']['valid_subset'].split(',')
    while (lr > args['optimization']['min_lr']
           and epoch_itr.next_epoch_idx <= max_epoch
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args['dataset']['disable_validation'] and epoch_itr.epoch % args[
                'dataset']['validate_interval'] == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args['checkpoint']['save_interval'] == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        # early stop
        if should_stop_early(args, valid_losses[0]):
            LOGGER.info(
                'early stop since valid performance hasn\'t improved for last {} runs'
                .format(args['checkpoint']['patience']))
            break

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            combine=False,  # TODO to be checked
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in args['task']['data']),
        )

    train_meter.stop()
    LOGGER.info('done training in {:.1f} seconds'.format(train_meter.sum))
Example #5
0
def save_checkpoint(args, trainer, epoch_itr, val_loss):
    from ncc import meters
    from ncc.utils import distributed_utils
    prev_best = getattr(save_checkpoint, "best", val_loss)
    if val_loss is not None:
        best_function = max if args['checkpoint'][
            'maximize_best_checkpoint_metric'] else min
        save_checkpoint.best = best_function(val_loss, prev_best)

    if args['checkpoint']['no_save'] or not distributed_utils.is_master(args):
        return

    def is_better(a, b):
        return a >= b if args['checkpoint'][
            'maximize_best_checkpoint_metric'] else a <= b

    write_timer = meters.StopwatchMeter()
    write_timer.start()

    epoch = epoch_itr.epoch
    end_of_epoch = epoch_itr.end_of_epoch()
    updates = trainer.get_num_updates()

    checkpoint_conds = collections.OrderedDict()
    checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
        end_of_epoch and not args['checkpoint']['no_epoch_checkpoints']
        and epoch % args['checkpoint']['save_interval'] == 0)
    checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = (
        not end_of_epoch and args['checkpoint']['save_interval_updates'] > 0
        and updates % args['checkpoint']['save_interval_updates'] == 0)
    checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
        not hasattr(save_checkpoint, "best")
        or is_better(val_loss, save_checkpoint.best))
    if val_loss is not None and args['checkpoint']['keep_best_checkpoints'] > 0:
        checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format(
            args['checkpoint']['best_checkpoint_metric'],
            val_loss)] = (not hasattr(save_checkpoint, "best")
                          or is_better(val_loss, save_checkpoint.best))
    checkpoint_conds[
        "checkpoint_last.pt"] = not args['checkpoint']['no_last_checkpoints']

    extra_state = {
        "train_iterator": epoch_itr.state_dict(),
        "val_loss": val_loss
    }
    if hasattr(save_checkpoint, "best"):
        extra_state.update({"best": save_checkpoint.best})

    checkpoints = [
        os.path.join(args['checkpoint']['save_dir'], fn)
        for fn, cond in checkpoint_conds.items() if cond
    ]
    if len(checkpoints) > 0:
        trainer.save_checkpoint(checkpoints[0], extra_state)
        for cp in checkpoints[1:]:
            PathManager.copy(checkpoints[0], cp)

        write_timer.stop()
        LOGGER.info(
            "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {:.6f} seconds)"
            .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum))

    if not end_of_epoch and args['checkpoint']['keep_interval_updates'] > 0:
        # remove old checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(args['checkpoint']['save_dir'],
                                       pattern=r"checkpoint_\d+_(\d+)\.pt")
        for old_chk in checkpoints[
                args['checkpoint']['keep_interval_updates']:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args['checkpoint']['keep_last_epochs'] > 0:
        # remove old epoch checkpoints; checkpoints are sorted in descending order
        checkpoints = checkpoint_paths(args['checkpoint']['save_dir'],
                                       pattern=r"checkpoint(\d+)\.pt")
        for old_chk in checkpoints[args['checkpoint']['keep_last_epochs']:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)

    if args['checkpoint']['keep_best_checkpoints'] > 0:
        # only keep the best N checkpoints according to validation metric
        checkpoints = checkpoint_paths(
            args['checkpoint']['save_dir'],
            pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
                args['checkpoint']['best_checkpoint_metric']))
        if not args['checkpoint']['maximize_best_checkpoint_metric']:
            checkpoints = checkpoints[::-1]
        for old_chk in checkpoints[
                args['checkpoint']['keep_best_checkpoints']:]:
            if os.path.lexists(old_chk):
                os.remove(old_chk)
Example #6
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args['dataset']['fixed_validation_seed'] is not None:
        # set fixed seed for every validation
        set_seed.set_torch_seed(args['dataset']['fixed_validation_seed'])

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args['dataset']['max_tokens_valid'],
            max_sentences=args['dataset']['max_sentences_valid'],
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args['dataset']
            ['skip_invalid_size_inputs_valid_test'],
            required_batch_size_multiple=args['dataset']
            ['required_batch_size_multiple'],
            seed=args['common']['seed'],
            num_shards=args['distributed_training']['distributed_world_size'],
            shard_id=args['distributed_training']['distributed_rank'],
            num_workers=args['dataset']['num_workers'],
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args['common']['log_format'],
            log_interval=args['common']['log_interval'],
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args['common']['tensorboard_logdir'] if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm' if not args['common']['no_progress_bar']
                                else 'simple'),
        )

        accs, mrrs, maps, ndcgs = [], [], [], []
        trainer.model.eval()
        trainer.criterion.eval()
        with torch.no_grad():
            for sample in progress:
                sample = trainer._prepare_sample(sample)
                inputs = list(sample['net_input'].values())
                code_repr = trainer.model.code_forward(*inputs[:6])
                desc_repr = trainer.model.desc_forward(*inputs[6:8])
                code_repr = code_repr / code_repr.norm(dim=-1, keepdim=True)
                desc_repr = desc_repr / desc_repr.norm(dim=-1, keepdim=True)
                similarity = code_repr @ desc_repr.t()
                acc, mrr, map, ndcg = inference(similarity)
                accs.append(acc.mean().item())
                mrrs.append(mrr.mean().item())
                maps.append(map.mean().item())
                ndcgs.append(ndcg.mean().item())
        accs = round(float(np.mean(accs)), 6)
        mrrs = round(float(np.mean(mrrs)), 6)
        maps = round(float(np.mean(maps)), 6)
        ndcgs = round(float(np.mean(ndcgs)), 6)
        stats = {'acc': accs, 'mrr': mrrs, 'map': maps, 'ndcg': ndcgs}
        progress.print(stats, tag=subset, step=trainer.get_num_updates())
        valid_losses.append(
            stats[args['checkpoint']['best_checkpoint_metric']])
    return valid_losses
Example #7
0
def save_expert_outputs(args, task, trainer):
    print("| Start saving expert outputs..")
    expert_outputs = gen_outputs(args, task, trainer)
    output_path = os.path.join(
        args['checkpoint']['save_dir'], 'train_output.json.{}'.format(
            args['distributed_training']['distributed_rank']))
    print('Save topk output at {}'.format(output_path))
    json.dump(expert_outputs, open(output_path, 'w'))
    # distributed_utils.barrier(args, 'save_expert_outputs')
    if distributed_utils.is_master(args):
        expert_outputs_ = []
        # copy valid bleu result
        val_bleu_path1 = os.path.join(args['checkpoint']['save_dir'],
                                      'val_bleu.json')
        val_bleu_path2 = os.path.join(
            args['task']['data'], 'expert_bleu_{}_{}_{}.json'.format(
                '_'.join(args['task']['programming_langs']),
                args['task']['source_lang'], args['task']['target_lang']))
        cmd = 'cp {} {}'.format(val_bleu_path1, val_bleu_path2)
        print(cmd)
        os.system(cmd)

        for i in range(args['distributed_training']['distributed_world_size']):
            output_path = os.path.join(args['checkpoint']['save_dir'],
                                       'train_output.json.{}'.format(i))
            expert_outputs_.append(json.load(open(output_path, 'r')))
            try:
                os.remove(output_path)
            except:
                pass

        for j in range(len(expert_outputs_[0])):
            for i in range(
                    args['distributed_training']['distributed_world_size']):
                if expert_outputs_[i][j] is not None:
                    expert_outputs[j] = expert_outputs_[i][j]
                    break
            assert expert_outputs[j] is not None

        path = os.path.join(
            args['task']['data'], '{}_{}_{}_topk_idx'.format(
                '_'.join(args['task']['programming_langs']),
                args['task']['source_lang'], args['task']['target_lang']))
        TeacherOutputDataset.save_bin(path, [o[0] for o in expert_outputs],
                                      np.int32)

        path = os.path.join(
            args['task']['data'], '{}_{}_{}_topk_prob'.format(
                '_'.join(args['task']['programming_langs']),
                args['task']['source_lang'], args['task']['target_lang']))
        TeacherOutputDataset.save_bin(path, [o[1] for o in expert_outputs],
                                      np.float)

        LOGGER.info(
            "| Save expert@{}_{}_{}. Bleu.Json: {}, TopK.Idx/Prob: {}.".format(
                '_'.join(args['task']['programming_langs']),
                args['task']['source_lang'],
                args['task']['target_lang'],
                val_bleu_path2,
                path,
            ))
Example #8
0
 def is_data_parallel_master(self):
     return distributed_utils.is_master(self.args)
Example #9
0
def validate(args, trainer, task, epoch_itr, valid_subsets, dev_subsets,
             dev_refs):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args['dataset']['fixed_validation_seed'] is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args['dataset']['fixed_validation_seed'])

    for subset in valid_subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args['dataset']['max_tokens_valid'],
            max_sentences=args['dataset']['max_sentences_valid'],
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args['dataset']
            ['skip_invalid_size_inputs_valid_test'],
            required_batch_size_multiple=args['dataset']
            ['required_batch_size_multiple'],
            seed=args['common']['seed'],
            num_shards=args['distributed_training']['distributed_world_size'],
            shard_id=args['distributed_training']['distributed_rank'],
            num_workers=args['dataset']['num_workers'],
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args['common']['log_format'],
            log_interval=args['common']['log_interval'],
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args['common']['tensorboard_logdir'] if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm' if not args['common']['no_progress_bar']
                                else 'simple'),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        # calculate accuracy
        match = stats.pop('match')
        total = stats.pop('total')
        valid_acc = match / total
        progress.print(
            {
                'accuracy': f'{round(100. * valid_acc, 2)}%',
                'bleu': stats['bleu'],
                'loss': stats['loss'],
            },
            tag=subset,
            step=trainer.get_num_updates())

    # for subset in dev_subsets:
    #     hypotheses, references = {}, dev_refs
    #
    #     # Initialize data iterator
    #     itr = task.get_batch_iterator(
    #         dataset=task.dataset(subset),
    #         max_tokens=args['dataset']['max_tokens_valid'],
    #         max_sentences=args['dataset']['max_sentences_valid'],
    #         max_positions=utils.resolve_max_positions(
    #             task.max_positions(),
    #             trainer.get_model().max_positions(),
    #         ),
    #         ignore_invalid_inputs=args['dataset']['skip_invalid_size_inputs_valid_test'],
    #         required_batch_size_multiple=args['dataset']['required_batch_size_multiple'],
    #         seed=args['common']['seed'],
    #         num_shards=args['distributed_training']['distributed_world_size'],
    #         shard_id=args['distributed_training']['distributed_rank'],
    #         num_workers=args['dataset']['num_workers'],
    #     ).next_epoch_itr(shuffle=False)
    #     progress = progress_bar.progress_bar(
    #         itr,
    #         log_format=args['common']['log_format'],
    #         log_interval=args['common']['log_interval'],
    #         epoch=epoch_itr.epoch,
    #         prefix=f"valid on '{subset}' subset",
    #         tensorboard_logdir=(
    #             args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None
    #         ),
    #         default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'),
    #     )
    #
    #     # create a new root metrics aggregator so validation metrics
    #     # don't pollute other aggregators (e.g., train meters)
    #     with metrics.aggregate(new_root=True) as agg:
    #         for sample in progress:
    #             with torch.no_grad():
    #                 trainer.model.eval()
    #                 trainer.criterion.eval()
    #                 sample = trainer._prepare_sample(sample)
    #                 hyps, _, _, ids = trainer.task.step_out(sample, trainer.model)
    #                 for idx, hypo in zip(ids, hyps):
    #                     hypotheses[idx] = hypo
    #
    #     from third_party.pycocoevalcap.bleu.google_bleu import compute_bleu
    #     assert set(hypotheses.keys()) == set(references.keys())
    #     bleus = [
    #         compute_bleu([references[idx]], [hypotheses[idx]], smooth=Trainer)[0]
    #         for idx in hypotheses.keys()
    #     ]
    #     dev_bleu = round(100. * sum(bleus) / len(bleus), 2)
    #     # log validation stats
    #     stats = agg.get_smoothed_values()
    #     stats['bleu'] = dev_bleu
    #     stats = get_dev_stats(args, trainer, stats)
    #     progress.print(stats, tag=subset, step=trainer.get_num_updates())
    # return valid_acc, dev_bleu
    return valid_acc, None
Example #10
0
def single_main(args, init_distributed=False):
    assert args['dataset']['max_tokens'] is not None or args['dataset']['max_sentences'] is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'
    metrics.reset()

    # 0. Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args['common']['cpu']:
        torch.cuda.set_device(args['distributed_training']['device_id'])
    random.seed(args['common']['seed'])
    np.random.seed(args['common']['seed'])
    torch.manual_seed(args['common']['seed'])
    torch.cuda.manual_seed(args['common']['seed'])
    if init_distributed:
        args['distributed_training'][
            'distributed_rank'] = distributed_utils.distributed_init(args)

    # Verify checkpoint directory
    if distributed_utils.is_master(args):
        save_dir = args['checkpoint']['save_dir']
        checkpoint_utils.verify_checkpoint_directory(save_dir)
        remove_files(save_dir,
                     'pt')  # this code will remove pre-trained models

    # 1. Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # 2. Load valid dataset (we load training data below, based on the latest checkpoint)
    # calculate accuracy for decay learning rate
    task.load_dataset(args['dataset']['valid_subset'], combine=False, epoch=1)
    # # compute meteor to select model
    # task.load_dataset(args['dataset']['dev_subset'], combine=False, epoch=1)
    # # load dev/ref.txt
    # dev_refs = load_refs(os.path.join(args['task']['data'], args['dataset']['dev_ref_subset']))

    # 3. Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    LOGGER.info(model)
    LOGGER.info('model {}, criterion {}'.format(args['model']['arch'],
                                                criterion.__class__.__name__))
    LOGGER.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # 4. Build trainer
    trainer = Trainer(args, task, model, criterion)
    LOGGER.info('training on {} GPUs'.format(
        args['distributed_training']['distributed_world_size']))
    LOGGER.info(
        'max tokens per GPU = {} and max sentences per GPU = {}'.format(
            args['dataset']['max_tokens'],
            args['dataset']['max_sentences'],
        ))

    # 5. Load the latest checkpoint if one is available and restore the corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args,
                                                              trainer,
                                                              combine=False)

    # 6. Train until the learning rate gets too small
    max_epoch = args['optimization']['max_epoch'] or math.inf
    max_update = args['optimization']['max_update'] or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()
    valid_subsets = args['dataset']['valid_subset'].split(',')
    dev_subsets = args['dataset']['dev_subset'].split(',')
    valid_accs_after_60e = []
    while (lr > args['optimization']['min_lr']
           and epoch_itr.next_epoch_idx <= max_epoch
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args['dataset']['disable_validation'] and epoch_itr.epoch % args[
                'dataset']['validate_interval'] == 0:
            valid_acc, dev_prf = validate(args,
                                          trainer,
                                          task,
                                          epoch_itr,
                                          valid_subsets,
                                          dev_subsets,
                                          dev_refs=None)
        else:
            valid_acc, dev_prf = None, None

        # if epoch_itr.next_epoch_idx > 61 and valid_acc < valid_accs_after_60e[-1]:
        #     """
        #     We start with a learning rate of 0.5 and start
        #     decaying it by a factor of 0.8 after 60 epochs if
        #     accuracy on the validation set goes down, and
        #     terminate training when the learning rate goes
        #     below 0.001.
        #     """
        #     lr = trainer.set_lr(lr * trainer.args['optimization']['lr_shrink'])
        #
        # if epoch_itr.epoch >= 60:
        #     valid_accs_after_60e.append(valid_acc)

        # if len(valid_accs_after_60e) > 10 and valid_accs_after_60e[-5] >= valid_acc:
        #     lr = trainer.set_lr(lr * trainer.args['optimization']['lr_shrink'])
        # valid_accs_after_60e.append(valid_acc)

        if len(valid_accs_after_60e
               ) > 10 and valid_accs_after_60e[-5] >= valid_acc:
            lr = trainer.set_lr(lr * trainer.args['optimization']['lr_shrink'])

        # eval on dev and dev.ref data

        # save checkpoint
        if epoch_itr.epoch % args['checkpoint']['save_interval'] == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_acc)

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            combine=False,  # TODO to be checked
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in args['task']['data']),
        )

    train_meter.stop()
    LOGGER.info('done training in {:.1f} seconds'.format(train_meter.sum))