コード例 #1
0
ファイル: train.py プロジェクト: pmichel31415/fairseq
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=8,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)

            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats)

        valid_losses.append(stats['valid_loss'])
    return valid_losses
コード例 #2
0
def validate(args, trainer, dataset, subset, epoch):
    """Evaluate the model on the validation set and return the average loss."""

    # Initialize dataloader
    max_positions_valid = (
        trainer.get_model().max_encoder_positions(),
        trainer.get_model().max_decoder_positions(),
    )
    itr = dataset.eval_dataloader(
        subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions_valid,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch,
        prefix='valid on \'{}\' subset'.format(subset),
        no_progress_bar='simple')

    # reset validation loss meters
    for k in ['valid_loss', 'valid_nll_loss']:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss']:
                continue
            extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    return stats['valid_loss']
コード例 #3
0
ファイル: trainer.py プロジェクト: SCUZPP/ENAS
    def __init__(self, args, task, model, criterion, dummy_batch):

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args
        self.task = task

        # copy model and criterion to current device
        self.criterion = criterion.cuda()
        if args.fp16:
            self._model = model.half().cuda()
        else:
            self._model = model.cuda()

        # initialize meters
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
        self.meters['wps'] = TimeMeter()  # words per second
        self.meters['ups'] = TimeMeter()  # updates per second
        self.meters['wpb'] = AverageMeter()  # words per batch
        self.meters['bsz'] = AverageMeter()  # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()  # % of updates clipped
        self.meters['oom'] = AverageMeter()  # out of memory
        if args.fp16:
            self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
        self.meters['wall'] = TimeMeter()  # wall time in seconds
        self.meters['train_wall'] = StopwatchMeter(
        )  # train wall time in seconds

        self._dummy_batch = dummy_batch
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
        self._wrapped_model = None
コード例 #4
0
ファイル: train.py プロジェクト: Novemser/sum
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.dataloader(subset,
                             batch_size=None,
                             max_tokens=args.max_tokens,
                             max_positions=args.max_positions)
    loss_meter = AverageMeter()
    rouge_greedy_meter = AverageMeter()
    rouge_sampled_meter = AverageMeter()

    desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
    with progress_bar(itr, desc, leave=False) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            ntokens = sum(s['ntokens'] for s in sample)
            loss, mean_rouge_greedy, mean_rouge_sampled = trainer.valid_step(
                sample, criterion)
            loss_meter.update(loss, ntokens)
            rouge_greedy_meter.update(mean_rouge_greedy, 1)
            rouge_sampled_meter.update(mean_rouge_sampled, 1)
            t.set_postfix(
                collections.OrderedDict([
                    ('loss', '{:.2f}'.format(loss_meter.avg)),
                    ('ROUGE-L/f (greedy)',
                     '{:.4f}'.format(rouge_greedy_meter.avg)),
                    ('ROUGE-L/f (sampled)',
                     '{:.4f}'.format(rouge_sampled_meter.avg))
                ]))

        val_loss = loss_meter.avg
        t.write(
            desc +
            ' | valid loss {:2.2f} | valid ppl {:3.2f} | ROUGE-L (greedy): {:.4f} | ROUGE-L (sampled): {:.4f}'
            .format(val_loss, math.pow(2, val_loss), rouge_greedy_meter.avg,
                    rouge_sampled_meter.avg))

    # update and return the learning rate
    return val_loss
コード例 #5
0
    def __init__(self, args, task, model, criterion, allreduce_communicators=None):
        super().__init__(args, task, model, criterion, allreduce_communicators)

        # convert model to FP16 (but keep criterion FP32)
        self.model.half()

        # dynamically scale loss to reduce overflow
        self.scaler = DynamicLossScaler(init_scale=2.**7)
        self.meters['loss_scale'] = AverageMeter()
        # FIXME: Add more meters

        self.grad_denom = 1.0

        assert (not self.args.enable_parallel_backward_allred_opt), "--distributed-weight-update cannot be combined with --enable-parallel-backward-allred-opt"
