示例#1
0
def compet_meta_cl(model, meta_learning_task, meta_learning_args,
                   meta_learning_criterion, fine_tune_args):
    meta_epoch_itr, meta_trainer, max_meta_epoch, max_meta_update, valid_subsets = prepare_meta_task(
        model=model,
        meta_learning_task=meta_learning_task,
        meta_learning_args=meta_learning_args,
        meta_learning_criterion=meta_learning_criterion)
    full_meta_learning_task = copy.deepcopy(meta_learning_task)
    frac_type = meta_learning_args.cl_frac
    assert (frac_type is not None)
    lr = meta_trainer.get_lr()
    # Evaluate on validation split
    print("| [Meta-Train Epoch] First validation ")
    maybe_validate(meta_epoch_itr=meta_epoch_itr,
                   meta_learning_args=meta_learning_args,
                   meta_trainer=meta_trainer,
                   meta_learning_task=meta_learning_task,
                   valid_subsets=valid_subsets)
    while lr > meta_learning_args.min_lr and meta_epoch_itr.epoch < max_meta_epoch and meta_trainer.get_num_updates(
    ) < max_meta_update:
        # Train the model for one epoch
        last_epoch = int(meta_epoch_itr.epoch)
        meta_trainer, meta_epoch_itr, meta_learning_task = modify_trainer(
            meta_learning_args, full_meta_learning_task, meta_trainer,
            frac_type, meta_trainer.get_num_updates(), max_meta_update)
        meta_epoch_itr.epoch = last_epoch
        print('|[Meta-Train Epoch] {} Cur step: {}/{}, task_num: {}'.format(
            meta_epoch_itr.epoch, meta_trainer.get_num_updates(),
            max_meta_update,
            len(
                meta_learning_task.dataset(
                    meta_learning_args.train_subset).meta_tasks)))
        utils.train(args=meta_learning_args,
                    trainer=meta_trainer,
                    task=meta_learning_task,
                    epoch_itr=meta_epoch_itr,
                    is_curriculum=meta_learning_args.is_curriculum)
        # Evaluate on validation split
        print("| [Meta-Train Epoch] validation start")
        valid_losses, _ = maybe_validate(meta_epoch_itr=meta_epoch_itr,
                                         meta_learning_args=meta_learning_args,
                                         meta_trainer=meta_trainer,
                                         meta_learning_task=meta_learning_task,
                                         valid_subsets=valid_subsets)
        # save checkpoint
        if meta_epoch_itr.epoch % meta_learning_args.save_interval == 0:
            utils.save_checkpoint(meta_learning_args, meta_trainer,
                                  meta_epoch_itr, valid_losses[0])
        # only use first validation loss to update the learning rate
        lr = meta_trainer.lr_step(meta_epoch_itr.epoch, valid_losses[0])
        print("| [Meta-Train Epoch END] ")
