示例#1
0
文件: train.py 项目: ahiroto/ParlAI
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.eval_dataloader(
        subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
    )
    loss_meter = AverageMeter()
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    prefix = 'valid on \'{}\' subset'.format(subset)
    with utils.build_progress_bar(args, itr, epoch, prefix) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            loss_dict = trainer.valid_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            loss_meter.update(loss, ntokens)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('valid loss', round(loss_meter.avg, 2)),
            ] + extra_postfix))

        t.print(collections.OrderedDict([
            ('valid loss', round(loss_meter.avg, 2)),
            ('valid ppl', get_perplexity(loss_meter.avg)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))

    # update and return the learning rate
    return loss_meter.avg
示例#2
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    if callable(getattr(trainer.criterion, 'set_epoch', None)):
        trainer.criterion.set_epoch(epoch_itr.epoch)
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if callable(getattr(trainer.criterion, 'set_num_updates', None)):
            trainer.criterion.set_num_updates(trainer.get_num_updates())

        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k or k == 'accuracy':
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
def validate(args, trainer, task, epoch_itr, subsets, sampled_arch_name):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        def get_itr():
            itr = task.get_batch_iterator(
                dataset=task.dataset(subset),
                max_tokens=args.max_tokens_valid,
                max_sentences=args.max_sentences_valid,
                max_positions=utils.resolve_max_positions(
                    task.max_positions(),
                    trainer.get_model().max_positions(),
                ),
                ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
                required_batch_size_multiple=args.required_batch_size_multiple,
                seed=args.seed,
                num_shards=args.distributed_world_size,
                shard_id=args.distributed_rank,
                num_workers=args.num_workers,
            ).next_epoch_itr(shuffle=False)
            progress = progress_bar.build_progress_bar(
                args,
                itr,
                epoch_itr.epoch,
                prefix='validate on \'{}\' subset'.format(subset),
            )
            return progress

        progress = get_itr()

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = utils.get_valid_stats(trainer, args)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg

        # log validation stats
        stats = utils.get_valid_stats(trainer, args, extra_meters)

        stats[sampled_arch_name + '_loss'] = deepcopy(stats['loss'])
        stats[sampled_arch_name + '_nll_loss'] = deepcopy(stats['nll_loss'])

        for k, meter in extra_meters.items():
            stats[k] = meter.avg

        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric].avg if args.
                            best_checkpoint_metric ==
                            'loss' else stats[args.best_checkpoint_metric])
    return valid_losses
示例#4
0
def validate(args, trainer, dataset, subset, extra_state):
    """Evaluate the model on the validation set and return the average loss."""
    epoch = extra_state["epoch"]
    # Initialize dataloader
    max_positions_valid = (
        trainer.get_model().max_encoder_positions(),
        trainer.get_model().max_decoder_positions(),
    )
    itr = dataset.eval_dataloader(
        subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions_valid,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch,
        prefix=f"valid on '{subset}' subset",
        no_progress_bar="simple")

    # reset validation loss meters
    for k in ["valid_loss", "valid_nll_loss"]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
            if k in ["loss", "nll_loss"]:
                continue
            if "loss" in k:
                extra_meters[k].update(v, log_output["sample_size"])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    val_loss = stats["valid_loss"]
    val_ppl = stats["valid_ppl"]

    if ("validate" not in extra_state
            or val_loss < extra_state["validate"]["lowest_loss"]):
        extra_state["validate"] = {
            "lowest_loss": val_loss,
            "num_since_best": 0
        }
    else:
        extra_state["validate"]["num_since_best"] += 1

    stop_due_to_val_loss = False
    if (args.stop_no_best_validate_loss >= 0
            and extra_state["validate"]["num_since_best"] >
            args.stop_no_best_validate_loss):
        stop_due_to_val_loss = True
        print(
            f"Stopping training due to validation score stagnation - last best "
            f"validation loss of {extra_state['validate']['lowest_loss']} (current loss: {val_loss})"
            f"was {extra_state['validate']['num_since_best']} validations ago."
        )
    return val_loss, val_ppl, stop_due_to_val_loss
示例#5
0
 def init_meters(self, args):
     self.meters = OrderedDict()
     self.meters['train_loss'] = AverageMeter()
     self.meters['train_nll_loss'] = AverageMeter()
     self.meters['valid_loss'] = AverageMeter()
     self.meters['valid_nll_loss'] = AverageMeter()
     self.meters['wps'] = TimeMeter()  # words per second
     self.meters['ups'] = TimeMeter()  # updates per second
     self.meters['wpb'] = AverageMeter()  # words per batch
     self.meters['bsz'] = AverageMeter()  # sentences per batch
     self.meters['gnorm'] = AverageMeter()  # gradient norm
     if args.sep_optim:
         self.meters['dec_gnorm'] = AverageMeter(
         )  # gradient norm for decoder
     self.meters['clip'] = AverageMeter()  # % of updates clipped
     self.meters['oom'] = AverageMeter()  # out of memory
     if args.fp16:
         self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
     self.meters['wall'] = TimeMeter()  # wall time in seconds
     self.meters['train_wall'] = StopwatchMeter(
     )  # train wall time in seconds
示例#6
0
def validate(args, trainer, task, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())
        cnt = 0
        if args.distributed_world_size > 1:
            fout_hidden = './tmp/hidden_{}.h5'.format(args.distributed_rank)
            fout_target = './tmp/target_{}.h5'.format(args.distributed_rank)
        else:
            fout_hidden = './tmp/hidden.h5'
            fout_target = './tmp/target.h5'
        fout_hidden, hidden_list = open_h5(fout_hidden, 1024)
        fout_target, target_list = open_h5(fout_target, 1)
        for sample in progress:
            record, log_output = trainer.valid_step(sample)
            hidden_list.append(record[0].cpu().numpy().astype('float16'))
            target_list.append(record[1].cpu().numpy().astype('float16'))
            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue
                extra_meters[k].update(v)
            cnt += 1
            if (cnt > 10):
                break
        # log validation stats
        fout_hidden.close()
        fout_target.close()

        stats = get_valid_stats(trainer, args, extra_meters)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

    return valid_losses
示例#7
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix="valid on '{}' subset".format(subset),
            no_progress_bar="simple",
        )

        # reset validation loss meters
        for k in ["valid_loss", "valid_nll_loss"]:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in [
                        "loss", "nll_loss", "ntokens", "nsentences",
                        "sample_size"
                ]:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer, args, extra_meters)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric].avg if args.
                            best_checkpoint_metric ==
                            "loss" else stats[args.best_checkpoint_metric])
    return valid_losses
示例#8
0
    args_transformer.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args_transformer.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args_transformer.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args_transformer,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for count, ind in enumerate(inds):
        # Get the coordinates of the loss value being calculated
        coord = coords[count]
        dx = directions[0]
        dy = directions[1]
        changes = [d0 * coord[0] + d1 * coord[1] for (d0, d1) in zip(dx, dy)]
        new_states = copy.deepcopy(states)
        assert (len(new_states) == len(changes))
        for (k, v), d in zip(new_states.items(), changes):
            d = torch.tensor(d)
            v.add_(d.type(v.type()))

        ## upload the weight
        model.load_state_dict(new_states)