コード例 #6
0
ファイル: train.py プロジェクト: longhuei/Crosslingual-GCN
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []
    for subset in subsets:
        itr = data_utils.get_epoch_iterator(
            task,
            task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=None,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions()),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            num_workers=args.num_workers,
            seed=args.seed,
            epoch=epoch_itr.epoch).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple')

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)
            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue
                extra_meters[k].update(v)

        # log validation stats
        stats = get_valid_stats(trainer, args, extra_meters)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric].avg if args.
                            best_checkpoint_metric ==
                            'loss' else stats[args.best_checkpoint_metric])
    return valid_losses
コード例 #7
0
ファイル: trainer.py プロジェクト: srbutler/MaskGAN.pytorch
    def rollout_critic(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        batch_size, seq_len = samples[0].size()
        meter = AverageMeter()
        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'critic-rollout')
        for rollout in pbar:
            loss = self.model(masked, lengths, mask, unmasked, tag="c-step")
            loss = loss.sum() / batch_size
            loss.backward()
            meter.update(loss.item())

        self.opt.step()
        self.logger.log("critic/loss", self.step, meter.avg)
コード例 #8
0
    def __init__(self, args, task, model, criterion):

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args

        # copy model and criterion to current device
        self.task = task
        self.model = model.cuda()
        self.criterion = criterion.cuda()

        # initialize meters
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
        self.meters['wps'] = TimeMeter()  # words per second
        self.meters['ups'] = TimeMeter()  # updates per second
        self.meters['wpb'] = AverageMeter()  # words per batch
        self.meters['bsz'] = AverageMeter()  # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()  # % of updates clipped
        self.meters['oom'] = AverageMeter()  # out of memory
        self.meters['wall'] = TimeMeter()  # wall time in seconds

        self._buffered_stats = defaultdict(lambda: [])
        self._flat_grads = None
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None

        self._last_step = False
        if self.args.enable_parallel_backward_allred_opt and not self.args.distributed_world_size > 1:
            raise RuntimeError(
                '--enable-parallel-backward-allred-opt is only meant for distributed training'
            )
        if self.args.enable_parallel_backward_allred_opt and not self.args.fp16:
            raise RuntimeError(
                '--enable-parallel-backward-allred-opt only works with FP16 training'
            )
コード例 #9
0
ファイル: train.py プロジェクト: wanchaol/translate
def setup_epoch(args, epoch_itr, trainer):
    """Sets up data and progress meters for one epoch."""
    # Initialize dataloader, starting at batch_offset
    itr = epoch_itr.next_epoch_itr()
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar="simple"
    )

    # reset training meters
    for k in ["train_loss", "train_nll_loss", "wps", "ups", "wpb", "bsz", "clip"]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    return itr, progress, extra_meters
コード例 #10
0
 def valid_loop_fn(
     args, device, trainer, progress, loader, last_batch_index
 ):
     extra_meters = collections.defaultdict(lambda: AverageMeter())
     for i, sample in enumerate(loader):
         if i == last_batch_index:
             # last batches are of different size, will cause recompilations
             break
         log_output = trainer.valid_step(sample)
         for k, v in log_output.items():
             if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
                 continue
             extra_meters[k].update(v)
     stats = get_valid_stats(trainer, args)
     for k, meter in extra_meters.items():
         stats[k] = meter.avg
     return stats
