Exemple #1
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=True, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_losses = [None]
    valid_subsets = args.valid_subset.split(',')

    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Exemple #2
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    try:
        from fairseq.fb_pathmgr import fb_pathmgr
        global fb_pathmgr_registerd
        if not fb_pathmgr_registerd:
            fb_pathmgr.register()
            fb_pathmgr_registerd = True
    except (ModuleNotFoundError, ImportError):
        pass

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(model)
    print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max input frames per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    if callable(getattr(trainer.criterion, 'set_train_tgt_dataset', None)):
        trainer.criterion.set_train_tgt_dataset(task.dataset(args.train_subset).tgt)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while (
        (lr >= args.min_lr or trainer.get_num_updates() <= getattr(args, 'warmup_updates', 0))
        and (
            epoch_itr.epoch < max_epoch or (
                epoch_itr.epoch == max_epoch
                and epoch_itr._next_epoch_itr is not None
            )
        )
        and trainer.get_num_updates() < max_update
    ):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

        # only use first validation wer to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        reload_dataset = len(args.train_feat_files) > 1
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset)
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
Exemple #3
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args,
        itr,
        epoch_itr.epoch,
        no_progress_bar='simple',
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf

    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        log_output = trainer.train_step(samples)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(trainer)
        for k, v in log_output.items():
            if k in [
                    'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size'
            ]:
                continue  # these are already logged above
            if 'loss' in k:
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg
        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)

    # add epoch information
    #progress._log_epochInf_to_tensorboard('epoch_loss',stats['loss'],epoch_itr.epoch)
    #progress._log_epochInf_to_tensorboard('epoch_nll_loss', stats['nll_loss'], epoch_itr.epoch)
    #progress._log_epochInf_to_tensorboard('epoch_pll_loss', stats['nll_loss'], epoch_itr.epoch)

    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Exemple #4
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    adv_criterion = task.build_adversarial_criterion(args)
    adv = task.build_adversary(args, model)
    print(model)
    print('| model {}, criterion {}'.format(args.arch,
                                            criterion.__class__.__name__))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = AdversarialTrainer(args, task, model, criterion, adv_criterion,
                                 adv)
    print('| training on {} GPUs'.format(args.distributed_world_size))
    print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr, filtered_maxpos_indices = checkpoint_utils.load_checkpoint(
        args, trainer)

    # pretrain data actor
    if args.pretrain_data_actor and args.data_actor == 'lan' and args.data_actor_step_update:
        trainer.pretrain_data_actor()

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    if args.eval_bleu:
        gen_args = copy.deepcopy(args)
        gen_args.sample = False
        gen_args.beam = 5
        gen_args.batch_size = 32
        generator = task.build_generator(gen_args)
        args.maximize_best_checkpoint_metric = True
    else:
        generator = None
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # train for one epoch
        epoch_itr = train(args, trainer, task, epoch_itr, generator,
                          filtered_maxpos_indices)
        #trainer.update_language_sampler(args)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets, generator)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if ':' in getattr(args, 'data', ''):
            # sharded data: get train iterator for next epoch
            epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)[0]
    train_meter.stop()
    print('| done training in {:.1f} seconds'.format(train_meter.sum))
    for idx in sorted(trainer.idx_to_dev_grad_dotprod.keys()):
        print(idx)
        str_dotprod = [str(i) for i in trainer.idx_to_dev_grad_dotprod[idx]]
        print(" ".join(str_dotprod))