示例#9
0
def validate(args, trainer, task, epoch_itr, subsets, force_refine_step=None):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_random = np.random.RandomState(3)
    valid_task_random = np.random.RandomState(3)
    if not hasattr(task, "random"):
        task.random = None
    task_random_bak = task.random
    task.random = valid_task_random
    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        dataset = task.dataset(subset)
        if hasattr(dataset, "random"):
            random_bak = dataset.random
        else:
            random_bak = None
        dataset.random = valid_random
        set_valid_tokens(task, dataset, trainer, args)
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)

        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        if hasattr(args, "progressive") and args.progressive:
            dataset.set_random_refine_step(args.refinetot,
                                           force_refine_step=force_refine_step)
        for sample in progress:
            if trainer._oom_batch is None:
                trainer._oom_batch = sample
            if sample is None or len(sample) == 0:
                sys.stderr.write("empty valid sample detected\n")
                sys.stderr.flush()
                break
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue
                extra_meters[k].update(v)

            if hasattr(args, "progressive") and args.progressive:
                dataset.set_random_refine_step(
                    args.refinetot, force_refine_step=force_refine_step)
        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats, tag=subset, step=trainer.get_num_updates())
        valid_losses.append(stats[args.best_checkpoint_metric]
                            if type(stats[args.best_checkpoint_metric]) ==
                            float else stats[args.best_checkpoint_metric].avg)
        dataset.random = random_bak

        if HAS_WANDB and distributed_utils.is_master(args):
            stat_dict = {}
            for k, v in stats.items():
                if isinstance(v, AverageMeter):
                    stat_dict[k] = v.val
                else:
                    stat_dict[k] = v
            wandb.log(stat_dict, step=trainer.get_num_updates())
    task.random = task_random_bak
    return valid_losses
示例#10
0
def train(args, trainer, task, epoch_itr, force_refine_step=None):
    """Train the model for one epoch."""

    # Update parameters every N batches
    def is_better(a, b):
        return a > b if args.maximize_best_checkpoint_metric else a < b

    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    if hasattr(args, "progressive") and args.progressive:
        task.dataset("train").set_random_refine_step(
            args.refinetot, force_refine_step=force_refine_step)
    last_samples = None
    last_best_update = 0
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if samples is None or len(samples) == 0:
            sys.stderr.write("Empty sample detected\n")
            sys.stderr.flush()
            samples = last_samples
        else:
            last_samples = samples
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue
        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args,
                                    trainer,
                                    task,
                                    epoch_itr,
                                    valid_subsets,
                                    force_refine_step=force_refine_step)
            # if distributed_utils.is_master(args):
            #     print("saving:", trainer.get_num_updates())
            #     nsml.save(str(trainer.get_num_updates()))
            if not hasattr(checkpoint_utils.save_checkpoint,
                           'best') or is_better(
                               valid_losses[0],
                               checkpoint_utils.save_checkpoint.best):
                last_best_update = num_updates
                if distributed_utils.is_master(args):
                    print("saving checkpoint ...")
                    sys.stdout.flush()
                    if getattr(args, "save_path",
                               False) and len(args.save_path) > 0:
                        if not os.path.exists(args.save_path):
                            os.mkdir(args.save_path)
                        torch.save({"model": trainer.get_model().state_dict()},
                                   "{}/best.pt".format(args.save_path))
                    elif HAS_NSML:
                        nsml.save("best")
                    else:
                        torch.save({"model": trainer.get_model().state_dict()},
                                   "/tmp/best.pt")
                        if HAS_WANDB:
                            wandb.save("/tmp/best.pt")
                    sys.stdout.flush()
                checkpoint_utils.save_checkpoint.best = valid_losses[0]

        if args.decoder_wise_training and update_num_to_refine_step(
                num_updates) != force_refine_step:
            if HAS_NSML:
                nsml.load("best")
            else:
                # Retrieve the model
                if distributed_utils.is_master(args):
                    state = torch.load("/tmp/best.pt", map_location="cpu")
                    trainer.model.load_state_dict(state["model"])
                # Sync
                assert isinstance(trainer.model,
                                  parallel.DistributedDataParallel)
                if isinstance(trainer.model, parallel.DistributedDataParallel):
                    trainer.model._sync_params()

            checkpoint_utils.save_checkpoint.best = 0.
            force_refine_step = update_num_to_refine_step(num_updates)
            trainer.criterion.pool.clear()
            print("| Start refinement step:", force_refine_step)

        if num_updates >= max_update:
            break

        if args.early_stop and num_updates - last_best_update >= 3000:
            if distributed_utils.is_master(args):
                print("early stop")
            setattr(args, "early_stopping", True)
            break

        if hasattr(args, "progressive") and args.progressive:
            task.dataset("train").set_random_refine_step(
                args.refinetot, force_refine_step=force_refine_step)

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
示例#11
0
def train(args, trainer, task, epoch_itr, shuffling_seeds):
    """Train the model for one epoch."""

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr()
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch_itr.epoch,
                                               no_progress_bar='simple')

    # update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    if args.enable_parallel_backward_allred_opt and update_freq > 1:
        raise RuntimeError(
            '--enable-parallel-backward-allred-opt is incompatible with --update-freq > 1'
        )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)
    if args.time_step:
        begin = time.time()
        end = time.time()
    count = 0

    #profile_count = 13
    profile_count = 10000000000

    for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if args.time_step:
            start_step = time.time()
        if i < num_batches - 1 and (i + 1) % update_freq > 0:
            # buffer updates according to --update-freq
            trainer.train_step(sample,
                               update_params=False,
                               last_step=(i == len(itr) - 1))
            continue
        else:
            log_output = trainer.train_step(sample,
                                            update_params=True,
                                            last_step=(i == len(itr) - 1))
        if args.time_step:
            end_step = time.time()
            #if count > 10  and sample['target'].size(0) > 248 :
            seqs = sample['target'].size(0)
            srclen = sample['net_input']['src_tokens'].size(1)
            tgtlen = sample['target'].size(1)
            srcbatch = srclen * seqs
            tgtbatch = tgtlen * seqs
            #print("ITER {}> Seqs: {} SrcLen: {} TgtLen: {} Src Batch: {} Tgt Batch {}".format( count, seqs, srclen, tgtlen, srcbatch, tgtbatch))
            print("ITER {}> Seqs: {} SrcLen: {} TgtLen: {} Total Time: {:.3} Step Time: {:.3} Load Time: {:.3}".format( \
                count,                                                                                                  \
                sample['target'].size(0),                                                                               \
                sample['net_input']['src_tokens'].size(1),                                                              \
                sample['target'].size(1),                                                                               \
                (end_step-begin)*1000.0,                                                                                \
                (end_step-start_step)*1000.0,                                                                           \
                (start_step-end)*1000.0))
            count += 1
            begin = time.time()

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        if args.profile is not None and i == args.profile:
            import sys
            sys.exit()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid], shuffling_seeds)
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break
        if args.time_step:
            end = time.time()

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip'
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
示例#12
0
def train(args, epoch, batch_offset, trainer, dataset, max_positions):
    """Train the model for one epoch."""

    seed = args.seed + epoch
    torch.manual_seed(seed)
    trainer.set_seed(seed)

    itr = dataset.train_dataloader(
        args.train_subset, num_workers=args.workers,
        max_tokens=args.max_tokens, max_sentences=args.max_sentences,
        max_positions=max_positions, seed=seed, epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum))
    loss_meter = AverageMeter()
    nll_loss_meter = AverageMeter()
    bsz_meter = AverageMeter()    # sentences per batch
    wpb_meter = AverageMeter()    # words per batch
    wps_meter = TimeMeter()       # words per second
    clip_meter = AverageMeter()   # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    lr = trainer.get_lr()
    with utils.build_progress_bar(args, itr, epoch) as t:
        for i, sample in data.skip_group_enumerator(t, args.num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)

            if 'nll_loss' in loss_dict:
                nll_loss = loss_dict['nll_loss']
                nll_loss_meter.update(nll_loss, ntokens)

            nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
            bsz_meter.update(nsentences)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('loss', loss_meter),
                ('wps', round(wps_meter.avg)),
                ('wpb', round(wpb_meter.avg)),
                ('bsz', round(bsz_meter.avg)),
                ('lr', lr),
                ('clip', '{:.0%}'.format(clip_meter.avg)),
            ] + extra_postfix))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        t.print(collections.OrderedDict([
            ('train loss', round(loss_meter.avg, 2)),
            ('train ppl', get_perplexity(nll_loss_meter.avg
                                         if nll_loss_meter.count > 0
                                         else loss_meter.avg)),
            ('s/checkpoint', round(wps_meter.elapsed_time)),
            ('words/s', round(wps_meter.avg)),
            ('words/batch', round(wpb_meter.avg)),
            ('bsz', round(bsz_meter.avg)),
            ('lr', lr),
            ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))
示例#13
0
def estimate_head_importance(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus)
    if args.n_pruning_steps > 0:
        itr = islice(itr, args.n_pruning_steps)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )
    # Inititalize meters
    extra_meters = collections.defaultdict(lambda: AverageMeter())
    # Initialize head importance scores
    encoder_layers = trainer.args.encoder_layers
    decoder_layers = trainer.args.decoder_layers
    encoder_heads = trainer.args.encoder_attention_heads
    decoder_heads = trainer.args.decoder_attention_heads
    device = next(trainer.model.parameters()).device

    head_importance = {
        "encoder_self": torch.zeros(encoder_layers, encoder_heads).to(device),
        "encoder_decoder": torch.zeros(decoder_layers,
                                       decoder_heads).to(device),
        "decoder_self": torch.zeros(decoder_layers, decoder_heads).to(device),
    }
    # Denominators to normalize properly
    denoms = {
        attn_type: val.clone()
        for attn_type, val in head_importance.items()
    }
    head_stats = {
        attn_type: [{} for _ in range(val.size(0))]
        for attn_type, val in head_importance.items()
    }
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        # Compute gradients
        log_output = trainer.prune_step(samples)
        # Retrieve importance scores for the encoder
        for layer in range(encoder_layers):
            self_attn_variables = trainer.model.encoder.layers[
                layer].self_attn_variables
            importance, denom = batch_head_importance(self_attn_variables,
                                                      one_minus=args.one_minus)
            head_importance["encoder_self"][layer] += importance
            denoms["encoder_self"][layer] += denom
            # Stats
            aggregate_stats(head_stats["encoder_self"][layer],
                            batch_head_stats(self_attn_variables)[0])
        # Retrieve importance scores for the decoder
        for layer in range(decoder_layers):
            # Self attention
            self_attn_variables = trainer.model.decoder.layers[
                layer].self_attn_variables
            importance, denom = batch_head_importance(self_attn_variables,
                                                      one_minus=args.one_minus)
            head_importance["decoder_self"][layer] += importance
            denoms["decoder_self"][layer] += denom
            aggregate_stats(
                head_stats["decoder_self"][layer],
                batch_head_stats(self_attn_variables, triu_masking=True)[0])
            # Encoder attention
            encoder_attn_variables = trainer.model.decoder.layers[
                layer].encoder_attn_variables
            importance, denom = batch_head_importance(encoder_attn_variables,
                                                      one_minus=args.one_minus)
            head_importance["encoder_decoder"][layer] += importance
            denoms["encoder_decoder"][layer] += denom
            aggregate_stats(head_stats["encoder_decoder"][layer],
                            batch_head_stats(encoder_attn_variables)[0])
        # log mid-epoch stats
        stats = get_pruning_stats(trainer)
        for k, v in log_output.items():
            extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)
        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()
    # log end-of-epoch stats
    stats = get_pruning_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)
    # Normalize by type
    for attn_type in denoms:
        head_importance[attn_type] /= denoms[attn_type]
    # Normalize head stats
    for attn_type in denoms:
        for layer in range(len(head_stats[attn_type])):
            for key in head_stats[attn_type][layer]:
                head_stats[attn_type][layer][key] /= denoms[attn_type].mean(
                ).cpu()
    # Normalize by layer
    if args.normalize_by_layer:
        for layer in range(encoder_layers):
            for attn_type, importance in head_importance.items():
                head_importance[attn_type][layer] /= torch.sqrt(
                    torch.sum(importance[layer]**2))
    return {k: v.cpu() for k, v in head_importance.items()}, head_stats
