def compet_meta_cl(model, meta_learning_task, meta_learning_args, meta_learning_criterion, fine_tune_args): meta_epoch_itr, meta_trainer, max_meta_epoch, max_meta_update, valid_subsets = prepare_meta_task( model=model, meta_learning_task=meta_learning_task, meta_learning_args=meta_learning_args, meta_learning_criterion=meta_learning_criterion) full_meta_learning_task = copy.deepcopy(meta_learning_task) frac_type = meta_learning_args.cl_frac assert (frac_type is not None) lr = meta_trainer.get_lr() # Evaluate on validation split print("| [Meta-Train Epoch] First validation ") maybe_validate(meta_epoch_itr=meta_epoch_itr, meta_learning_args=meta_learning_args, meta_trainer=meta_trainer, meta_learning_task=meta_learning_task, valid_subsets=valid_subsets) while lr > meta_learning_args.min_lr and meta_epoch_itr.epoch < max_meta_epoch and meta_trainer.get_num_updates( ) < max_meta_update: # Train the model for one epoch last_epoch = int(meta_epoch_itr.epoch) meta_trainer, meta_epoch_itr, meta_learning_task = modify_trainer( meta_learning_args, full_meta_learning_task, meta_trainer, frac_type, meta_trainer.get_num_updates(), max_meta_update) meta_epoch_itr.epoch = last_epoch print('|[Meta-Train Epoch] {} Cur step: {}/{}, task_num: {}'.format( meta_epoch_itr.epoch, meta_trainer.get_num_updates(), max_meta_update, len( meta_learning_task.dataset( meta_learning_args.train_subset).meta_tasks))) utils.train(args=meta_learning_args, trainer=meta_trainer, task=meta_learning_task, epoch_itr=meta_epoch_itr, is_curriculum=meta_learning_args.is_curriculum) # Evaluate on validation split print("| [Meta-Train Epoch] validation start") valid_losses, _ = maybe_validate(meta_epoch_itr=meta_epoch_itr, meta_learning_args=meta_learning_args, meta_trainer=meta_trainer, meta_learning_task=meta_learning_task, valid_subsets=valid_subsets) # save checkpoint if meta_epoch_itr.epoch % meta_learning_args.save_interval == 0: utils.save_checkpoint(meta_learning_args, meta_trainer, meta_epoch_itr, valid_losses[0]) # only use first validation loss to update the learning rate lr = meta_trainer.lr_step(meta_epoch_itr.epoch, valid_losses[0]) print("| [Meta-Train Epoch END] ")
def fairseq_reptile(model, meta_learning_task, meta_learning_args, meta_learning_criterion, fine_tune_args): meta_epoch_itr, meta_trainer, max_meta_epoch, max_meta_update, valid_subsets = prepare_meta_task( model=model, meta_learning_task=meta_learning_task, meta_learning_args=meta_learning_args, meta_learning_criterion=meta_learning_criterion) lr = meta_trainer.get_lr() # Evaluate on validation split print("| [Meta-Train Epoch] First validation ") maybe_validate(meta_epoch_itr=meta_epoch_itr, meta_learning_args=meta_learning_args, meta_trainer=meta_trainer, meta_learning_task=meta_learning_task, valid_subsets=valid_subsets) while lr > meta_learning_args.min_lr and meta_epoch_itr.epoch < max_meta_epoch and meta_trainer.get_num_updates( ) < max_meta_update: # Train the model for one epoch print("|[Meta-Train Epoch] ", meta_epoch_itr.epoch) utils.train(args=meta_learning_args, trainer=meta_trainer, task=meta_learning_task, epoch_itr=meta_epoch_itr, is_curriculum=meta_learning_args.is_curriculum) # Evaluate on validation split print("| [Meta-Train Epoch] validation start") valid_losses, _ = maybe_validate(meta_epoch_itr=meta_epoch_itr, meta_learning_args=meta_learning_args, meta_trainer=meta_trainer, meta_learning_task=meta_learning_task, valid_subsets=valid_subsets) # save checkpoint if meta_epoch_itr.epoch % meta_learning_args.save_interval == 0: utils.save_checkpoint(meta_learning_args, meta_trainer, meta_epoch_itr, valid_losses[0]) # only use first validation loss to update the learning rate lr = meta_trainer.lr_step(meta_epoch_itr.epoch, valid_losses[0]) print("|[Meta-Train Epoch END] ", meta_epoch_itr.epoch)
def baseline_with_meta_evaluation(model, meta_learning_task, meta_learning_args, meta_learning_criterion, fine_tune_args): meta_epoch_itr, meta_trainer, max_meta_epoch, max_meta_update, valid_subsets = prepare_meta_task( model=model, meta_learning_task=meta_learning_task, meta_learning_args=meta_learning_args, meta_learning_criterion=meta_learning_criterion) # Combine and do fine-tuning on combined data meta_train = meta_learning_task.dataset(meta_learning_args.train_subset) combined_fairseq_task = combine_data(meta_train=meta_train, fine_tune_args=fine_tune_args) # Fine-tune using the combined task criterion = combined_fairseq_task.build_criterion(fine_tune_args) import math from fairseq.trainer import Trainer combined_fairseq_task.load_dataset(fine_tune_args.train_subset) train_dataset = combined_fairseq_task.dataset(fine_tune_args.train_subset) # Make a dummy batch to (i) warm the caching allocator and (ii) as a placeholder DistributedDataParallel when # there's an uneven number of batches per worker. max_positions = utils.resolve_max_positions( combined_fairseq_task.max_positions(), model.max_positions(), ) dummy_batch = train_dataset.get_dummy_batch( num_tokens=fine_tune_args.max_tokens, max_positions=max_positions) oom_batch = combined_fairseq_task.dataset( fine_tune_args.train_subset).get_dummy_batch(1, max_positions) # Create a trainer for training the model trainer = Trainer(fine_tune_args, combined_fairseq_task, model, criterion, dummy_batch, oom_batch) epoch_itr = utils.create_epoch_iterator(task=combined_fairseq_task, dataset=train_dataset, args=fine_tune_args, max_positions=max_positions) max_epoch = fine_tune_args.max_epoch or math.inf max_update = fine_tune_args.max_update or math.inf # Do SGD on this task valid_subsets = fine_tune_args.valid_subset.split(',') lr = trainer.get_lr() batch_info = [] # Always validate once before training valid_losses, _ = utils.validate(fine_tune_args, trainer, combined_fairseq_task, epoch_itr, valid_subsets) while lr > fine_tune_args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates( ) < max_update: # Train the model for one epoch import collections import math from fairseq.data import iterators from fairseq import progress_bar from fairseq.meters import AverageMeter, ConcatentateMeter, BleuMeter """Train the model for one epoch.""" # Update parameters every N batches update_freq = fine_tune_args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(fine_tune_args.update_freq) else fine_tune_args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=fine_tune_args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= fine_tune_args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( fine_tune_args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters['strings'] = ConcatentateMeter() extra_meters['bleu_stats'] = BleuMeter() valid_subsets = fine_tune_args.valid_subset.split(',') max_update = fine_tune_args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = utils.get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag=fine_tune_args.train_subset, step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if fine_tune_args.save_interval_updates > 0 and num_updates % fine_tune_args.save_interval_updates == 0 and num_updates > 0: valid_losses, _ = utils.validate(fine_tune_args, trainer, combined_fairseq_task, epoch_itr, valid_subsets, train_progress=progress) utils.save_checkpoint(fine_tune_args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = utils.get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg stats[k + '_std'] = meter.std progress.print(stats, tag=fine_tune_args.train_subset, step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset() # Evaluate on validation split if epoch_itr.epoch % fine_tune_args.validate_interval == 0: valid_losses, _ = utils.validate(fine_tune_args, trainer, combined_fairseq_task, epoch_itr, valid_subsets) # save checkpoint if epoch_itr.epoch % fine_tune_args.save_interval == 0: utils.save_checkpoint(fine_tune_args, trainer, epoch_itr, valid_losses[0]) # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) if batch_info is None: # Handle the original train function batch_info = [] # Evaluate on validation split maybe_validate(meta_epoch_itr=meta_epoch_itr, meta_learning_args=meta_learning_args, meta_trainer=meta_trainer, meta_learning_task=meta_learning_task, valid_subsets=valid_subsets)
def _async_save_checkpoint(self, rank, device_id, args, epoch, batch_offset, val_loss): utils.save_checkpoint(args, epoch, batch_offset, self.model, self.optimizer, self.lr_scheduler, val_loss)