Beispiel #1
0
    def test_grouped_iterator(self):
        # test correctness
        x = list(range(10))
        itr = iterators.GroupedIterator(x, 1)
        self.assertEqual(list(itr),
                         [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]])
        itr = iterators.GroupedIterator(x, 4)
        self.assertEqual(list(itr), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]])
        itr = iterators.GroupedIterator(x, 5)
        self.assertEqual(list(itr), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])

        # test the GroupIterator also works correctly as a CountingIterator
        x = list(range(30))
        ref = list(iterators.GroupedIterator(x, 3))
        itr = iterators.GroupedIterator(x, 3)
        self.test_counting_iterator_index(ref, itr)
Beispiel #2
0
    def test_grouped_iterator(self):
        # test correctness
        x = list(range(10))
        itr = iterators.GroupedIterator(x, 1)
        self.assertEqual(list(itr),
                         [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]])
        itr = iterators.GroupedIterator(x, 4)
        self.assertEqual(list(itr), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]])
        itr = iterators.GroupedIterator(x, 5)
        self.assertEqual(list(itr), [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])

        # test CountingIterator functionality
        x = list(range(30))
        ref = list(iterators.GroupedIterator(x, 3))
        itr = iterators.GroupedIterator(x, 3)
        self.test_counting_iterator(ref, itr)
Beispiel #3
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
    )

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False
    for i, samples in enumerate(progress):
        with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
            progress.log(stats, tag="train_inner", step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(
            args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
        )
        if should_stop:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
Beispiel #4
0
def test_nmt(args, trainer, task, epoch_itr):
  # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

    valid_subsets = ['valid']
    max_update = args.max_update or math.inf
    num_samples = 0
    for samples in progress:
        for i, sample in enumerate(samples):
            total_loss = trainer.valid_step(sample)
            num_samples += 1
            #num_updates = trainer.get_num_updates()
         
    return num_samples / total_loss
Beispiel #5
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
Beispiel #6
0
def _estimate_diagonal_fisher(args, trainer, epoch_itr, n_steps):
    """Estimate the diagonal empirical fisher information matrix"""
    # Iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=True)
    itr = iterators.GroupedIterator(itr, 1, bottomless=True)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        0,
        no_progress_bar='simple',
    )
    progress.log_interval = n_steps // 10
    # Initialize the Fisher
    FIM = {
        name: th.zeros_like(p)
        for name, p in trainer.model.named_parameters()
    }
    # Iterate
    for i, samples in enumerate(islice(progress, n_steps)):
        # Forward backward
        trainer.train_step(samples, update_params=False, clip_grad=False)
        # Get gradients
        for name, p in trainer.model.named_parameters():
            FIM[name].add_(p.grad.detach()**2)
        # Log progress
        progress.log({"step": i})
    # Normalize
    FIM = {name: F / n_steps for name, F in FIM.items()}
    return FIM
Beispiel #7
0
def train(args, trainer, task, epoch_itr, max_update=math.inf):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )

    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, 'tpu', False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )
    progress.log_args(args, tag='train')

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(',')
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        end_of_epoch = not itr.has_next()
        valid_losses = validate_and_save(args, trainer, task, epoch_itr,
                                         valid_subsets, end_of_epoch)
        if should_stop_early(args,
                             valid_losses[0]) or num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return valid_losses