示例#14
0
def main():

    args = parser.parse_args()

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
    model = TransformerModel.build_model(args, task).cuda()
    criterion = task.build_criterion(args).cuda()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=eval(args.adam_betas),
                                 eps=args.adam_eps,
                                 weight_decay=args.weight_decay)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=(args.max_source_positions, args.max_target_positions),
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=1,
        num_shards=1,
        shard_id=0,
    )

    losses = AverageMeter()

    encoder_layer_forward = [
        AverageMeter() for _ in range(len(model.encoder.layers[0].layer))
    ]
    decoder_layer_forward = [
        AverageMeter() for _ in range(len(model.decoder.layers[0].layer))
    ]
    encoder_layer_backward = [
        AverageMeter() for _ in range(len(model.encoder.layers[0].layer))
    ]
    decoder_layer_backward = [
        AverageMeter() for _ in range(len(model.decoder.layers[0].layer))
    ]

    def measure_hook(forward, backward):
        def hook(module, input, output):
            for i, layer in enumerate(module.layer):

                if len(input) == 2:
                    x, _ = input
                else:
                    x, = input
                x = x.detach().clone().requires_grad_()

                # warm-up
                for _ in range(5):
                    if isinstance(layer, nn.MultiheadAttention):
                        out, _ = layer(x, x, x)
                    else:
                        out = layer(x)
                    torch.autograd.backward(out, out)

                starter, ender = torch.cuda.Event(
                    enable_timing=True), torch.cuda.Event(enable_timing=True)
                for _ in range(50):
                    starter.record()
                    if isinstance(layer, nn.MultiheadAttention):
                        out, _ = layer(x, x, x)
                    else:
                        out = layer(x)
                    ender.record()
                    torch.cuda.synchronize()
                    forward[i].update(starter.elapsed_time(ender))

                    starter.record()
                    torch.autograd.backward(out, out)
                    ender.record()
                    torch.cuda.synchronize()
                    backward[i].update(starter.elapsed_time(ender))

        return hook

    for layer in model.encoder.layers:
        layer.register_forward_hook(
            measure_hook(encoder_layer_forward, encoder_layer_backward))

    for layer in model.decoder.layers:
        layer.register_forward_hook(
            measure_hook(decoder_layer_forward, decoder_layer_backward))

    embed_forward = AverageMeter()
    embed_backward = AverageMeter()

    def embed_hook(module, input, output):
        tokens, _ = input

        # warm-up
        for _ in range(5):
            x = module.embed_scale * module.embed_tokens(tokens)
            x += module.embed_positions(tokens)
            torch.autograd.backward(x, x)

        starter, ender = torch.cuda.Event(
            enable_timing=True), torch.cuda.Event(enable_timing=True)
        for _ in range(50):
            starter.record()
            x = module.embed_scale * module.embed_tokens(tokens)
            x += module.embed_positions(tokens)
            ender.record()
            torch.cuda.synchronize()
            embed_forward.update(starter.elapsed_time(ender))

            starter.record()
            torch.autograd.backward(x, x)
            ender.record()
            torch.cuda.synchronize()
            embed_backward.update(starter.elapsed_time(ender))

    model.encoder.register_forward_hook(embed_hook)

    linear_forward = AverageMeter()
    linear_backward = AverageMeter()

    def linear_hook(module, input, output):
        _, encode_out = input
        encode_out = encode_out.detach().clone().requires_grad_()

        # warm-up
        for _ in range(5):
            x = encode_out.transpose(0, 1)
            out = F.linear(x, module.embed_out)
            torch.autograd.backward(out, out)

        starter, ender = torch.cuda.Event(
            enable_timing=True), torch.cuda.Event(enable_timing=True)
        for _ in range(50):
            starter.record()
            x = encode_out.transpose(0, 1)
            out = F.linear(x, module.embed_out)
            ender.record()
            torch.cuda.synchronize()
            linear_forward.update(starter.elapsed_time(ender))

            starter.record()
            torch.autograd.backward(out, out)
            ender.record()
            torch.cuda.synchronize()
            linear_backward.update(starter.elapsed_time(ender))

    model.decoder.register_forward_hook(linear_hook)

    itr = epoch_itr.next_epoch_itr()
    max_positions = (args.max_source_positions, args.max_target_positions)
    for i, sample in enumerate(itr):
        sample = task.dataset('train').get_dummy_batch(args.max_tokens,
                                                       max_positions)
        sample = utils.move_to_cuda(sample)
        loss, _, logging_output = criterion(model, sample)
        num_tokens = logging_output['ntokens']
        losses.update(loss.item() / num_tokens / math.log(2), num_tokens)
        if i % 100 == 0:
            print('Loss: {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses))
            print(
                'Time: {forward_time.avg:.3f} ({backward_time.avg:.3f})'
                '{forward_time_decoder.avg:.3f} ({backward_time_decoder.avg:.3f})'
                .format(forward_time=encoder_layer_forward[0],
                        backward_time=encoder_layer_backward[0],
                        forward_time_decoder=decoder_layer_forward[-1],
                        backward_time_decoder=decoder_layer_backward[-1]))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        break

    stat = {i: {} for i in range(len(decoder_layer_forward))}
    for i, (f,
            b) in enumerate(zip(encoder_layer_forward,
                                encoder_layer_backward)):
        stat[i]['encoder'] = {}
        stat[i]['encoder']['forward'] = f.avg
        stat[i]['encoder']['backward'] = b.avg

    for i, (f,
            b) in enumerate(zip(decoder_layer_forward,
                                decoder_layer_backward)):
        stat[i]['decoder'] = {}
        stat[i]['decoder']['forward'] = f.avg
        stat[i]['decoder']['backward'] = b.avg

    stat['embed'] = {}
    stat['embed']['forward'] = embed_forward.avg
    stat['embed']['backward'] = embed_backward.avg

    stat['linear'] = {}
    stat['linear']['forward'] = linear_forward.avg
    stat['linear']['backward'] = linear_backward.avg

    with open('time.json', 'w') as file:
        json.dump(stat, file, indent=4)
示例#15
0
文件: train.py 项目: ahiroto/ParlAI
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
    """Train the model for one epoch."""

    seed = args.seed + epoch
    torch.manual_seed(seed)
    trainer.set_seed(seed)

    itr = dataset.train_dataloader(
        args.train_subset, num_workers=args.workers,
        max_tokens=args.max_tokens, max_sentences=args.max_sentences,
        max_positions=max_positions, seed=seed, epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()    # sentences per batch
    wpb_meter = AverageMeter()    # words per batch
    wps_meter = TimeMeter()       # words per second
    clip_meter = AverageMeter()   # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    lr = trainer.get_lr()
    with utils.build_progress_bar(args, itr, epoch) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            nsentences = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
            bsz_meter.update(nsentences)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('loss', loss_meter),
                ('wps', round(wps_meter.avg)),
                ('wpb', round(wpb_meter.avg)),
                ('bsz', round(bsz_meter.avg)),
                ('lr', lr),
                ('clip', '{:.0%}'.format(clip_meter.avg)),
            ] + extra_postfix))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        t.print(collections.OrderedDict([
            ('train loss', round(loss_meter.avg, 2)),
            ('train ppl', get_perplexity(loss_meter.avg)),
            ('s/checkpoint', round(wps_meter.elapsed_time)),
            ('words/s', round(wps_meter.avg)),
            ('words/batch', round(wpb_meter.avg)),
            ('bsz', round(bsz_meter.avg)),
            ('lr', lr),
            ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))
示例#16
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )

    itr_groups = OrderedDict()
    progress_len = 0

    for domain in itr.keys():
        itr_groups[domain] = iterators_dtn.GroupedIteratorDtn(
            itr[domain], update_freq)
        progress_len = progress_len + len(itr_groups[domain])
    progress = progress_bar.build_progress_bar(
        args,
        range(progress_len),
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    valid_select = args.valid_select[0]
    max_update = args.max_update or math.inf

    def sample_a_training_set(prob=None):
        train_idx = np.random.choice(np.arange(len(prob)), p=prob)
        return train_idx

    train_domains = args.train_domains

    if args.random_select:
        lens_map = OrderedDict()
        count = 0.0
        for domain in train_domains:
            lens_map[domain] = len(itr_groups[domain])
            count += lens_map[domain]

        prob_map = OrderedDict()
        for domain in train_domains:
            prob_map[domain] = lens_map[domain] / count

        keys = []
        probs = []
        for key, value in prob_map.items():
            keys.append(key)
            probs.append(value)
        probs_new = []
        probs_norm = 0.0
        for prob in probs:
            probs_norm += prob**args.random_select_factor
        for prob in probs:
            probs_new.append(prob**args.random_select_factor / probs_norm)

    for i, _ in enumerate(progress, start=epoch_itr.iterations_in_epoch):

        if args.random_select:
            train_idx = sample_a_training_set(probs_new)
            domain = keys[train_idx]
        else:
            domain = train_domains[i % len(train_domains)]
        samples = next(itr_groups[domain])

        # domain, samples = sample_a_training_set(itr_groups)
        log_output = trainer.train_step(samples, domain=domain)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        stats['dataset'] = domain
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()

        if args.validate_interval_updates > 0 and num_updates % args.validate_interval_updates == 0 and num_updates > 0:
            valid_losses, valid_bleus = validate(args, trainer, task,
                                                 epoch_itr, [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses,
                            valid_bleus, valid_select)

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
示例#17
0
    def __init__(self,
                 args,
                 task,
                 model,
                 criterion,
                 allreduce_communicators=None):
        super().__init__(args, task, model, criterion, allreduce_communicators)

        # convert model to FP16 (but keep criterion FP32)
        self.model.half()

        # dynamically scale loss to reduce overflow
        self.scaler = DynamicLossScaler(init_scale=2.**7)
        self.meters['loss_scale'] = AverageMeter()

        self.grad_denom = 1.0

        if self.args.enable_parallel_backward_allred_opt:
            import numpy as np

            self._flat_grads_parallel = torch.tensor(
                [], dtype=torch.float16).cuda()
            self._grads_info = []
            grads_size = 0
            p_offset = 0
            for p_i, p in enumerate(
                [p for p in self.model.parameters() if p.requires_grad]):
                p_grads_size = np.prod(list(p.size()))
                grads_size += p_grads_size

                # register hooks
                def wrapper(param, param_i, param_grads_size, param_offset):
                    def allreduce_hook(grad):
                        self._do_allreduce(param_i, param_grads_size,
                                           param_offset, grad)

                    if param.requires_grad:
                        param.register_hook(allreduce_hook)

                # print(p_i, p.size(), p_grads_size, p_offset)
                self._grads_info.append({
                    "param_grads_size": p_grads_size,
                    "param_offset": p_offset
                })
                wrapper(p, p_i, p_grads_size, p_offset)
                p_offset += p_grads_size
            self._flat_grads_parallel.resize_(grads_size)
            # print(grads_size, len(self._flat_grads_parallel), self._flat_grads_parallel.dtype, self._flat_grads_parallel.get_device())

            self._allreduce_flush_min_threshold = self.args.parallel_backward_allred_opt_threshold
            print("| parallel all-reduce ENABLED. all-reduce threshold: " +
                  str(self._allreduce_flush_min_threshold))
            self._grads_generated = [False] * len(self._grads_info)
            self._allreduce_processed_idx = len(self._grads_info) - 1

            self._num_allreduce_sent = 0
            print("| # of parallel all-reduce cuda streams: " +
                  str(self.args.parallel_backward_allred_cuda_nstreams))
            if allreduce_communicators:
                self._allreduce_groups = allreduce_communicators[0]
                self._allreduce_streams = allreduce_communicators[1]
            else:
                raise RuntimeError(
                    'Moved communicator init before RUN_START (invalid code path)'
                )
                self._allreduce_groups = [
                    torch.distributed.new_group() for _ in range(
                        self.args.parallel_backward_allred_cuda_nstreams)
                ]
                self._allreduce_streams = [
                    torch.cuda.Stream() for _ in range(
                        self.args.parallel_backward_allred_cuda_nstreams)
                ]

            if self.args.enable_parallel_backward_allred_opt_correctness_check:
                self._num_grads_generated = 0
                self._all_grads_generated = False
                self._allreduce_schedule = []
示例#18
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = OrderedDict()
    valid_bleus = OrderedDict()
    valid_select = task.args.valid_select[0]
    assert len(subsets) == 1

    for subset in subsets:
        # Initialize data iterator
        valid_loss_all = []
        valid_nll_loss_all = []
        valid_bleu_all = []

        for k in ['valid_loss', 'valid_nll_loss', 'valid_bleu']:
            meter = trainer.get_meter(k + '_all')
            meter.reset()

        for domain, data_valid in task.dataset(subset).items():

            itr = task.get_batch_iterator_valid(
                dataset=data_valid,
                max_tokens=args.max_tokens,
                max_sentences=args.max_sentences_valid,
                max_positions=utils.resolve_max_positions(
                    task.max_positions(),
                    trainer.get_model().max_positions(),
                ),
                ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
                required_batch_size_multiple=8,
                seed=args.seed,
                num_shards=args.distributed_world_size,
                shard_id=args.distributed_rank,
                num_workers=args.num_workers,
            ).next_epoch_itr(shuffle=False)

            progress = progress_bar.build_progress_bar(
                args,
                itr,
                epoch_itr.epoch,
                prefix='valid on \'{}\' subset \'{}\' domain'.format(
                    subset, domain),
                no_progress_bar='simple')
            # reset validation loss meters
            for k in ['valid_loss', 'valid_nll_loss', 'valid_bleu']:
                meter = trainer.get_meter(k + '_' + domain)
                meter.reset()

            extra_meters = collections.defaultdict(lambda: AverageMeter())

            src_target_hypo_strs = []
            for sample in progress:
                log_output, src_target_hypo_str = trainer.valid_step(
                    sample, domain=domain)
                src_target_hypo_strs.extend(src_target_hypo_str)

                for k, v in log_output.items():
                    if k in [
                            'loss', 'nll_loss', 'ntokens', 'nsentences',
                            'sample_size'
                    ]:
                        continue
                    extra_meters[k].update(v)

            src_target_hypo_strs_filter = []
            for sents in src_target_hypo_strs:
                for sent in sents:
                    if sent is None or len(sent) == 0:
                        continue
                    src_target_hypo_strs_filter.append(sent)

            src_target_hypo_strs_filter = sorted(src_target_hypo_strs_filter,
                                                 key=lambda elem: int(elem[0]),
                                                 reverse=False)
            if args.valid_decoding_path is not None:
                with open(
                        os.path.join(
                            args.valid_decoding_path, domain,
                            'decoding_{}.txt'.format(args.distributed_rank)),
                        'w') as f:
                    for sent in src_target_hypo_strs_filter:
                        if len(sent) == 0:
                            continue
                        f.write(sent[-1] + '\n')

            num_ref = args.num_ref[domain]
            ref_path = []
            for i in range(int(num_ref)):
                ref_path.append(
                    os.path.join(args.valid_decoding_path, domain,
                                 'valid.tok.' + args.target_lang + str(i)))

            valid_decoding_path = os.path.join(
                args.valid_decoding_path, domain,
                'decoding_{}.txt'.format(args.distributed_rank))

            with open(valid_decoding_path) as out_file:
                out_file.seek(0)
                res = subprocess.check_output(
                    'perl %s/multi-bleu.perl %s' %
                    (args.multi_bleu_path, ' '.join(ref_path)),
                    stdin=out_file,
                    shell=True).decode("utf-8")

            trainer.get_meter('valid_bleu_' + domain).update(
                float(res.split(',')[0].split('=')[1]), 1.0)

            stats = get_valid_stats(trainer,
                                    domain=domain,
                                    valid_select=valid_select)

            for k in ['loss', 'nll_loss', 'bleu']:
                stats[k] = stats[k].avg
            for k, meter in extra_meters.items():
                stats[k] = meter.avg

            progress.print(stats,
                           tag=os.path.join(subset, domain),
                           step=trainer.get_num_updates())
            valid_losses.update({domain: stats['loss']})
            valid_bleus.update({domain: stats['bleu']})
            valid_loss_all.append(stats['loss'])
            valid_nll_loss_all.append(stats['nll_loss'])
            valid_bleu_all.append(stats['bleu'])

        trainer.get_meter('valid_loss_all').update(np.mean(valid_loss_all),
                                                   1.0)
        trainer.get_meter('valid_nll_loss_all').update(
            np.mean(valid_nll_loss_all), 1.0)
        trainer.get_meter('valid_bleu_all').update(np.mean(valid_bleu_all),
                                                   1.0)

        stats = get_valid_stats(trainer,
                                domain='all',
                                valid_select=valid_select)

        for k in ['loss', 'nll_loss', 'bleu']:
            stats[k] = stats[k].avg

        progress = progress_bar.build_progress_bar(
            args, [0],
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset \'{}\' domain'.format(
                subset, 'all'),
            no_progress_bar='simple')

        progress.print(stats,
                       tag=os.path.join(subset, 'all'),
                       step=trainer.get_num_updates())
        valid_losses.update({'all': stats['loss']})
        valid_bleus.update({'all': stats['bleu']})

    return valid_losses, valid_bleus
示例#19
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar="simple",
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(",")
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    "loss", "nll_loss", "ntokens", "nsentences", "sample_size"
            ]:
                continue  # these are already logged above
            if "loss" in k or k == "accuracy":
                extra_meters[k].update(v, log_output["sample_size"])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag="train", step=stats["num_updates"])

        # ignore the first mini-batch in words-per-second and updates-per-second calculation
        if i == 0:
            trainer.get_meter("wps").reset()
            trainer.get_meter("ups").reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag="train", step=stats["num_updates"])

    # reset training meters
    for k in [
            "train_loss",
            "train_nll_loss",
            "wps",
            "ups",
            "wpb",
            "bsz",
            "gnorm",
            "clip",
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
示例#20
0
文件: train.py 项目: liufly/refreader
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = data.EpochBatchIterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=trainer.get_model().max_positions(),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        predicted_results, gold_clusters = None, None
        if 'gap_bert' in args.task:
            predicted_results, gold_clusters = collections.defaultdict(
                dict), {}

        for sample in progress:
            log_output = trainer.valid_step(sample)

            if 'gap_bert' in args.task:
                for threshold, predicted_dict in log_output[
                        'predicted_results'].items():
                    len_before = len(predicted_results[threshold])
                    predicted_results[threshold].update(predicted_dict)
                    assert len_before + len(predicted_dict) == len(
                        predicted_results[threshold])
                len_before = len(gold_clusters)
                gold_clusters.update(log_output['gold_clusters'])
                assert len_before + len(
                    log_output['gold_clusters']) == len(gold_clusters)

            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'sample_size', 'predicted_results',
                        'gold_clusters'
                ]:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg

        if 'gap_bert' in args.task:
            best_f1, best_mf1, best_ff1 = float('-inf'), None, None
            best_threshold = None
            for idx, (k, predicted_result) in enumerate(
                    predicted_results.items()):
                scores = trainer.criterion.coref_evaluator.eval(
                    gold_clusters, predicted_result)
                masculine_score = scores[1]
                _, _, _, mf1 = masculine_score
                feminine_score = scores[2]
                _, _, _, ff1 = feminine_score
                overall_score = scores[0]
                _, _, _, f1 = overall_score

                if f1 > best_f1:
                    best_f1, best_mf1, best_ff1 = f1, mf1, ff1
                    best_threshold = k

                if idx == 0:
                    continue
                if idx == 1:
                    stats['valid_loss'] = -f1
                else:
                    stats['valid_loss'] = min(-f1, stats['valid_loss'])

            if not args.no_train:
                if hasattr(save_checkpoint, 'best'):
                    if stats['valid_loss'] < save_checkpoint.best:
                        stats['best'] = -1 * stats['valid_loss']
                        stats['best_threshold'] = best_threshold
                        save_checkpoint.best_threshold = best_threshold
                    else:
                        stats['best'] = -1 * save_checkpoint.best
                        stats[
                            'best_threshold'] = save_checkpoint.best_threshold
                else:
                    stats['best'] = best_f1
                    stats['best_threshold'] = best_threshold
                    save_checkpoint.best_threshold = best_threshold

            stats['f@%.2f' % best_threshold] = best_f1
            stats['mf@%.2f' % best_threshold] = best_mf1
            stats['ff@%.2f' % best_threshold] = best_ff1

        progress.print(stats)

        valid_losses.append(stats['valid_loss'])
    return valid_losses