Exemple #5
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    if hasattr(trainer.criterion, 'set_epoch'):
        trainer.criterion.set_epoch(epoch_itr.epoch)
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
Exemple #6
0
def main_tpu(args):

    def prepare_task(args, xla_device):
        # Setup task, e.g., translation, language modeling, etc.
        task = tasks.setup_task(args)

        # Load valid dataset (we load training data below, based on the latest checkpoint)
        for valid_sub_split in args.valid_subset.split(','):
            task.load_dataset(valid_sub_split, combine=True, epoch=0)

        # Build models and criteria to print some metadata
        torch.manual_seed(args.seed)
        model, criterion = task.build_model(args), task.build_criterion(args)
        xm.master_print(model)
        xm.master_print('| model {}, criterion {}'.format(
            args.arch, criterion.__class__.__name__))
        xm.master_print('| num. model params: {} (num. trained: {})'.format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad)))
        model = model.to(xla_device)
        trainer = Trainer(args, task, model, criterion, xla_device=xla_device)
        lr = trainer.get_lr()

        # Load the latest checkpoint if one is available and restore the
        # corresponding train iterator
        # we overwrite distributed args here to shard data using torch_xla's
        # distributed training.
        trainer.args.distributed_rank = xm.get_ordinal()
        trainer.args.distributed_world_size = xm.xrt_world_size()
        extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
        trainer.args.distributed_rank = 0
        trainer.args.distributed_world_size = 1
        trainer.meters_to_device(xla_device)
        valid_subsets = args.valid_subset.split(',')
        ordinal = xm.get_ordinal(defval=-1)
        device_str = (
            str(xla_device) if ordinal < 0 else
            '{}/{}'.format(xla_device, ordinal)
        )
        return task, trainer, model, epoch_itr, lr, valid_subsets, device_str

    def train_loop_fn(device, trainer, loader, last_batch_index):
        """
        This is the main training loop. It trains for 1 epoch.
        """

        def print_training_update(trainer, progress, args, i):
            stats = get_training_stats(trainer, args=args)
            stats['now'] = now()
            progress.log(stats, tag='train', step=trainer.get_num_updates())
            progress.print_mid_epoch(i+1, force=True)

        stats, log_output, skip_stat_keys = None, None, {'clip'}
        max_update = args.max_update or math.inf
        for i, samples in enumerate(loader, start=epoch_itr.iterations_in_epoch):
            if i == last_batch_index:
                # last batches are incomplete
                break
            log_output = trainer.train_step(samples)
            reset_perf_training_meters(trainer, i, ignore_index=10)
            if (not (i % args.log_steps)) or (i == last_batch_index-1):
                step_args = trainer, progress, args, i
                xm.add_step_closure(print_training_update, args=step_args)
            num_updates = trainer.get_num_updates()
            if (
                not args.disable_validation
                and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0
            ):
                vloss = validate_subset(
                    args, device, trainer, task, epoch_itr, valid_subsets[0]
                )
                checkpoint_utils.save_checkpoint(
                    args, trainer, epoch_itr, vloss.item(),
                    epoch=epoch, end_of_epoch=False,
                )
            if num_updates >= max_update:
                break


    def valid_loop_fn(
        args, device, trainer, progress, loader, last_batch_index
    ):
        extra_meters = collections.defaultdict(lambda: AverageMeter())
        for i, sample in enumerate(loader):
            if i == last_batch_index:
                # last batches are of different size, will cause recompilations
                break
            log_output = trainer.valid_step(sample)
            for k, v in log_output.items():
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
                    continue
                extra_meters[k].update(v)
        stats = get_valid_stats(trainer, args)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        return stats

    def validate_subset(args, device, trainer, task, epoch_itr, subset):
        xm.master_print('Validating the subset "{}", {}'.format(subset, now()))
        # Initialize data iterator
        # we're not sharding the validation set
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_workers=args.num_workers
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args, itr, epoch_itr.epoch,
            prefix='valid on {} \'{}\' subset'.format(device, subset),
            no_progress_bar='simple'
        )
        para_loader = pl.ParallelLoader(progress, [xla_device])
        reset_validation_loss_meters(trainer)
        stats = valid_loop_fn(
            args, device, trainer, progress,
            para_loader.per_device_loader(xla_device), len(progress) - 1
        )
        progress_bar.progress_bar_print(
            progress, stats, step=trainer.get_num_updates(), force=True,
            tag='validate-{}'.format(subset), flush_writer=True,
        )
        xm.master_print('Validated the subset "{}", {}'.format(subset, now()))
        return stats['loss'].avg

    def validate_subsets(args, device, trainer, task, epoch_itr, subsets):
        valid_losses = {
            subset: validate_subset(
                args, device, trainer, task, epoch_itr, subset
            )
            for subset in subsets
        }
        return valid_losses

    def keep_training(lr, epoch_itr, trainer):
        # Train until the learning rate gets too small
        max_epoch = args.max_epoch or math.inf
        max_update = args.max_update or math.inf
        lr, n_updates = trainer.get_lr(), trainer.get_num_updates()
        return ((lr > args.min_lr) and (epoch_itr.epoch < max_epoch) and
            (n_updates < max_update))

    if xu.getenv_as('XLA_USE_BF16', bool, False):
        xm.master_print(
            'WARNING: bfloat16 is enabled. Note that fairseq meters such as '
            'loss will accumulate the numerator, and increment the denominator.'
            ' Due to lack of precision in higher numbers in bfloat16, these '
            'meters will report invalid values after a while.',
            fd=sys.stderr
        )

    xm.master_print('Args', fd=sys.stderr)
    for key, val in args.__dict__.items():
        xm.master_print('\t{} {}'.format(key, val), fd=sys.stderr)
    # `xla_device` is `torch.device` and `device` is `str`
    xla_device = xm.xla_device()
    task, trainer, model, epoch_itr, lr, valid_subsets, device = prepare_task(
        args, xla_device)

    train_meter = StopwatchMeter()
    train_meter.start()
    while keep_training(lr, epoch_itr, trainer):
        # TRAINING
        epoch = epoch_itr.epoch + 1
        xm.master_print('Epoch {} begin {}'.format(epoch, now()))
        progress = initialize_loader_for_epoch(
            args, epoch_itr, prefix='training on {}'.format(device),
        )
        skip_stat_keys = {'clip'}
        if args.suppress_loss_report:
            skip_stat_keys.update({'loss', 'nll_loss', 'gnorm'})
        progress.set_keys_to_skip_mid_epoch(skip_stat_keys)
        para_loader = pl.ParallelLoader(progress, [xla_device])
        train_loop_fn(
            device, trainer, para_loader.per_device_loader(xla_device),
            len(progress) - 1
        )
        training_stats = get_training_stats(trainer, args=args)
        tloss = training_stats['loss'].avg.item()
        progress_bar.progress_bar_print(
            progress, training_stats, tag='train', force=True,
            step=trainer.get_num_updates(), log_xla_metrics=True,
            flush_writer=True,
        )
        xm.master_print('Epoch {} end {}'.format(epoch_itr.epoch, now()))
        if args.metrics_debug:
            xm.master_print(met.metrics_report())
        reset_training_meters(trainer)

        # VALIDATION
        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate_subsets(
                args, device, trainer, task, epoch_itr, valid_subsets
            )

            # only use average first validation loss to update learning rate
            vloss = valid_losses[valid_subsets[0]].item()
            xm.master_print('old learning rate: {}'.format(lr))
            lr = trainer.lr_step(epoch_itr.epoch, vloss)
            xm.master_print('new learning rate: {}'.format(lr))
            if args.metrics_debug:
                xm.master_print(met.metrics_report())
        else:
            vloss = None

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(
                args, trainer, epoch_itr, vloss,
                epoch=epoch, end_of_epoch=True,
            )

    train_meter.stop()
    xm.master_print('| done training in {:.1f} seconds'.format(train_meter.sum))
    assert_on_losses(args, train_loss=tloss, valid_loss=vloss)