コード例 #11
0
    def __init__(self, args, task, model, criterion):
        super().__init__(args, task, model, criterion)

        # convert model to FP16 (but keep criterion FP32)
        self.model.half()

        # dynamically scale loss to reduce overflow
        self.scaler = DynamicLossScaler(init_scale=2.**7)
        self.meters['loss_scale'] = AverageMeter()

        self.grad_denom = 1.0

        if self.args.enable_parallel_backward_allred_opt:
            import numpy as np
            self._reduction_stream = torch.cuda.Stream()

            self._flat_grads_parallel = torch.tensor([], dtype=torch.float16).cuda()
            self._grads_info = []
            grads_size = 0
            p_offset = 0
            for p_i, p in enumerate([p for p in self.model.parameters() if p.requires_grad]):
                p_grads_size = np.prod(list(p.size()))
                grads_size += p_grads_size
                # register hooks
                def wrapper(param, param_i, param_grads_size, param_offset):
                    def allreduce_hook(grad):
                        self._do_allreduce(param_i, param_grads_size, param_offset, grad)

                    if param.requires_grad:
                        param.register_hook(allreduce_hook)
                # print(p_i, p.size(), p_grads_size, p_offset)
                self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
                wrapper(p, p_i, p_grads_size, p_offset)
                p_offset += p_grads_size
            self._flat_grads_parallel.resize_(grads_size)
            # print(grads_size, len(self._flat_grads_parallel), self._flat_grads_parallel.dtype, self._flat_grads_parallel.get_device())

            self._allreduce_flush_min_threshold = self.args.parallel_backward_allred_opt_threshold
            print("| parallel all-reduce ENABLED. all-reduce threshold: " + str(self._allreduce_flush_min_threshold))
            self._grads_generated = [False]*len(self._grads_info)
            self._allreduce_processed_idx = len(self._grads_info)-1

            if self.args.enable_parallel_backward_allred_opt_correctness_check:
                self._num_grads_generated = 0
                self._all_grads_generated = False
                self._allreduce_schedule = []
コード例 #12
0
    def __init__(self, args, task, model, criterion):

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args

        # copy model and criterion to current device
        self.task = task
        self.model = model.cuda()
        self.criterion = criterion.cuda()

        # initialize meters
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
        self.meters['wps'] = TimeMeter()  # words per second
        self.meters['ups'] = TimeMeter()  # updates per second
        self.meters['wpb'] = AverageMeter()  # words per batch
        self.meters['bsz'] = AverageMeter()  # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()  # % of updates clipped
        self.meters['oom'] = AverageMeter()  # out of memory
        self.meters['wall'] = TimeMeter()  # wall time in seconds

        self._buffered_stats = defaultdict(lambda: [])
        self._flat_grads = None
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None
        if self.args.use_ema:
            print('Use ema.')
            from fairseq.utils import EMA
            self._ema = EMA()
            self._backup = {}
            self._init_ema()
コード例 #13
0
ファイル: train.py プロジェクト: warut-vijit/translate
def setup_epoch(args, epoch, batch_offset, trainer, dataset):
    """Sets up data and progress meters for one epoch."""
    # Set seed based on args.seed and the epoch number so that we get
    # reproducible results when resuming from checkpoints
    seed = args.seed + epoch
    torch.manual_seed(seed)

    # The max number of positions can be different for train and valid
    # e.g., RNNs may support more positions at test time than seen in training
    max_positions_train = (
        min(args.max_source_positions,
            trainer.get_model().max_encoder_positions()),
        min(args.max_target_positions,
            trainer.get_model().max_decoder_positions()),
    )

    # Initialize dataloader, starting at batch_offset
    itr = dataset.train_dataloader(
        args.train_subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions_train,
        seed=seed,
        epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum),
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch,
                                               no_progress_bar="simple")
    itr = itertools.islice(progress, batch_offset, None)

    # reset training meters
    for k in [
            "train_loss", "train_nll_loss", "wps", "ups", "wpb", "bsz", "clip"
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    return itr, progress, extra_meters
コード例 #14
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    progress = initialize_loader_for_epoch(args, epoch_itr)
    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k or k == 'accuracy':
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second and updates-per-second calculation
        reset_perf_training_meters(trainer, i)

        num_updates = trainer.get_num_updates()
        if (
            not args.disable_validation
            and args.save_interval_updates > 0
            and num_updates % args.save_interval_updates == 0
            and num_updates > 0
        ):
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    reset_training_meters(trainer)
コード例 #15
0
ファイル: trainer.py プロジェクト: ywang07/fairseq
    def __init__(self, args, task, model, criterion):

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args

        # copy model and criterion to current device
        self.task = task
        self.model = model.cuda()
        self.criterion = criterion.cuda()

        # initialize meters
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()
        self.meters['valid_loss'] = AverageMeter()
        self.meters['valid_nll_loss'] = AverageMeter()
        self.meters['wps'] = TimeMeter()  # words per second
        self.meters['ups'] = TimeMeter()  # updates per second
        self.meters['wpb'] = AverageMeter()  # words per batch
        self.meters['bsz'] = AverageMeter()  # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()  # % of updates clipped
        self.meters['oom'] = AverageMeter()  # out of memory
        self.meters['wall'] = TimeMeter()  # wall time in seconds

        self._buffered_stats = defaultdict(lambda: [])
        self._flat_grads = None
        self._num_updates = 0
        self._optim_history = None
        self._optimizer = None

        self.prev_teacher_models = None  #used as the models to perform kd training
        self.prev_teacher_val_losses = OrderedDict()
        self.kd_teacher_weights = None
コード例 #16
0
ファイル: sim_mt_env.py プロジェクト: xpertasks/simulNMT
    def start_epoch(self) -> bool:
        if not (self.trainer.get_lr() > self.args.min_lr and self.epoch_itr.epoch < self.max_epoch and
                self.trainer.get_num_updates() < self.max_update):
            self._done = True
            return False

        args = self.args
        update_freq = args.update_freq[self.epoch_itr.epoch - 1] \
            if self.epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
        # Initialize data iterator
        itr = self.epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=args.fix_batches_to_gpus,
            shuffle=(self.epoch_itr.epoch >= args.curriculum),
        )
        itr = iterators.GroupedIterator(itr, update_freq)
        # meters in the epoch
        self.extra_meters = collections.defaultdict(lambda: AverageMeter())
        # enumerate
        self.itr = enumerate(itr, start=self.epoch_itr.iterations_in_epoch)
        return True