Beispiel #8
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    if hasattr(trainer.criterion, 'set_epoch'):
        trainer.criterion.set_epoch(epoch_itr.epoch)
    for samples in progress:
        if hasattr(trainer.criterion, 'set_num_updates'):
            trainer.criterion.set_num_updates(trainer.get_num_updates())

        log_output = trainer.train_step(samples)
        num_updates = trainer.get_num_updates()
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(metrics.get_smoothed_values('train'))
        progress.log(stats, tag='train', step=num_updates)

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
Beispiel #9
0
    def test_grouped_iterator_skip_remainder_batch(self):
        reference = [1, 2, 3, 4, 5, 6, 7, 8, 9]
        itr1 = _get_epoch_batch_itr(reference, 3, False)
        grouped_itr1 = iterators.GroupedIterator(itr1, 2, True)
        self.assertEqual(len(grouped_itr1), 1)

        itr2 = _get_epoch_batch_itr(reference, 3, False)
        grouped_itr2 = iterators.GroupedIterator(itr2, 2, False)
        self.assertEqual(len(grouped_itr2), 2)

        itr3 = _get_epoch_batch_itr(reference, 3, True)
        grouped_itr3 = iterators.GroupedIterator(itr3, 2, True)
        self.assertEqual(len(grouped_itr3), 1)

        itr4 = _get_epoch_batch_itr(reference, 3, True)
        grouped_itr4 = iterators.GroupedIterator(itr4, 2, False)
        self.assertEqual(len(grouped_itr4), 1)

        itr5 = _get_epoch_batch_itr(reference, 5, True)
        grouped_itr5 = iterators.GroupedIterator(itr5, 2, True)
        self.assertEqual(len(grouped_itr5), 0)

        itr6 = _get_epoch_batch_itr(reference, 5, True)
        grouped_itr6 = iterators.GroupedIterator(itr6, 2, False)
        self.assertEqual(len(grouped_itr6), 1)
Beispiel #10
0
def downstream_train_pytorch(args, trainer, task, epoch_itr, train_prefix):
    """Fine-tune PyTorch classifier on downstream training set for one epoch"""
    task.split = 'train'
    num_updates = trainer.get_num_updates()

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    progress = maybe_wrap_neptune_logging(progress, args)

    # Task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    max_update = args.max_update or math.inf
    with metrics.aggregate() as agg:
        for samples in progress:

            # Train for one step
            log_output = trainer.train_step(samples)
            num_updates = trainer.get_num_updates()
            if log_output is None:
                continue

            # log mid-epoch stats
            stats = get_ft_train_stats(agg.get_smoothed_values())
            progress.log(stats, tag=train_prefix, step=num_updates)

            if num_updates >= max_update:
                break

    # log end-of-epoch stats
    stats = get_ft_train_stats(agg.get_smoothed_values())
    try:
        progress.print(stats, tag=train_prefix, step=num_updates, log=False)
    except:
        progress.print(stats, tag=train_prefix, step=num_updates)

    # Reset epoch-level meters
    metrics.reset_meters(train_prefix)
Beispiel #11
0
  def initialize_loader_for_epoch(args, epoch_itr):
    if epoch_itr.epoch <= len(args.update_freq):
      update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
      update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=False, shuffle=(epoch_itr.epoch >= args.curriculum))
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple')
    return progress
Beispiel #12
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    task.split = 'train'

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = maybe_wrap_neptune_logging(
        progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            no_progress_bar='simple',
        ),
        args=args,
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    with metrics.aggregate() as agg:
        for samples in progress:
            log_output = trainer.train_step(samples)
            num_updates = trainer.get_num_updates()
            if log_output is None:
                continue

            # log mid-epoch stats
            stats = get_training_stats(agg.get_smoothed_values())
            progress.log(stats, tag='train', step=num_updates)

            if num_updates >= max_update:
                break

    # log end-of-epoch stats
    stats = get_training_stats(agg.get_smoothed_values())
    try:
        progress.print(stats, tag='train', step=num_updates, log=False)
    except:
        progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
Beispiel #13
0
def fisher(args, trainer, epoch_itr):

    if args.no_fisher:
        # Keep training code untouched, and make the Fisher values 1s.
        for n, p in trainer.model.named_parameters():
            trainer.fisher[n] = torch.ones(p.shape, device=p.device)

        for n, _ in trainer.model.named_parameters():
            trainer.fisher[n] = torch.autograd.Variable(trainer.fisher[n], requires_grad=False)

        return

    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

    for n, p in trainer.model.named_parameters():
        trainer.fisher[n] = 0 * p.data

    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        trainer.fisher_step(samples)

    for n, _ in trainer.model.named_parameters():
        trainer.fisher[n] = trainer.fisher[n] / len(progress)
        trainer.fisher[n] = torch.autograd.Variable(trainer.fisher[n], requires_grad=False)

    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Beispiel #14