Exemple #7
0
def train(args,
          trainer,
          task,
          epoch_itr,
          generator=None,
          filtered_maxpos_indices=None):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]
    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    # data selection: reset epoch iter to filter out unselected data
    if epoch_itr.epoch == args.select_by_dds_epoch and args.select_by_dds_epoch > 0:
        epoch_itr, _ = trainer.get_filtered_train_iterator(
            epoch_itr.epoch, filtered_maxpos_indices=filtered_maxpos_indices)

    if args.update_language_sampling > 0 and args.select_by_dds_epoch < 0 and (
            not args.data_actor_step_update):
        num_reset = len(epoch_itr.frozen_batches) // (
            args.update_language_sampling * args.update_freq[0] + 1)
        datasize = args.update_language_sampling * args.update_freq[0] + 1
        if num_reset * datasize < len(epoch_itr.frozen_batches):
            num_reset += 1
    else:
        num_reset = 1
        datasize = -1

    for reset_idx in range(num_reset):
        print("resetting at step", reset_idx)
        # Initialize data iterator
        itr = epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=args.fix_batches_to_gpus,
            shuffle=(epoch_itr.epoch >= args.curriculum),
            offset=reset_idx *
            (args.update_language_sampling * args.update_freq[0] + 1),
            datasize=datasize,
        )
        itr = iterators.GroupedIterator(itr, update_freq)
        progress = progress_bar.build_progress_bar(
            args,
            itr,
            epoch_itr.epoch,
            no_progress_bar='simple',
        )

        for i, samples in enumerate(progress,
                                    start=epoch_itr.iterations_in_epoch):
            if args.extra_data_actor == 'ave_emb':
                update_actor = (i % args.extra_update_language_sampling == 0)
            elif args.data_actor_step_update:
                update_actor = (i % args.update_language_sampling == 0)
            elif args.data_actor == 'lan' and args.data_actor_step_update:
                update_actor = (i % args.update_language_sampling == 0)
            else:
                update_actor = False
            if (epoch_itr.epoch > args.select_by_dds_epoch
                    and args.select_by_dds_epoch > 0):
                update_actor = False
            log_output = trainer.train_step(samples, update_actor=update_actor)
            if log_output is None:
                continue

            # update sampling distribution
            if args.update_language_sampling > 0 and i % args.update_language_sampling == 0 and args.data_actor != 'ave_emb' and not args.data_actor_step_update:
                if args.data_actor_multilin:
                    trainer.update_language_sampler_multilin(
                        args, epoch=epoch_itr.epoch)
                else:
                    trainer.update_language_sampler(args)
            # log mid-epoch stats
            stats = get_training_stats(trainer)
            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue  # these are already logged above
                if 'loss' in k or k == 'accuracy':
                    extra_meters[k].update(v, log_output['sample_size'])
                else:
                    extra_meters[k].update(v)
                stats[k] = extra_meters[k].avg
            progress.log(stats, tag='train', step=stats['num_updates'])

            # ignore the first mini-batch in words-per-second calculation
            if i == 0:
                trainer.get_meter('wps').reset()

            num_updates = trainer.get_num_updates()
            if (not args.disable_validation and args.save_interval_updates > 0
                    and num_updates % args.save_interval_updates == 0
                    and num_updates > 0):
                valid_losses = validate(args, trainer, task, epoch_itr,
                                        valid_subsets, generator)
                checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                                 valid_losses[0])

            if num_updates >= max_update:
                break

    # log end-of-epoch stats
    stats = get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
            'train_loss',
            'train_nll_loss',
            'wps',
            'ups',
            'wpb',
            'bsz',
            'gnorm',
            'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
    return epoch_itr