def train(args, trainer, dataset, epoch, batch_offset):
    """Train the model for one epoch."""

    # Set seed based on args.seed and the epoch number so that we get
    # reproducible results when resuming from checkpoints
    seed = args.seed + epoch
    torch.manual_seed(seed)

    # The max number of positions can be different for train and valid
    # e.g., RNNs may support more positions at test time than seen in training
    max_positions_train = (
        min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
        min(args.max_target_positions, trainer.get_model().max_decoder_positions())
    )

    # Initialize dataloader, starting at batch_offset
    itr = dataset.train_dataloader(
        args.train_subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions_train,
        seed=seed,
        epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum),
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
    itr = itertools.islice(progress, batch_offset, None)

    # reset training meters
    for k in ['train_loss', 'train_nll_loss', 'src_train_loss', 'src_train_nll_loss', 'reg_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for i, sample in enumerate(itr, start=batch_offset):
        log_output = trainer.train_step(sample)

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'reg_loss', 'src_loss']:
                continue  # these are already logged above
            extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # save mid-epoch checkpoints
        if i == batch_offset:
            # ignore the first mini-batch in words-per-second calculation
            trainer.get_meter('wps').reset()
        if args.save_interval > 0 and trainer.get_num_updates() % args.save_interval == 0:
            save_checkpoint(trainer, args, epoch, i + 1)

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)
示例#22
0
文件: train.py 项目: yf1291/nlp4
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args, itr, epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        preds, targets, all_results = [], [], []
        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size', 'targets', 'preds', 'starts', 'ends']:
                    continue
                extra_meters[k].update(v)
            if 'targets' in log_output:
                preds.append(log_output['preds'])
                targets.append(log_output['targets'])

            if 'starts' in log_output:
                for i in range(len(sample['id'])):
                    indice = sample['id'][i].tolist()
                    start  = log_output['starts'][i].cpu().tolist()
                    end    = log_output['ends'][i].cpu().tolist()
                    unique_id = task.features[indice].unique_id
                    result = SquadResult(unique_id, start, end)
                    all_results.append(result)

        if len(preds) > 0:
            preds = torch.cat(preds, 0).cpu().numpy()
            targets = torch.cat(targets, 0).cpu().numpy()
        else:
            preds = None
            targets = None
        
        if len(all_results) > 0:
            results = task.compute_predictions_logits(all_results)
            for k, v in results.items():
                print("({}, {})".format(k, v))
            exit()

        # log validation stats
        stats = get_valid_stats(trainer, args, extra_meters, preds, targets)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(
            stats[args.best_checkpoint_metric].avg
            if args.best_checkpoint_metric == 'loss'
            else stats[args.best_checkpoint_metric]
        )
    return valid_losses