コード例 #17
0
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.dataloader(subset,
                             batch_size=None,
                             max_tokens=args.max_tokens,
                             max_positions=args.max_positions)
    loss_meter = AverageMeter()

    desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
    with progress_bar(itr, desc, leave=False) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            ntokens = sum(s['ntokens'] for s in sample)
            loss = trainer.valid_step(sample, criterion)
            loss_meter.update(loss, ntokens)
            t.set_postfix(loss='{:.2f}'.format(loss_meter.avg))

        val_loss = loss_meter.avg
        t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
            val_loss, math.pow(2, val_loss)))

    # update and return the learning rate
    return val_loss
コード例 #18
0
    def __init__(self, args, model):

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args

        self.model = model.cuda()
        self.criterion = CRITERION_REGISTRY[args.criterion](args).cuda()
        self.optimizer = optim.build_optimizer(self.args,
                                               self.model.parameters())
        self.lr_scheduler = lr_scheduler.build_lr_scheduler(
            self.args, self.optimizer)
        self.scaler = amp.GradScaler(enabled=self.args.amp, init_scale=2**15)

        if self.args.distributed_world_size > 1:
            self.model = DDP(model)

        self._buffered_stats = defaultdict(lambda: [])
        self._num_updates = 0
        self._optim_history = None
        self.throughput_meter = TimeMeter()
        self.avg_loss_meter = AverageMeter()
コード例 #19
0
ファイル: train.py プロジェクト: warut-vijit/translate
def validate(args, trainer, dataset, subset, extra_state):
    """Evaluate the model on the validation set and return the average loss."""
    epoch = extra_state["epoch"]
    # Initialize dataloader
    max_positions_valid = (
        trainer.get_model().max_encoder_positions(),
        trainer.get_model().max_decoder_positions(),
    )
    itr = dataset.eval_dataloader(
        subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions_valid,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch,
        prefix=f"valid on '{subset}' subset",
        no_progress_bar="simple")

    # reset validation loss meters
    for k in ["valid_loss", "valid_nll_loss"]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
            if k in ["loss", "nll_loss"]:
                continue
            if "loss" in k:
                extra_meters[k].update(v, log_output["sample_size"])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    val_loss = stats["valid_loss"]
    val_ppl = stats["valid_ppl"]

    if ("validate" not in extra_state
            or val_loss < extra_state["validate"]["lowest_loss"]):
        extra_state["validate"] = {
            "lowest_loss": val_loss,
            "num_since_best": 0
        }
    else:
        extra_state["validate"]["num_since_best"] += 1

    stop_due_to_val_loss = False
    if (args.stop_no_best_validate_loss >= 0
            and extra_state["validate"]["num_since_best"] >
            args.stop_no_best_validate_loss):
        stop_due_to_val_loss = True
        print(
            f"Stopping training due to validation score stagnation - last best "
            f"validation loss of {extra_state['validate']['lowest_loss']} (current loss: {val_loss})"
            f"was {extra_state['validate']['num_since_best']} validations ago."
        )
    return val_loss, val_ppl, stop_due_to_val_loss