0
    def start_epoch(self) -> bool:
        if not (self.trainer.get_lr() > self.args.min_lr and self.epoch_itr.epoch < self.max_epoch and
                self.trainer.get_num_updates() < self.max_update):
            self._done = True
            return False

        args = self.args
        update_freq = args.update_freq[self.epoch_itr.epoch - 1] \
            if self.epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
        # Initialize data iterator
        itr = self.epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=args.fix_batches_to_gpus,
            shuffle=(self.epoch_itr.epoch >= args.curriculum),
        )
        itr = iterators.GroupedIterator(itr, update_freq)
        # meters in the epoch
        self.extra_meters = collections.defaultdict(lambda: AverageMeter())
        # enumerate
        self.itr = enumerate(itr, start=self.epoch_itr.iterations_in_epoch)
        return True
Beispiel #15
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = 1

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=False,  # TODO: changed
    )

    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        progress.log(stats, tag='train', step=stats['num_updates'])

    stats = get_training_stats(trainer)
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz',
            'gnorm', 'clip'
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Beispiel #16
0
def train_nmt(args, trainer, task, epoch_itr):
   # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

    valid_subsets = ['valid']
    max_update = args.max_update or math.inf
    for samples in progress:
        with fmetrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            num_updates = trainer.get_num_updates()
            if log_output is None:
                continue

            # log mid-epoch stats
            #stats = get_training_stats('train_inner')
            #progress.log(stats, tag='train', step=num_updates)

        if (
            not args.disable_validation
            and args.save_interval_updates > 0
            and num_updates % args.save_interval_updates == 0
            and num_updates > 0
        ):
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
    return log_output
Beispiel #17
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        try:
            log_output = trainer.train_step(samples)
        except FloatingPointError as e:
            if "Minimum loss scale reached" in str(e):
                print(f'Check samples: len={len(samples)}')
                for ik, s in enumerate(samples):
                    if s is None:
                        print(f'[{ik}]: None')
                    else:
                        for k, v in s.items():
                            if isinstance(v, torch.Tensor):
                                print(f'[{ik}][{k}]: {v.size()}')

            raise e

        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Beispiel #18