Exemple #8
0
def main(args, init_distributed=False):
    utils.import_user_module(args)
    utils.handle_save_path(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    print(f"| Configs: {args}")

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    print(f"| Model: {args.arch} \n| Criterion: {criterion.__class__.__name__}")

    # Log architecture
    if args.train_subtransformer:
        print(" \n\n\t\tWARNING!!! Training one single SubTransformer\n\n")
        print(f"| SubTransformer Arch: {utils.get_subtransformer_config(args)} \n")
    else:
        print(" \n\n\t\tWARNING!!! Training SuperTransformer\n\n")
        print(f"| SuperTransformer Arch: {model} \n")

    # Log model size
    if args.train_subtransformer:
        print(f"| SubTransformer size (without embedding weights): {model.get_sampled_params_numel(utils.get_subtransformer_config(args))}")
        embed_size = args.decoder_embed_dim_subtransformer * len(task.tgt_dict)
        print(f"| Embedding layer size: {embed_size} \n")

    else:
        model_s = 0
        # if use model.state_dict, then will add 2 more parameters, they are encoder.version and decoder.version. Should not count them
        for name, param in model.named_parameters():
            if 'embed' not in name:
                model_s += param.numel()
        print(f"| SuperTransofmer model size (without embedding weights): {model_s}")

        print(f"| Embedding layer size: {sum(p.numel() for p in model.parameters() if p.requires_grad) - model_s} \n")

    # specify the length of the dummy input for profile
    # for iwslt, the average length is 23, for wmt, that is 30
    dummy_sentence_length_dict = {'iwslt': 23, 'wmt': 30}
    if 'iwslt' in args.arch:
        dummy_sentence_length = dummy_sentence_length_dict['iwslt']
    elif 'wmt' in args.arch:
        dummy_sentence_length = dummy_sentence_length_dict['wmt']
    else:
        raise NotImplementedError

    dummy_src_tokens = [2] + [7] * (dummy_sentence_length - 1)
    dummy_prev = [7] * (dummy_sentence_length - 1) + [2]

    # profile the overall FLOPs number
    if args.profile_flops:
        import torchprofile
        config_subtransformer = utils.get_subtransformer_config(args)
        model.set_sample_config(config_subtransformer)
        model.profile(mode=True)
        macs = torchprofile.profile_macs(model, args=(torch.tensor([dummy_src_tokens], dtype=torch.long), torch.tensor([30]), torch.tensor([dummy_prev], dtype=torch.long)))
        model.profile(mode=False)

        last_layer_macs = config_subtransformer['decoder']['decoder_embed_dim'] * dummy_sentence_length * len(task.tgt_dict)

        print(f"| Total FLOPs: {macs * 2}")
        print(f"| Last layer FLOPs: {last_layer_macs * 2}")
        print(f"| Total FLOPs without last layer: {(macs - last_layer_macs) * 2} \n")
        exit(0)

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    print(f"| Training on {args.distributed_world_size} GPUs")
    print(f"| Max tokens per GPU = {args.max_tokens} and max sentences per GPU = {args.max_sentences} \n")

    # Measure model latency, the program will exit after profiling latency
    if args.latcpu or args.latgpu:
        utils.measure_latency(args, model, dummy_src_tokens, dummy_prev)
        exit(0)

    # Load the latest checkpoint if one is available and restore the corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Evaluate the SubTransformer
    if args.validate_subtransformer:
        config = utils.get_subtransformer_config(args)
        trainer.set_sample_config(config)
        valid_loss = validate(args, trainer, task, epoch_itr, ['valid'], 'SubTransformer')
        print(f"| SubTransformer validation loss:{valid_loss}")

    # Loop boundaries
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()

    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')

    represent_configs = utils.get_represent_configs(args)

    # Main training loop
    while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates() < max_update:
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            for k, v in represent_configs.items():
                trainer.set_sample_config(config=v)
                valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, sampled_arch_name=k)
        else:
            valid_losses = [None]

        # update the best loss and get current lr; the real lr scheduling is done in trainer.train_step()
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint epoch level
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

    train_meter.stop()
    print('| Done training in {:.1f} seconds'.format(train_meter.sum))