コード例 #20
0
ファイル: task6_train.py プロジェクト: Silent-Zebra/reproduce
def estimate_head_importance(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus)
    if args.n_pruning_steps > 0:
        itr = islice(itr, args.n_pruning_steps)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )
    # Inititalize meters
    extra_meters = collections.defaultdict(lambda: AverageMeter())
    # Initialize head importance scores
    encoder_layers = trainer.args.encoder_layers
    decoder_layers = trainer.args.decoder_layers
    encoder_heads = trainer.args.encoder_attention_heads
    decoder_heads = trainer.args.decoder_attention_heads
    device = next(trainer.model.parameters()).device
    head_importance = {
        "encoder_self": torch.zeros(encoder_layers, encoder_heads).to(device),
        "encoder_decoder": torch.zeros(decoder_layers,
                                       decoder_heads).to(device),
        "decoder_self": torch.zeros(decoder_layers, decoder_heads).to(device),
    }
    # Denominators to normalize properly
    denoms = {
        attn_type: val.clone()
        for attn_type, val in head_importance.items()
    }
    head_stats = {
        attn_type: [{} for _ in range(val.size(0))]
        for attn_type, val in head_importance.items()
    }
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        # Compute gradients
        log_output = trainer.prune_step(samples)
        # Retrieve importance scores for the encoder
        for layer in range(encoder_layers):
            self_attn_variables = trainer.model.encoder.layers[
                layer].self_attn_variables
            importance, denom = batch_head_importance(self_attn_variables,
                                                      one_minus=args.one_minus)
            head_importance["encoder_self"][layer] += importance
            denoms["encoder_self"][layer] += denom
            # Stats
            aggregate_stats(head_stats["encoder_self"][layer],
                            batch_head_stats(self_attn_variables)[0])
        # Retrieve importance scores for the decoder
        for layer in range(decoder_layers):
            # Self attention
            self_attn_variables = trainer.model.decoder.layers[
                layer].self_attn_variables
            importance, denom = batch_head_importance(self_attn_variables,
                                                      one_minus=args.one_minus)
            head_importance["decoder_self"][layer] += importance
            denoms["decoder_self"][layer] += denom
            aggregate_stats(
                head_stats["decoder_self"][layer],
                batch_head_stats(self_attn_variables, triu_masking=True)[0])
            # Encoder attention
            encoder_attn_variables = trainer.model.decoder.layers[
                layer].encoder_attn_variables
            importance, denom = batch_head_importance(encoder_attn_variables,
                                                      one_minus=args.one_minus)
            head_importance["encoder_decoder"][layer] += importance
            denoms["encoder_decoder"][layer] += denom
            aggregate_stats(head_stats["encoder_decoder"][layer],
                            batch_head_stats(encoder_attn_variables)[0])
        # log mid-epoch stats
        stats = get_pruning_stats(trainer)
        for k, v in log_output.items():
            extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)
        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()
    # log end-of-epoch stats
    stats = get_pruning_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)
    # Normalize by type
    for attn_type in denoms:
        head_importance[attn_type] /= denoms[attn_type]
    # Normalize head stats
    for attn_type in denoms:
        for layer in range(len(head_stats[attn_type])):
            for key in head_stats[attn_type][layer]:
                head_stats[attn_type][layer][key] /= denoms[attn_type].mean(
                ).cpu()
    # Normalize by layer
    if args.normalize_by_layer:
        for layer in range(encoder_layers):
            for attn_type, importance in head_importance.items():
                head_importance[attn_type][layer] /= torch.sqrt(
                    torch.sum(importance[layer]**2))
    return {k: v.cpu() for k, v in head_importance.items()}, head_stats