0
def train(args, trainer, task, epoch_itr, summary_writer=None):
    """Train the model for one epoch."""

    # Update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus)
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    num_batches = len(epoch_itr)

    distributed_utils.barrier(args, "train_%d" % trainer.get_num_updates())
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg

        stats['progress'] = round(
            i / num_batches * args.distributed_world_size *
            args.update_freq[-1], 3)
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
            distributed_utils.barrier(
                args, "train_val_%d" % trainer.get_num_updates())

        if num_updates % args.log_interval == 0:
            summary_writer.log_stats('train', stats, num_updates)

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Beispiel #19
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    should_end_training = False
    for samples in progress:
        with metrics.aggregate('train_inner'):
            try:
                log_output = trainer.train_step(samples)

            except ResetTrainerException:
                trainer._wrapped_criterion = None
                trainer._wrapped_model = None
                trainer._optimizer = None

                logger.info("reset the trainer at {}".format(
                    trainer.get_num_updates()))
                log_output = trainer.train_step(samples)

            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        valid_losses = validate_and_save(args, trainer, task, epoch_itr,
                                         valid_subsets)
        if should_stop_early(args,
                             valid_losses[0]) or num_updates >= max_update:
            should_end_training = True
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return should_end_training
Beispiel #20
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    if type(
            task
    ) is tasks.factored_translation.FactoredTranslationTask:  # factored
        if args.factors_to_freeze is not None:
            factors_to_freeze = list({
                x
                for lang_pair in [args.factors_to_freeze]
                for x in lang_pair.split(',')
            })
            if epoch_itr.epoch == args.freeze_factors_epoch:
                for factor in factors_to_freeze:
                    print('Freezing', factor)
                    for param in trainer.get_model(
                    ).encoder.encoders[factor].parameters():
                        param.requires_grad = False

    # Update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus)
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Beispiel #21
0
def train(args,
          trainer,
          task,
          epoch_itr,
          generator=None,
          filtered_maxpos_indices=None):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf

    # data selection: reset epoch iter to filter out unselected data
    filter_data = epoch_itr.epoch % args.select_by_dds_epoch == 0
    if filter_data and args.select_by_dds_epoch > 0:
        epoch_itr, _ = trainer.get_filtered_train_iterator(
            epoch_itr.epoch, filtered_maxpos_indices=filtered_maxpos_indices)

    # if args.update_language_sampling > 0 and args.select_by_dds_epoch < 0 and (not args.data_actor_step_update):
    #     num_reset = len(epoch_itr.frozen_batches) // (args.update_language_sampling*args.update_freq[0]+1)
    #     datasize = args.update_language_sampling*args.update_freq[0]+1
    #     if num_reset * datasize < len(epoch_itr.frozen_batches):
    #         num_reset += 1
    # else:
    #     num_reset = 1
    #     datasize = -1
    # for reset_idx in range(num_reset):
    #     print("resetting at step", reset_idx)
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
        offset=0,
        datasize=-1,
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        #print(samples)

        # if args.extra_data_actor == 'ave_emb':
        #     update_actor = (i % args.extra_update_language_sampling == 0)
        # elif args.data_actor_step_update:
        #     update_actor = (i % args.update_language_sampling == 0)
        # elif args.data_actor == 'lan' and args.data_actor_step_update:
        #     update_actor = (i % args.update_language_sampling == 0)
        # else:
        #     update_actor = False
        # update sampling distribution
        # if args.update_language_sampling > 0 and i % args.update_language_sampling == 0 and args.data_actor != 'ave_emb' and not args.data_actor_step_update:
        #     if args.data_actor_multilin:
        #         trainer.update_language_sampler_multilin(args, epoch=epoch_itr.epoch)
        #     else:
        #         trainer.update_language_sampler(args)

        if (epoch_itr.epoch > args.select_by_dds_epoch
                and args.select_by_dds_epoch > 0):
            update_actor = False
        update_actor = False
        log_output = trainer.train_step(samples, update_actor=update_actor)
        if log_output is None:
            continue

        # update the data selector
        if args.select_by_dds_epoch > 0 and args.update_data_selector > 0 and i % args.update_data_selector == 0:
            trainer.update_data_selector(args)

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k or k == 'accuracy':
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets, generator)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
    return epoch_itr