示例#23
0
def baseline_with_meta_evaluation(model, meta_learning_task,
                                  meta_learning_args, meta_learning_criterion,
                                  fine_tune_args):
    meta_epoch_itr, meta_trainer, max_meta_epoch, max_meta_update, valid_subsets = prepare_meta_task(
        model=model,
        meta_learning_task=meta_learning_task,
        meta_learning_args=meta_learning_args,
        meta_learning_criterion=meta_learning_criterion)
    # Combine and do fine-tuning on combined data
    meta_train = meta_learning_task.dataset(meta_learning_args.train_subset)
    combined_fairseq_task = combine_data(meta_train=meta_train,
                                         fine_tune_args=fine_tune_args)
    # Fine-tune using the combined task
    criterion = combined_fairseq_task.build_criterion(fine_tune_args)
    import math
    from fairseq.trainer import Trainer
    combined_fairseq_task.load_dataset(fine_tune_args.train_subset)
    train_dataset = combined_fairseq_task.dataset(fine_tune_args.train_subset)
    # Make a dummy batch to (i) warm the caching allocator and (ii) as a  placeholder DistributedDataParallel when
    # there's an uneven number of batches per worker.
    max_positions = utils.resolve_max_positions(
        combined_fairseq_task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = train_dataset.get_dummy_batch(
        num_tokens=fine_tune_args.max_tokens, max_positions=max_positions)
    oom_batch = combined_fairseq_task.dataset(
        fine_tune_args.train_subset).get_dummy_batch(1, max_positions)
    # Create a trainer for training the model
    trainer = Trainer(fine_tune_args, combined_fairseq_task, model, criterion,
                      dummy_batch, oom_batch)
    epoch_itr = utils.create_epoch_iterator(task=combined_fairseq_task,
                                            dataset=train_dataset,
                                            args=fine_tune_args,
                                            max_positions=max_positions)
    max_epoch = fine_tune_args.max_epoch or math.inf
    max_update = fine_tune_args.max_update or math.inf
    # Do SGD on this task
    valid_subsets = fine_tune_args.valid_subset.split(',')
    lr = trainer.get_lr()
    batch_info = []
    # Always validate once before training
    valid_losses, _ = utils.validate(fine_tune_args, trainer,
                                     combined_fairseq_task, epoch_itr,
                                     valid_subsets)
    while lr > fine_tune_args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # Train the model for one epoch
        import collections
        import math
        from fairseq.data import iterators
        from fairseq import progress_bar
        from fairseq.meters import AverageMeter, ConcatentateMeter, BleuMeter
        """Train the model for one epoch."""
        # Update parameters every N batches
        update_freq = fine_tune_args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(fine_tune_args.update_freq) else fine_tune_args.update_freq[-1]

        # Initialize data iterator
        itr = epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=fine_tune_args.fix_batches_to_gpus,
            shuffle=(epoch_itr.epoch >= fine_tune_args.curriculum),
        )
        itr = iterators.GroupedIterator(itr, update_freq)
        progress = progress_bar.build_progress_bar(
            fine_tune_args,
            itr,
            epoch_itr.epoch,
            no_progress_bar='simple',
        )

        extra_meters = collections.defaultdict(lambda: AverageMeter())
        extra_meters['strings'] = ConcatentateMeter()
        extra_meters['bleu_stats'] = BleuMeter()

        valid_subsets = fine_tune_args.valid_subset.split(',')
        max_update = fine_tune_args.max_update or math.inf
        for i, samples in enumerate(progress,
                                    start=epoch_itr.iterations_in_epoch):
            log_output = trainer.train_step(samples)
            if log_output is None:
                continue

            # log mid-epoch stats
            stats = utils.get_training_stats(trainer)
            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue  # these are already logged above
                if 'loss' in k:
                    extra_meters[k].update(v, log_output['sample_size'])
                else:
                    extra_meters[k].update(v)
                stats[k] = extra_meters[k].avg
            progress.log(stats,
                         tag=fine_tune_args.train_subset,
                         step=stats['num_updates'])

            # ignore the first mini-batch in words-per-second calculation
            if i == 0:
                trainer.get_meter('wps').reset()

            num_updates = trainer.get_num_updates()
            if fine_tune_args.save_interval_updates > 0 and num_updates % fine_tune_args.save_interval_updates == 0 and num_updates > 0:
                valid_losses, _ = utils.validate(fine_tune_args,
                                                 trainer,
                                                 combined_fairseq_task,
                                                 epoch_itr,
                                                 valid_subsets,
                                                 train_progress=progress)
                utils.save_checkpoint(fine_tune_args, trainer, epoch_itr,
                                      valid_losses[0])

            if num_updates >= max_update:
                break

        # log end-of-epoch stats
        stats = utils.get_training_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
            stats[k + '_std'] = meter.std
        progress.print(stats,
                       tag=fine_tune_args.train_subset,
                       step=stats['num_updates'])

        # reset training meters
        for k in [
                'train_loss',
                'train_nll_loss',
                'wps',
                'ups',
                'wpb',
                'bsz',
                'gnorm',
                'clip',
        ]:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        # Evaluate on validation split
        if epoch_itr.epoch % fine_tune_args.validate_interval == 0:
            valid_losses, _ = utils.validate(fine_tune_args, trainer,
                                             combined_fairseq_task, epoch_itr,
                                             valid_subsets)
        # save checkpoint
        if epoch_itr.epoch % fine_tune_args.save_interval == 0:
            utils.save_checkpoint(fine_tune_args, trainer, epoch_itr,
                                  valid_losses[0])
        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
    if batch_info is None:
        # Handle the original train function
        batch_info = []
    # Evaluate on validation split
    maybe_validate(meta_epoch_itr=meta_epoch_itr,
                   meta_learning_args=meta_learning_args,
                   meta_trainer=meta_trainer,
                   meta_learning_task=meta_learning_task,
                   valid_subsets=valid_subsets)