コード例 #21
0
ファイル: trainer_dtn.py プロジェクト: wangyong1122/dtn
    def init_meters(self, args):
        self.meters = OrderedDict()
        self.meters['train_loss'] = AverageMeter()
        self.meters['train_nll_loss'] = AverageMeter()

        for domain in ['all'] + args.valid_domains:
            self.meters['valid_loss_' + domain] = AverageMeter()
            self.meters['valid_nll_loss_' + domain] = AverageMeter()
            self.meters['valid_bleu_' + domain] = AverageMeter()

        self.meters['wps'] = TimeMeter()  # words per second
        self.meters['ups'] = TimeMeter()  # updates per second
        self.meters['wpb'] = AverageMeter()  # words per batch
        self.meters['bsz'] = AverageMeter()  # sentences per batch
        self.meters['gnorm'] = AverageMeter()  # gradient norm
        self.meters['clip'] = AverageMeter()  # % of updates clipped
        self.meters['oom'] = AverageMeter()  # out of memory
        if args.fp16:
            self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
        self.meters['wall'] = TimeMeter()  # wall time in seconds
        self.meters['train_wall'] = StopwatchMeter(
        )  # train wall time in seconds
コード例 #22
0
ファイル: train.py プロジェクト: sk210892/fairseq
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr()
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch_itr.epoch,
                                               no_progress_bar='simple')

    # update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)
    for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if i < num_batches - 1 and (i + 1) % update_freq > 0:
            # buffer updates according to --update-freq
            trainer.train_step(sample, update_params=False)
            continue
        else:
            log_output = trainer.train_step(sample, update_params=True)

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip'
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
コード例 #23
0
 def init_meters(self, args):
     self.meters = OrderedDict()
     self.meters['train_loss'] = AverageMeter()
     self.meters['train_distribution_loss'] = AverageMeter()
     self.meters['train_label_loss'] = AverageMeter()
     self.meters['train_label_acc'] = AverageMeter()
     self.meters['train_nll_loss'] = AverageMeter()
     self.meters['valid_loss'] = AverageMeter()
     self.meters['valid_nll_loss'] = AverageMeter()
     self.meters['copy_alpha'] = AverageMeter()
     self.meters['wps'] = TimeMeter()       # words per second
     self.meters['ups'] = TimeMeter()       # updates per second
     self.meters['wpb'] = AverageMeter()    # words per batch
     self.meters['bsz'] = AverageMeter()    # sentences per batch
     self.meters['gnorm'] = AverageMeter()  # gradient norm
     self.meters['clip'] = AverageMeter()   # % of updates clipped
     self.meters['oom'] = AverageMeter()    # out of memory
     if args.fp16:
         self.meters['loss_scale'] = AverageMeter()  # dynamic loss scale
     self.meters['wall'] = TimeMeter()      # wall time in seconds
     self.meters['train_wall'] = StopwatchMeter()  # train wall time in seconds
コード例 #24
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        try:
            log_output = trainer.train_step(samples)
        except FloatingPointError as e:
            if "Minimum loss scale reached" in str(e):
                print(f'Check samples: len={len(samples)}')
                for ik, s in enumerate(samples):
                    if s is None:
                        print(f'[{ik}]: None')
                    else:
                        for k, v in s.items():
                            if isinstance(v, torch.Tensor):
                                print(f'[{ik}][{k}]: {v.size()}')

            raise e

        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