Beispiel #22
0
def train(args, trainer, task, epoch_itr, epoch_aux_itr, fim=None):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
    print(update_freq)
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    # Auxiliary iterator
    aux_itr = epoch_aux_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus)
    aux_itr = iterators.GroupedIterator(aux_itr, update_freq, bottomless=True)

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        # Record gradients from auxiliary data
        aux_samples = next(aux_itr)
        trainer.train_step(aux_samples, update_params=False)
        # Fisher
        if hasattr(trainer.optimizer, "save_auxiliary"):
            trainer.optimizer.save_auxiliary()
        else:
            print("Warning, the optimizer is ignoring the auxiliary gradients")
        # Take a step on the primary task
        log_output = trainer.train_step(samples, apply_ewc=args.ewc > 0)

        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, None)

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Beispiel #23
0
def train(args, trainer, task, epoch_itr, experiment=None):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(args,
                                               itr,
                                               epoch_itr.epoch,
                                               no_progress_bar="simple")

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(",")
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    "loss", "nll_loss", "ntokens", "nsentences", "sample_size"
            ]:
                continue  # these are already logged above
            if "loss" in k or k == "accuracy":
                extra_meters[k].update(v, log_output["sample_size"])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag="train", step=stats["num_updates"])
        if experiment:
            experiment.log_metrics(stats,
                                   step=stats["num_updates"],
                                   prefix="mid_epoch_train")

        # ignore the first mini-batch in words-per-second and updates-per-second calculation
        if i == 0:
            trainer.get_meter("wps").reset()
            trainer.get_meter("ups").reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag="train", step=stats["num_updates"])
    if experiment:
        experiment.log_metrics(stats,
                               prefix="end_of_epoch_train",
                               step=stats["num_updates"])

    # reset training meters
    for k in [
            "train_loss",
            "train_nll_loss",
            "wps",
            "ups",
            "wpb",
            "bsz",
            "gnorm",
            "clip",
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Beispiel #24
0
def train(args, trainer, task, epoch_itr, force_refine_step=None):
    """Train the model for one epoch."""

    # Update parameters every N batches
    def is_better(a, b):
        return a > b if args.maximize_best_checkpoint_metric else a < b

    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    if hasattr(args, "progressive") and args.progressive:
        task.dataset("train").set_random_refine_step(
            args.refinetot, force_refine_step=force_refine_step)
    last_samples = None
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if samples is None or len(samples) == 0:
            sys.stderr.write("Empty sample detected\n")
            sys.stderr.flush()
            samples = last_samples
        else:
            last_samples = samples
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue
        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args,
                                    trainer,
                                    task,
                                    epoch_itr,
                                    valid_subsets,
                                    force_refine_step=force_refine_step)
            # if distributed_utils.is_master(args):
            #     print("saving:", trainer.get_num_updates())
            #     nsml.save(str(trainer.get_num_updates()))
            if not hasattr(checkpoint_utils.save_checkpoint,
                           'best') or is_better(
                               valid_losses[0],
                               checkpoint_utils.save_checkpoint.best):
                if distributed_utils.is_master(args):
                    print("saving checkpoint ...")
                    sys.stdout.flush()
                    if HAS_NSML:
                        nsml.save("best")
                    else:
                        torch.save({"model": trainer.get_model().state_dict()},
                                   "/tmp/best.pt")
                    if HAS_WANDB:
                        wandb.save("/tmp/best.pt")
                    sys.stdout.flush()
                checkpoint_utils.save_checkpoint.best = valid_losses[0]

        if args.decoder_wise_training and update_num_to_refine_step(
                num_updates) != force_refine_step:
            if HAS_NSML:
                nsml.load("best")
            else:
                # Retrieve the model
                if distributed_utils.is_master(args):
                    state = torch.load("/tmp/best.pt", map_location="cpu")
                    trainer.model.load_state_dict(state["model"])
                # Sync
                assert isinstance(trainer.model,
                                  parallel.DistributedDataParallel)
                if isinstance(trainer.model, parallel.DistributedDataParallel):
                    trainer.model._sync_params()

            checkpoint_utils.save_checkpoint.best = 0.
            force_refine_step = update_num_to_refine_step(num_updates)
            trainer.criterion.pool.clear()
            print("| Start refinement step:", force_refine_step)

        if num_updates >= max_update:
            break

        if hasattr(args, "progressive") and args.progressive:
            task.dataset("train").set_random_refine_step(
                args.refinetot, force_refine_step=force_refine_step)

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
def train(args, trainer, task, epoch_itr, epoch_aux_itr):
    """Train the model for one epoch."""

    # Update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus)
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    # Auxiliary iterator
    aux_itr = epoch_aux_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus)
    aux_itr = iterators.GroupedIterator(aux_itr,
                                        update_freq,
                                        restart_when_done=True)

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        # Record gradients from auxiliary data
        aux_samples = next(aux_itr)
        trainer.train_step(aux_samples, update_params=False)
        # if hasattr(trainer.optimizer, "save_constraints"):
        trainer.optimizer.save_constraints()

        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