Exemple #9
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Update parameters every N batches
    update_freq = args.update_freq[epoch_itr.epoch - 1] \
        if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]

    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch,
    )

    extra_meters = collections.defaultdict(lambda: AverageMeter())
    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf

    represent_configs = utils.get_represent_configs(args)

    for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
        if args.train_subtransformer:
            # training one SubTransformer only
            configs = [utils.get_subtransformer_config(args)]
        else:
            # training SuperTransformer by randomly sampling SubTransformers
            configs = [utils.sample_configs(utils.get_all_choices(args), reset_rand_seed=True, rand_seed=trainer.get_num_updates(),
                                            super_decoder_num_layer=args.decoder_layers)]

        log_output = trainer.train_step(samples, configs=configs)
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = utils.get_training_stats(trainer)
        for k, v in log_output.items():
            if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
                continue  # these are already logged above
            if 'loss' in k or k == 'accuracy':
                extra_meters[k].update(v, log_output['sample_size'])
            else:
                extra_meters[k].update(v)
            stats[k] = extra_meters[k].avg

        utils.log_arch_info(stats, configs[0])

        progress.log(stats, tag='train', step=stats['num_updates'])

        # ignore the first mini-batch in words-per-second calculation
        if i == 0:
            trainer.get_meter('wps').reset()

        num_updates = trainer.get_num_updates()
        if (
            not args.disable_validation
            and args.save_interval_updates > 0
            and num_updates % args.save_interval_updates == 0
            and num_updates > 0
        ):
            for k, v in represent_configs.items():
                trainer.set_sample_config(config=v)
                valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, sampled_arch_name=k)

            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = utils.get_training_stats(trainer)
    for k, meter in extra_meters.items():
        stats[k] = meter.avg
    progress.print(stats, tag='train', step=stats['num_updates'])

    # reset training meters
    for k in [
        'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
    ]:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()