示例#2
0
def fairseq_reptile(model, meta_learning_task, meta_learning_args,
                    meta_learning_criterion, fine_tune_args):
    meta_epoch_itr, meta_trainer, max_meta_epoch, max_meta_update, valid_subsets = prepare_meta_task(
        model=model,
        meta_learning_task=meta_learning_task,
        meta_learning_args=meta_learning_args,
        meta_learning_criterion=meta_learning_criterion)
    lr = meta_trainer.get_lr()
    # Evaluate on validation split
    print("| [Meta-Train Epoch] First validation ")
    maybe_validate(meta_epoch_itr=meta_epoch_itr,
                   meta_learning_args=meta_learning_args,
                   meta_trainer=meta_trainer,
                   meta_learning_task=meta_learning_task,
                   valid_subsets=valid_subsets)
    while lr > meta_learning_args.min_lr and meta_epoch_itr.epoch < max_meta_epoch and meta_trainer.get_num_updates(
    ) < max_meta_update:
        # Train the model for one epoch
        print("|[Meta-Train Epoch] ", meta_epoch_itr.epoch)
        utils.train(args=meta_learning_args,
                    trainer=meta_trainer,
                    task=meta_learning_task,
                    epoch_itr=meta_epoch_itr,
                    is_curriculum=meta_learning_args.is_curriculum)
        # Evaluate on validation split
        print("| [Meta-Train Epoch] validation start")
        valid_losses, _ = maybe_validate(meta_epoch_itr=meta_epoch_itr,
                                         meta_learning_args=meta_learning_args,
                                         meta_trainer=meta_trainer,
                                         meta_learning_task=meta_learning_task,
                                         valid_subsets=valid_subsets)
        # save checkpoint
        if meta_epoch_itr.epoch % meta_learning_args.save_interval == 0:
            utils.save_checkpoint(meta_learning_args, meta_trainer,
                                  meta_epoch_itr, valid_losses[0])
        # only use first validation loss to update the learning rate
        lr = meta_trainer.lr_step(meta_epoch_itr.epoch, valid_losses[0])
        print("|[Meta-Train Epoch END] ", meta_epoch_itr.epoch)