示例#24
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr()
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch_itr.epoch,
                                               no_progress_bar='simple')

    # update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    if args.enable_parallel_backward_allred_opt and update_freq > 1:
        raise RuntimeError(
            '--enable-parallel-backward-allred-opt is incompatible with --update-freq > 1'
        )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)
    #begin = time.time()
    #inside = 0
    for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        #newbegin = time.time()
        #print("iter time", newbegin - begin, inside, (newbegin - begin - inside)*1000)
        #begin = newbegin
        if i < num_batches - 1 and (i + 1) % update_freq > 0:
            # buffer updates according to --update-freq
            trainer.train_step(sample,
                               update_params=False,
                               last_step=(i == len(itr) - 1))
            continue
        else:
            log_output = trainer.train_step(sample,
                                            update_params=True,
                                            last_step=(i == len(itr) - 1))

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        if args.profile is not None and i == args.profile:
            import sys
            sys.exit()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break
        #end = time.time()
        #inside = end - begin

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip'
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
示例#25
0
 def init_meters(self, args):
     self.meters = OrderedDict()
     self.meters['train_loss'] = AverageMeter()
     self.meters['train_nll_loss'] = AverageMeter()
     self.meters['train_generate_loss'] = AverageMeter()
     self.meters['train_generate_nll_loss'] = AverageMeter()
     self.meters['train_predict_loss'] = AverageMeter()
     self.meters['train_predict_nll_loss'] = AverageMeter()
     self.meters['valid_loss'] = AverageMeter()
     self.meters['valid_nll_loss'] = AverageMeter()
     self.meters['valid_generate_loss'] = AverageMeter()
     self.meters['valid_generate_nll_loss'] = AverageMeter()
     self.meters['valid_predict_loss'] = AverageMeter()
     self.meters['valid_predict_nll_loss'] = AverageMeter()
     self.meters['wps'] = TimeMeter()       # words per second
     self.meters['ups'] = TimeMeter()       # updates per second
     self.meters['wpb'] = AverageMeter()    # words per batch
     self.meters['bsz'] = AverageMeter()    # sentences per batch
     self.meters['gnorm'] = AverageMeter()  # gradient norm
     self.meters['clip'] = AverageMeter()   # % of updates clipped
     self.meters['oom'] = AverageMeter()    # out of memory
     self.meters['src_tokens'] = []             #
     self.meters['target_tokens'] = []          #
     self.meters['select_retrive_tokens'] = []  #
     self.meters['loss_weight'] = []  #
     if args.fp16:
         self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
     self.meters['wall'] = TimeMeter()      # wall time in seconds
     self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