コード例 #25
0
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        batch_size=args.batch_size,
        test_batch_size=args.test_batch_size,
        valid_batch_size=args.valid_batch_size,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement)
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    gnorm_meter = AverageMeter()  # gradient norm

    desc = '| epoch {:03d}'.format(epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss, grad_norm = trainer.train_step(sample, criterion)

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if grad_norm > args.clip_norm else 0)
            gnorm_meter.update(grad_norm)

            t.set_postfix(
                collections.OrderedDict([
                    ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
                    ('wps', '{:5d}'.format(round(wps_meter.avg))),
                    ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                    ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                    ('lr', lr),
                    ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
                    ('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
                ]))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                trainer.save_checkpoint(args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
        t.write(
            fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
                       round(wps_meter.elapsed_time), round(wps_meter.avg),
                       round(wpb_meter.avg), round(bsz_meter.avg), lr,
                       clip_meter.avg * 100, gnorm_meter.avg))
コード例 #26
0
def validate(args, trainer, dataset, subset, epoch):
    """Evaluate the model on the validation set and return the average loss."""
    # Initialize dataloader
    max_positions_valid = (
        trainer.get_model().max_encoder_positions(),
        trainer.get_model().max_decoder_positions(),
    )
    itr = dataset.eval_dataloader(
        subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=max_positions_valid,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch,
        prefix=f'valid on \'{subset}\' subset',
        no_progress_bar='simple')

    # reset validation loss meters
    for k in ['valid_loss', 'valid_nll_loss']:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for sample in progress:
        log_output = trainer.valid_step(sample)

        # log mid-validation stats
        stats = get_valid_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss']:
                continue
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

    # log validation stats
    stats = get_valid_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    val_loss = stats['valid_loss']
    val_ppl = stats['valid_ppl']
    if not hasattr(validate, 'lowest_loss') or val_loss < validate.lowest_loss:
        validate.lowest_loss = val_loss
        validate.num_since_best = 0
    elif not hasattr(validate, 'num_since_best'):
        validate.num_since_best = 1
    else:
        validate.num_since_best += 1

    stop_due_to_val_loss = False
    if (args.stop_no_best_validate_loss >= 0
            and validate.num_since_best > args.stop_no_best_validate_loss):
        stop_due_to_val_loss = True
        print(
            f'Stopping training due to validation score stagnation - last best '
            'validation loss of {validate.lowest_loss} (current loss: {val_loss})'
            'was {validate.num_since_best} validations ago.')
    return val_loss, val_ppl, stop_due_to_val_loss
コード例 #27
0
def train(args, trainer, dataset, epoch, batch_offset):
    """Train the model for one epoch."""

    # Set seed based on args.seed and the epoch number so that we get
    # reproducible results when resuming from checkpoints
    seed = args.seed + epoch
    torch.manual_seed(seed)

    # The max number of positions can be different for train and valid
    # e.g., RNNs may support more positions at test time than seen in training
    max_positions_train = (
        min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
        min(args.max_target_positions, trainer.get_model().max_decoder_positions())
    )

    # Initialize dataloader, starting at batch_offset
    itr = dataset.train_dataloader(
        args.train_subset,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions_train,
        seed=seed,
        epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum),
        shard_id=args.distributed_rank,
        num_shards=args.distributed_world_size,
    )
    progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
    itr = itertools.islice(progress, batch_offset, None)

    # reset training meters
    for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    for i, sample in enumerate(itr, start=batch_offset):
        log_output = trainer.train_step(sample)

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss']:
                continue  # these are already logged above
            extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # save mid-epoch checkpoints
        if i == batch_offset:
            # ignore the first mini-batch in words-per-second calculation
            trainer.get_meter('wps').reset()
        if args.save_interval > 0 and trainer.get_num_updates() % args.save_interval == 0:
            save_checkpoint(trainer, args, epoch, i + 1)

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)
コード例 #28
0
def train(args, trainer, task, epoch_itr, summary_writer=None):
    """Train the model for one epoch."""

    # Update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus)
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)

    distributed_utils.barrier(args, "train_%d" % trainer.get_num_updates())
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg

        stats['progress'] = round(
            i / num_batches * args.distributed_world_size *
            args.update_freq[-1], 3)
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
            distributed_utils.barrier(
                args, "train_val_%d" % trainer.get_num_updates())

        if num_updates % args.log_interval == 0:
            summary_writer.log_stats('train', stats, num_updates)

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
コード例 #29
0
ファイル: train.py プロジェクト: nlpofwhat/OR-NMT
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples, epoch_itr=epoch_itr)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k or k == 'accuracy':
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
コード例 #30
0
def validate(args,
             trainer,
             task,
             epoch_itr,
             subsets,
             test_bleu=False,
             summary_writer=None):
    """Evaluate the model on the validation set(s) and return the losses."""
    valid_losses = []

    distributed_utils.barrier(args, "validate1_%d" % trainer.get_num_updates())
    for subset in subsets:
        # Initialize data iterator
        def get_itr():
            itr = task.get_batch_iterator(
                dataset=task.dataset(subset),
                max_tokens=args.max_tokens,
                max_sentences=args.max_sentences_valid,
                max_positions=utils.resolve_max_positions(
                    task.max_positions(),
                    trainer.get_model().max_positions(),
                ),
                ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
                required_batch_size_multiple=8,
                seed=args.seed,
                num_shards=args.distributed_world_size,
                shard_id=args.distributed_rank,
            ).next_epoch_itr(shuffle=False)
            progress = progress_bar.build_progress_bar(
                args,
                itr,
                epoch_itr.epoch,
                prefix='valid on \'{}\' subset'.format(subset),
                no_progress_bar='simple')
            return progress

        progress = get_itr()

        num_dataset = task.dataset(subset).num_dataset

        # reset validation loss meters
        for k in ['valid_loss', 'valid_nll_loss']:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        extra_meters = collections.defaultdict(lambda: AverageMeter())

        for sample in progress:
            log_output = trainer.valid_step(sample)
            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue
                extra_meters[k].update(v)

        bleu_scorers = [
            bleu.Scorer(task.target_dictionary.pad(),
                        task.target_dictionary.eos(),
                        task.target_dictionary.unk())
            for _ in range(num_dataset)
        ] if test_bleu else None

        # log validation stats
        stats = get_valid_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg

        if bleu_scorers is not None:
            # test bleu
            print("| test bleu.")
            sample_size = [0 for _ in range(num_dataset)]
            bleu_scores = [0 for _ in range(num_dataset)]
            progress = get_itr()

            tgt_str_files = []
            hypo_str_files = []
            for ds_id in range(num_dataset):
                tgt_str_path = task.dataset(
                    subset).dataset_names[ds_id] + '.tgt.txt'
                hypo_str_path = task.dataset(
                    subset).dataset_names[ds_id] + '.hypo.txt'
                tgt_str_files.append(
                    open(os.path.join(args.save_dir, tgt_str_path),
                         'w',
                         encoding='utf-8'))
                hypo_str_files.append(
                    open(os.path.join(args.save_dir, hypo_str_path),
                         'w',
                         encoding='utf-8'))

            def print_to_file(dataset_id, tgt_str, hypo_str):
                tgt_str_files[dataset_id].write(tgt_str + '\n')
                hypo_str_files[dataset_id].write(hypo_str + '\n')

            for sample in progress:
                trainer.test_bleu_step(sample, bleu_scorers, print_to_file)
                if 'dataset_id' in sample:
                    for ds_id in range(num_dataset):
                        sample_size[ds_id] += (
                            sample['dataset_id'] == ds_id).int().sum().item()
                elif 'id' in sample:
                    sample_size[0] += len(sample['id'])

            for f in tgt_str_files + hypo_str_files:
                f.close()

            distributed_utils.barrier(
                args, "validate2_%d" % trainer.get_num_updates())
            for ds_id in range(num_dataset):
                try:
                    bleu_scores[ds_id] = bleu_scorers[ds_id].score(
                    ) * sample_size[ds_id]
                except Exception as e:
                    bleu_scores[ds_id] = 0

            sample_size = torch.Tensor(sample_size).cuda()
            bleu_scores = torch.Tensor(bleu_scores).cuda()
            if args.distributed_world_size > 1:
                all_reduce(sample_size)
                all_reduce(bleu_scores)

            bleu_dict = {}
            for ds_id in range(num_dataset):
                if sample_size[ds_id].item() > 0:
                    name = "bleu_" + task.dataset(subset).dataset_names[ds_id]
                    bleu_dict[name] = stats[name] = bleu_scores[ds_id].item(
                    ) / sample_size[ds_id].item()
                    try:
                        train_ds_id = task.dataset(
                            'train').dataset_names.index(
                                task.dataset(subset).dataset_names[ds_id])
                        task.dataset('train').student_scores[
                            train_ds_id] = bleu_dict[name]
                    except ValueError:
                        pass
            output_path = os.path.join(args.save_dir, 'val_bleu.json')
            json.dump(bleu_dict, open(output_path, 'w'))

        progress.print(stats)
        if summary_writer is not None:
            summary_writer.log_stats('val/' + subset, stats,
                                     trainer.get_num_updates())

        valid_losses.append(stats['valid_loss'])
    return valid_losses