示例#3
0
def baseline_with_meta_evaluation(model, meta_learning_task,
                                  meta_learning_args, meta_learning_criterion,
                                  fine_tune_args):
    meta_epoch_itr, meta_trainer, max_meta_epoch, max_meta_update, valid_subsets = prepare_meta_task(
        model=model,
        meta_learning_task=meta_learning_task,
        meta_learning_args=meta_learning_args,
        meta_learning_criterion=meta_learning_criterion)
    # Combine and do fine-tuning on combined data
    meta_train = meta_learning_task.dataset(meta_learning_args.train_subset)
    combined_fairseq_task = combine_data(meta_train=meta_train,
                                         fine_tune_args=fine_tune_args)
    # Fine-tune using the combined task
    criterion = combined_fairseq_task.build_criterion(fine_tune_args)
    import math
    from fairseq.trainer import Trainer
    combined_fairseq_task.load_dataset(fine_tune_args.train_subset)
    train_dataset = combined_fairseq_task.dataset(fine_tune_args.train_subset)
    # Make a dummy batch to (i) warm the caching allocator and (ii) as a  placeholder DistributedDataParallel when
    # there's an uneven number of batches per worker.
    max_positions = utils.resolve_max_positions(
        combined_fairseq_task.max_positions(),
        model.max_positions(),
    )
    dummy_batch = train_dataset.get_dummy_batch(
        num_tokens=fine_tune_args.max_tokens, max_positions=max_positions)
    oom_batch = combined_fairseq_task.dataset(
        fine_tune_args.train_subset).get_dummy_batch(1, max_positions)
    # Create a trainer for training the model
    trainer = Trainer(fine_tune_args, combined_fairseq_task, model, criterion,
                      dummy_batch, oom_batch)
    epoch_itr = utils.create_epoch_iterator(task=combined_fairseq_task,
                                            dataset=train_dataset,
                                            args=fine_tune_args,
                                            max_positions=max_positions)
    max_epoch = fine_tune_args.max_epoch or math.inf
    max_update = fine_tune_args.max_update or math.inf
    # Do SGD on this task
    valid_subsets = fine_tune_args.valid_subset.split(',')
    lr = trainer.get_lr()
    batch_info = []
    # Always validate once before training
    valid_losses, _ = utils.validate(fine_tune_args, trainer,
                                     combined_fairseq_task, epoch_itr,
                                     valid_subsets)
    while lr > fine_tune_args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates(
    ) < max_update:
        # Train the model for one epoch
        import collections
        import math
        from fairseq.data import iterators
        from fairseq import progress_bar
        from fairseq.meters import AverageMeter, ConcatentateMeter, BleuMeter
        """Train the model for one epoch."""
        # Update parameters every N batches
        update_freq = fine_tune_args.update_freq[epoch_itr.epoch - 1] \
            if epoch_itr.epoch <= len(fine_tune_args.update_freq) else fine_tune_args.update_freq[-1]

        # Initialize data iterator
        itr = epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=fine_tune_args.fix_batches_to_gpus,
            shuffle=(epoch_itr.epoch >= fine_tune_args.curriculum),
        )
        itr = iterators.GroupedIterator(itr, update_freq)
        progress = progress_bar.build_progress_bar(
            fine_tune_args,
            itr,
            epoch_itr.epoch,
            no_progress_bar='simple',
        )

        extra_meters = collections.defaultdict(lambda: AverageMeter())
        extra_meters['strings'] = ConcatentateMeter()
        extra_meters['bleu_stats'] = BleuMeter()

        valid_subsets = fine_tune_args.valid_subset.split(',')
        max_update = fine_tune_args.max_update or math.inf
        for i, samples in enumerate(progress,
                                    start=epoch_itr.iterations_in_epoch):
            log_output = trainer.train_step(samples)
            if log_output is None:
                continue

            # log mid-epoch stats
            stats = utils.get_training_stats(trainer)
            for k, v in log_output.items():
                if k in [
                        'loss', 'nll_loss', 'ntokens', 'nsentences',
                        'sample_size'
                ]:
                    continue  # these are already logged above
                if 'loss' in k:
                    extra_meters[k].update(v, log_output['sample_size'])
                else:
                    extra_meters[k].update(v)
                stats[k] = extra_meters[k].avg
            progress.log(stats,
                         tag=fine_tune_args.train_subset,
                         step=stats['num_updates'])

            # ignore the first mini-batch in words-per-second calculation
            if i == 0:
                trainer.get_meter('wps').reset()

            num_updates = trainer.get_num_updates()
            if fine_tune_args.save_interval_updates > 0 and num_updates % fine_tune_args.save_interval_updates == 0 and num_updates > 0:
                valid_losses, _ = utils.validate(fine_tune_args,
                                                 trainer,
                                                 combined_fairseq_task,
                                                 epoch_itr,
                                                 valid_subsets,
                                                 train_progress=progress)
                utils.save_checkpoint(fine_tune_args, trainer, epoch_itr,
                                      valid_losses[0])

            if num_updates >= max_update:
                break

        # log end-of-epoch stats
        stats = utils.get_training_stats(trainer)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
            stats[k + '_std'] = meter.std
        progress.print(stats,
                       tag=fine_tune_args.train_subset,
                       step=stats['num_updates'])

        # reset training meters
        for k in [
                'train_loss',
                'train_nll_loss',
                'wps',
                'ups',
                'wpb',
                'bsz',
                'gnorm',
                'clip',
        ]:
            meter = trainer.get_meter(k)
            if meter is not None:
                meter.reset()
        # Evaluate on validation split
        if epoch_itr.epoch % fine_tune_args.validate_interval == 0:
            valid_losses, _ = utils.validate(fine_tune_args, trainer,
                                             combined_fairseq_task, epoch_itr,
                                             valid_subsets)
        # save checkpoint
        if epoch_itr.epoch % fine_tune_args.save_interval == 0:
            utils.save_checkpoint(fine_tune_args, trainer, epoch_itr,
                                  valid_losses[0])
        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
    if batch_info is None:
        # Handle the original train function
        batch_info = []
    # Evaluate on validation split
    maybe_validate(meta_epoch_itr=meta_epoch_itr,
                   meta_learning_args=meta_learning_args,
                   meta_trainer=meta_trainer,
                   meta_learning_task=meta_learning_task,
                   valid_subsets=valid_subsets)
示例#4
0
 def _async_save_checkpoint(self, rank, device_id, args, epoch,
                            batch_offset, val_loss):
     utils.save_checkpoint(args, epoch, batch_offset, self.model,
                           self.optimizer, self.lr_scheduler, val_loss)