示例#26
0
文件: train.py 项目: nilesh-c/kgqa
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        if "target" not in valid_sub_split:
            task.load_dataset(valid_sub_split, combine=False, epoch=0)
        # task.load_dataset("target_" + valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    assert isinstance(model, XlmrTransformerEncoderDecoder)
    # encoder = model.encoder

    # args.task = 'semparse_classification'
    # adv_task = tasks.setup_task(args, xlmr=task.xlmr)
    # assert isinstance(adv_task, SemparseClassificationTask)
    #
    # # Build adversarial language critic model and criterion (WGAN-GP)
    # adv_model = adv_task.build_model(args)
    # adv_criterion = adv_task.build_criterion(args)
    # print(adv_model)
    # print('| model {}, criterion {}'.format(args.arch, adv_criterion.__class__.__name__))
    # print('| num. model params: {} (num. trained: {})'.format(
    #     sum(p.numel() for p in adv_model.parameters()),
    #     sum(p.numel() for p in adv_model.parameters() if p.requires_grad),
    # ))

    # Build

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    # adv_trainer = Trainer(args, adv_task, adv_model, adv_criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    grad_meter = AverageMeter()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')

    if args.plot_features:
        valid_losses, accuracy = validate(args, trainer, task, epoch_itr,
                                          valid_subsets)
        plot_features(args,
                      trainer,
                      task,
                      epoch_itr,
                      valid_subsets,
                      accuracy=accuracy)

    if args.validate_first:
        validate(args, trainer, task, epoch_itr, valid_subsets)

    while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or
                                 (epoch_itr.epoch == max_epoch
                                  and epoch_itr._next_epoch_itr is not None))
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr, grad_meter)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses, accuracy = validate(args, trainer, task, epoch_itr,
                                              valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        reload_dataset = ':' in getattr(args, 'data', '')
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch,
                                               load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
示例#27
0
def eval_tune_loss(args, trainer, task, subset, extra_state):
    """Evaluate the model on the validation set and return the average loss."""
    # Initialize dataloader
    itr = task.get_batch_iterator(
        dataset=task.dataset(subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            trainer.get_model().max_positions()),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(
        args=args,
        iterator=itr,
        epoch=extra_state["epoch"],
        prefix=f"valid on '{subset}' subset",
        no_progress_bar="simple",
    )

    # reset validation loss meters
    for k in ["valid_loss", "valid_nll_loss"]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    "loss", "nll_loss", "ntokens", "nsentences", "sample_size"
            ]:
                continue
            if "loss" in k:
                extra_meters[k].update(v, log_output["sample_size"])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    extra_state["tune_eval"]["loss"] = stats["valid_loss"]
    extra_state["tune_eval"]["perplexity"] = stats["valid_ppl"]

    if (extra_state["tune_eval"]["lowest_loss"] is None
            or extra_state["tune_eval"]["loss"] <
            extra_state["tune_eval"]["lowest_loss"]):
        extra_state["tune_eval"]["lowest_loss"] = extra_state["tune_eval"][
            "loss"]
        extra_state["tune_eval"]["num_since_best"] = 0
    else:
        extra_state["tune_eval"]["num_since_best"] += 1

    stop_due_to_tune_loss = False
    if (args.stop_no_best_validate_loss >= 0
            and extra_state["tune_eval"]["num_since_best"] >
            args.stop_no_best_validate_loss):
        stop_due_to_tune_loss = True
        print(
            f"Stopping training due to eval tune loss stagnation - last best "
            f"eval tune loss of {extra_state['tune_eval']['lowest_loss']} "
            f"(current loss: {extra_state['tune_eval']['loss']}) "
            f"was {extra_state['tune_eval']['num_since_best']} validations ago."
        )
    return extra_state, stop_due_to_tune_loss
示例#28
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr()
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'loss_sen_piece', 'nll_loss_sen_piece',
                    'overall_loss', 'overall_nll_loss', 'ntokens',
                    'ntokens_sen_piece', 'nsentences', 'sample_size',
                    'sample_size_sen_piece', 'sample_size_overall'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            if 'loss_sen_piece' in k:
                extra_meters[k].update(v, log_output['sample_size_sen_piece'])
            else:
                extra_meters[k].update(v)
            if 'overall_loss' in k:
                extra_meters[k].update(v, (log_output['sample_size_overall']) /
                                       2.0)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses, valid_losses_sen_piece, valid_overall_losses = validate(
                args, trainer, task, epoch_itr, [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0],
                            valid_losses_sen_piece[0], valid_overall_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'train_loss_sen_piece',
            'train_nll_loss_sen_piece',
            'train_overall_loss',
            'train_overall_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
示例#29
0
文件: train.py 项目: zxw866/fairseq
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr()
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch_itr.epoch,
                                               no_progress_bar='simple')

    # update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)
    for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if i < num_batches - 1 and (i + 1) % update_freq > 0:
            # buffer updates according to --update-freq
            trainer.train_step(sample, update_params=False)
            continue
        else:
            log_output = trainer.train_step(sample, update_params=True)

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip'
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
示例#30
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    valid_losses_sen_piece = []
    valid_overall_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in [
                'valid_loss', 'valid_nll_loss', 'valid_loss_sen_piece',
                'valid_nll_loss_sen_piece', 'valid_overall_loss',
                'valid_overall_nll_loss'
        ]:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'loss_sen_piece',
                        'nll_loss_sen_piece', 'overall_loss',
                        'overall_nll_loss', 'ntokens', 'ntokens_sen_piece',
                        'nsentences', 'sample_size', 'sample_size_sen_piece',
                        'sample_size_overall'
                ]:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats)

        valid_losses.append(stats['valid_loss'])
        valid_losses_sen_piece.append(stats['valid_loss_sen_piece'])
        valid_overall_losses.append(stats['valid_overall_loss'])
    return valid_losses, valid_losses_sen_piece, valid_overall_losses