def train(args,
          trainer,
          task,
          epoch_itr,
          model,
          experiment_path,
          total_samples=None,
          last_epoch_num=0,
          restore=None):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
    )

    num_heads = args.decoder_attention_heads
    head_dim = args.decoder_embed_dim // num_heads
    if experiment_path is not None:
        with open(experiment_path, 'r') as f:
            swaps = json.load(f)
        mhr(model, swaps, head_dim, num_heads, epoch_itr.epoch)

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False

    conf = {
        "encoder": [{
            "self_attn": []
        } for i in range(args.encoder_layers)],
        "decoder": [{
            "self_attn": [],
            "enc_attn": []
        } for i in range(args.decoder_layers)]
    }
    attentions = {
        "decoder": [{
            "self_attn": []
        } for i in range(args.decoder_layers)]
    }

    batch_regression = 1.0 - (total_samples / (160239 * 50))
    for i, samples in enumerate(progress):
        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples,
                                            batch_num=batch_regression)

            if log_output is None:  # OOM, overflow, ...
                continue
        total_samples += model.decoder.layers[0].self_attn.bsz
        batch_regression = 1.0 - (
            total_samples / (160239 * 40)
        )  # need to find more generic way to find total samples and epoch num.

        # Get Confidence for each Head.
        if args.head_confidence_method is not None:
            conf = get_batch_confs(model, conf, args)

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values("train_inner"))
            progress.log(stats, tag="train_inner", step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop, val_conf = validate_and_save(
            args, trainer, task, epoch_itr, valid_subsets, end_of_epoch)

        if should_stop:
            break

    if args.head_confidence_method is not None:

        conf = convert_confs(conf, args)

        path = args.save_dir.replace("checkpoints",
                                     "confs") + "-method={0}".format(
                                         args.head_confidence_method)
        try:
            os.mkdir(path, 0o775)
        except:
            pass
        with open(
                args.save_dir.replace("checkpoints", "confs") +
                "-method={0}".format(args.head_confidence_method) +
                "/epoch-{0}.pkl".format(epoch_itr.epoch), 'wb') as fd:
            pickle.dump(conf, fd, protocol=3)

    if args.dynamic_type is not None and args.head_confidence_method is not None:
        conf = val_conf

        restore['enc_self_attn'], last_epoch_num[
            'enc_self_attn'] = dynamic_mhr(model,
                                           int(args.start_dynamic_mhr[0]),
                                           "encoder",
                                           "self_attn",
                                           restore['enc_self_attn'],
                                           int(args.dynamic_swap_frequency[0]),
                                           last_epoch_num['enc_self_attn'],
                                           epoch_itr.epoch + 1,
                                           int(args.dynamic_max_switches[0]),
                                           conf[0],
                                           num_heads,
                                           head_dim,
                                           args.encoder_layers,
                                           local_only=False,
                                           d_type=args.dynamic_type[0],
                                           rest=int(args.dynamic_rest[0]),
                                           end_epoch=int(
                                               args.dynamic_end_epoch[0]))

        restore['dec_self_attn'], last_epoch_num[
            'dec_self_attn'] = dynamic_mhr(model,
                                           int(args.start_dynamic_mhr[1]),
                                           "decoder",
                                           "self_attn",
                                           restore['dec_self_attn'],
                                           int(args.dynamic_swap_frequency[1]),
                                           last_epoch_num['dec_self_attn'],
                                           epoch_itr.epoch + 1,
                                           int(args.dynamic_max_switches[1]),
                                           conf[1],
                                           num_heads,
                                           head_dim,
                                           args.encoder_layers,
                                           local_only=False,
                                           d_type=args.dynamic_type[1],
                                           rest=int(args.dynamic_rest[1]),
                                           end_epoch=int(
                                               args.dynamic_end_epoch[1]))
        restore['dec_enc_attn'], last_epoch_num['dec_enc_attn'] = dynamic_mhr(
            model,
            int(args.start_dynamic_mhr[2]),
            "decoder",
            "encoder_attn",
            restore['dec_enc_attn'],
            int(args.dynamic_swap_frequency[2]),
            last_epoch_num['dec_enc_attn'],
            epoch_itr.epoch + 1,
            int(args.dynamic_max_switches[2]),
            conf[2],
            num_heads,
            head_dim,
            args.encoder_layers,
            local_only=False,
            d_type=args.dynamic_type[2],
            rest=int(args.dynamic_rest[2]),
            end_epoch=int(args.dynamic_end_epoch[2]))

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop, total_samples, restore, last_epoch_num
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask,
          epoch_itr) -> Tuple[List[Optional[float]], bool]:
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
    )
    update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(cfg.optimization.update_freq) else
                   cfg.optimization.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if cfg.common.tpu:
        itr = utils.tpu_data_loader(itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(cfg.common.tensorboard_logdir
                            if distributed_utils.is_master(
                                cfg.distributed_training) else None),
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
        wandb_project=(cfg.common.wandb_project if distributed_utils.is_master(
            cfg.distributed_training) else None),
        wandb_run_name=os.environ.get(
            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        azureml_logging=(cfg.common.azureml_logging
                         if distributed_utils.is_master(
                             cfg.distributed_training) else False),
    )
    progress.update_config(_flatten_config(cfg))

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = cfg.dataset.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    logger.info("Start iterating over samples")
    for i, samples in enumerate(progress):
        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % cfg.common.log_interval == 0:
                stats = get_training_stats(
                    metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(cfg, trainer, task,
                                                      epoch_itr, valid_subsets,
                                                      end_of_epoch)

        if should_stop:
            break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
Beispiel #28
0
def train(args, trainer, task, epoch_itr, max_update=math.inf, model=None):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    for i, samples in enumerate(progress):
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')
        if(i==0):
            print('epoch: ', epoch_itr.epoch)
            endeattn_norm=[]
            selfattn_norm=[]
            for m in model.modules():
                if(hasattr(m, 'selfattn_norm')):
                    if(m.selfattn_norm != None):
                        selfattn_norm.append(m.selfattn_norm)
                if(hasattr(m, 'endeattn_norm')):
                    if(m.endeattn_norm != None):
                        endeattn_norm.append(m.endeattn_norm)
            print('self attention norms: ', selfattn_norm)
            print('en/decoder attn norms:', endeattn_norm)
        valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
        if should_stop_early(args, valid_losses[0]) or num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return valid_losses
Beispiel #29
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples, epoch_itr=epoch_itr)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k or k == 'accuracy':
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Beispiel #30
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""

    # Update parameters every N batches
    if epoch_itr.epoch <= len(args.update_freq):
        update_freq = args.update_freq[epoch_itr.epoch - 1]
    else:
        update_freq = args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus)
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    first_valid = args.valid_subset.split(',')[0]
    max_update = args.max_update or math.inf
    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue
            if '_cls' in k or '_reg' in k or '_num' in k or '_acc' in k:
                continue
            extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg

        for i in range(args.num_props):
            loss_log_key = '%d_cls' % i if i in args.cls_index else '%d_reg' % i
            sample_num = log_output.get('%d_num' % i, 0)
            extra_meters[loss_log_key].update(log_output.get(loss_log_key, 0),
                                              sample_num)
            stats[loss_log_key] = extra_meters[loss_log_key].avg
            if i in args.cls_index:
                cls_acc_key = '%d_acc' % i
                extra_meters[cls_acc_key].update(
                    log_output.get(cls_acc_key, 0), sample_num)
                stats[cls_acc_key] = extra_meters[cls_acc_key].avg
        progress.log(stats)

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()

        # Write Tensorboard.
        if num_updates % args.log_per_iter == 0:
            for k, v in stats.items():
                if sum([
                        1 for x in ['loss', 'ppl', 'ac', 'reg', 'lr'] if x in k
                ]) > 0:
                    trainer.summary_writer.scalar_summary(
                        'train/' + k, float(v), num_updates)
        if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    [first_valid])
            save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats)

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()