Exemple #10
0
def validate_and_save(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.FairseqTask,
    epoch_itr,
    valid_subsets: List[str],
    end_of_epoch: bool,
) -> Tuple[List[Optional[float]], bool]:
    num_updates = trainer.get_num_updates()
    max_update = cfg.optimization.max_update or math.inf

    # Stopping conditions (and an additional one based on validation loss later
    # on)
    should_stop = False
    if num_updates >= max_update:
        should_stop = True
        logger.info(
            f"Stopping training due to "
            f"num_updates: {num_updates} >= max_update: {max_update}"
        )

    training_time_hours = trainer.cumulative_training_time() / (60 * 60)
    if (
        cfg.optimization.stop_time_hours > 0
        and training_time_hours > cfg.optimization.stop_time_hours
    ):
        should_stop = True
        logger.info(
            f"Stopping training due to "
            f"cumulative_training_time: {training_time_hours} > "
            f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
        )

    do_save = (
        (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
        or should_stop
        or (
            cfg.checkpoint.save_interval_updates > 0
            and num_updates > 0
            and num_updates % cfg.checkpoint.save_interval_updates == 0
            and num_updates >= cfg.dataset.validate_after_updates
        )
    )
    do_validate = (
        (not end_of_epoch and do_save)  # validate during mid-epoch saves
        or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
        or should_stop
        or (
            cfg.dataset.validate_interval_updates > 0
            and num_updates > 0
            and num_updates % cfg.dataset.validate_interval_updates == 0
        )
    ) and not cfg.dataset.disable_validation

    # Validate
    valid_losses = [None]
    if do_validate:
        valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)

    should_stop |= should_stop_early(cfg, valid_losses[0])

    # Save checkpoint
    if do_save or should_stop:
        checkpoint_utils.save_checkpoint(
            cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
        )

    return valid_losses, should_stop
Exemple #11
0
def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion
    model = task.build_model(args)
    criterion = task.build_criterion(args)
    logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch,
                                                criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer
    trainer = Trainer(args, task, model, criterion)
    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info(
        'max input frames per GPU = {} and max sentences per GPU = {}'.format(
            args.max_tokens,
            args.max_sentences,
        ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    while (lr > args.min_lr
           and (epoch_itr.epoch < max_epoch
                # allow resuming training from the final checkpoint
                or epoch_itr._next_epoch_itr is not None)
           and trainer.get_num_updates() < max_update):
        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        # early stop
        if should_stop_early(args, valid_losses[0]):
            logger.info(
                'early stop since valid performance hasn\'t improved for last {} runs'
                .format(args.patience))
            break

        reload_dataset = len(args.train_feat_files) > 1
        # sharded data: get train iterator for next epoch
        epoch_itr = trainer.get_train_iterator(epoch_itr.epoch,
                                               load_dataset=reload_dataset)
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
Exemple #12
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # logger.info("DEBUG: Entering fairseq_cli/train.py: train()")
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum), # shuffling leads to error for multitask learning wiht  cls_indices!!!
    )
    # logger.info("DEBUG: initialized itr")
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # logger.info("DEBUG: Got the progress bar")
    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())
    # logger.info("DEBUG: finished task specific setup per epoch")

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    # debug_count = 0
    for samples in progress:
        # if debug_count > 10:
        #     continue
        # debug_count += 1
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue
        # if debug_count < 20:
        #     logger.info("DEBUG: mini-batch {}".format(debug_count))
        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        if (
            not args.disable_validation
            and args.save_interval_updates > 0
            and num_updates % args.save_interval_updates == 0
            and num_updates > 0
        ):
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')