示例#31
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        if callable(getattr(trainer.criterion, 'set_valid_tgt_dataset', None)):
            trainer.criterion.set_valid_tgt_dataset(task.dataset(subset).tgt)

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size', 'word_count', 'char_count'
                ]:
                    continue
                if k == 'word_error':
                    extra_meters['wer'].update(
                        float(v) / log_output['word_count'] * 100,
                        log_output['word_count'])
                elif k == 'char_error':
                    extra_meters['cer'].update(
                        float(v) / log_output['char_count'] * 100,
                        log_output['char_count'])
                else:
                    extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer, args, extra_meters)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric].avg if args.
                            best_checkpoint_metric ==
                            'loss' else stats[args.best_checkpoint_metric])
    return valid_losses
示例#32
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=False,  # edited 5-5-2020 for not shuffling
    )
    print('iterator in training:')
    print(itr.iterable.batch_sampler)

    itr = iterators.GroupedIterator(itr, update_freq)
    print('iterator in training-2:')
    print(itr)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )
    print('iterator in training-2**:')
    print(itr)
    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    print('iterator in training-3:')
    print(epoch_itr.iterations_in_epoch)

    print('progress:')
    print(progress)

    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        #Christine (6-5-2020)
        print('samples:')
        #print(samples)
        dtype = samples[0]['net_input']['src_tokens'].dtype
        #print('dtype:')
        #print(dtype)
        deleted_batches = samples[0]['deleted']
        task.initiate_memory(i, deleted_batches, trainer, dtype)
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k or k == 'accuracy':
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

        if itr:
            print('iterator is not empty')
        else:
            pri

    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()