def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if distributed_utils.is_master(self.args): # only save one checkpoint extra_state['train_meters'] = self.meters utils.save_state( filename, self.args, self.model, self.criterion, self.optimizer, self.lr_scheduler, self._num_updates, self._optim_history, extra_state, )
def save_checkpoint(args, trainer, epoch_itr, val_loss): if args.no_save or not distributed_utils.is_master(args): return epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0 ) checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0 ) checkpoint_conds['checkpoint_best.pt'] = ( val_loss is not None and (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best) ) checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink prev_best = getattr(save_checkpoint, 'best', val_loss) if val_loss is not None: save_checkpoint.best = min(val_loss, prev_best) extra_state = { 'best': save_checkpoint.best, 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, } checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] if len(checkpoints) > 0: for cp in checkpoints: trainer.save_checkpoint(cp, extra_state) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') for old_chk in checkpoints[args.keep_interval_updates:]: os.remove(old_chk)
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) if getattr(args, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def build_progress_bar(args, iterator, epoch=None, prefix=None, default='tqdm', no_progress_bar='none'): if args.log_format is None: args.log_format = no_progress_bar if args.no_progress_bar else default if args.log_format == 'tqdm' and not sys.stderr.isatty(): args.log_format = 'simple' if args.log_format == 'json': bar = json_progress_bar(iterator, epoch, prefix, args.log_interval) elif args.log_format == 'none': bar = noop_progress_bar(iterator, epoch, prefix) elif args.log_format == 'simple': bar = simple_progress_bar(iterator, epoch, prefix, args.log_interval) elif args.log_format == 'tqdm': bar = tqdm_progress_bar(iterator, epoch, prefix) else: raise ValueError('Unknown log format: {}'.format(args.log_format)) if args.tensorboard_logdir and distributed_utils.is_master(args): bar = tensorboard_log_wrapper(bar, args.tensorboard_logdir) return bar
def validate_save_and_evaluate_bleu( args, trainer, dataset, extra_state: Dict[str, Any], do_validate: bool, do_save: bool, do_eval_bleu: bool, ) -> Tuple[Optional[float], Optional[float], Optional[float], bool]: # evaluate on validate set val_loss = None val_ppl = None stop_due_to_val_loss = False if do_validate: val_loss, val_ppl, stop_due_to_val_loss = validate( args=args, trainer=trainer, dataset=dataset, subset=args.valid_subset, epoch=extra_state["epoch"], ) extra_state["val_loss"] = val_loss val_bleu = None stop_due_to_val_bleu = False if do_save and distributed_utils.is_master(args): # save checkpoint save_checkpoint(trainer=trainer, args=args, extra_state=extra_state) if do_eval_bleu: val_bleu, stop_due_to_val_bleu = evaluate_bleu( args=args, dataset=dataset, epoch=extra_state["epoch"], offset=extra_state["batch_offset"], ) return (val_loss, val_ppl, val_bleu, stop_due_to_val_loss or stop_due_to_val_bleu)
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args print(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates( ) < max_update: # train for one epoch train(args, trainer, task, epoch_itr) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) reload_dataset = ':' in getattr(args, 'data', '') # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def validate_metric(args, trainer, task, epoch_itr, subsets): # when training with distributed trainer, only one of them (the one args.distributed_rank == 0) is working ... print('args.distributed_rank', args.distributed_rank) print('args.distributed_world_size', args.distributed_world_size) if not distributed_utils.is_master(args): return """Evaluate the model on the validation set(s) and return the losses.""" for subset in subsets: model_output_placeholder = os.path.join( args.save_dir, '{}.{}.txt'.format('placeholder', subset)) model_output_file_list = [] # fout = open(model_output_file, 'w', encoding='utf8') # # firstly, output dictionary information # fout.write('%d\n'%len(task.target_dictionary)) # for i in range(len(task.target_dictionary)): # fout.write('{}\t{}\n'.format(task.target_dictionary[i], i)) # fout.flush() # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, # max_positions=trainer.get_model().max_positions(), max_positions=None, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=1, shard_id=0, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') cnt = 0 for sample in progress: preds = [] scores = [] trainer.model.eval() sample = utils.move_to_cuda(sample) # net_output = trainer.model(args.lam1, args.lam2, args.transpose_method, **sample['net_input']) with torch.no_grad(): net_output = trainer.model(**sample['net_input']) # probs = trainer.model.get_normalized_probs(net_output, log_probs=False) # _, pred = probs.max(2) if isinstance(net_output[0], list): if len(model_output_file_list) < len(net_output[0]): for idx, sub_net_output in enumerate(net_output[0]): model_output_file_list.append( init_output_file( model_output_placeholder.replace( 'placeholder', str(idx)), task.target_dictionary)) for sub_net_output, sub_score in zip(net_output[0], net_output[1]): preds.append(sub_net_output) scores.append(sub_score) else: if len(model_output_file_list) == 0: model_output_file_list.append( init_output_file( model_output_placeholder.replace( 'placeholder', '1'), task.target_dictionary)) preds.append(net_output[0]) scores.append(net_output[1]) if sample.get('target', None) is not None: target = trainer.model.get_targets(sample, net_output) if target.size(1) > preds[0].size(1): target = target[:, :preds[0].size(1)] else: target = torch.ones_like(preds[0]) target = torch.where(preds[0] == 0, torch.zeros_like(preds[0]), target.int()) assert len(preds) == len(scores) == len(model_output_file_list) for pred, score, fout in zip(preds, scores, model_output_file_list): for i in range(pred.size(0)): labels = [] pred_labels = [] pred_dists = [] pred_scores = [] for j in range(pred.size(1)): if target[i, j] != task.target_dictionary.pad(): labels.append(task.target_dictionary[target[i, j]]) pred_labels.append(task.target_dictionary[pred[i, j]]) pred_scores.append( str(round(score[i, j].item(), 5))) # pred_dists.append( ' '.join( map(lambda x: str(x.item()), probs[i, j]) ) )d else: break fout.write('True Labels:\t%s\n' % ' '.join(labels)) fout.write('Predicted Labels:\t%s\n' % ' '.join(pred_labels)) fout.write('Score:\t%s\n' % ' '.join(pred_scores)) fout.write('Predicted Distri:\t%s\n' % ' | '.join(pred_dists)) fout.flush() assert cnt == sample['id'][0] cnt += sample['id'].shape[0] for fout in model_output_file_list: fout.close() utils.xprintln('valid metric %s done!' % fout.name)
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(cfg.optimization.update_freq) else cfg.optimization.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), azureml_logging=(cfg.common.azureml_logging if distributed_utils.is_master( cfg.distributed_training) else False), ) progress.update_config(_flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() logger.info("Start iterating over samples") for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % cfg.common.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save(cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) if getattr(args, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): with metrics.aggregate( "train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format( epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop
def save_checkpoint(args, trainer, epoch_itr, val_loss): from fairseq import distributed_utils, meters prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: best_function = max if args.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) if args.no_save or not distributed_utils.is_master(args): return def is_better(a, b): return a >= b if args.maximize_best_checkpoint_metric else a <= b write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}.pt".format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0) checkpoint_conds["checkpoint_{}_{}.pt".format( epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0) checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best)) checkpoint_conds["checkpoint.best_{}_{:.2f}.pt".format( args.best_checkpoint_metric, val_loss)] = (val_loss is not None and args.keep_best_checkpoints > 0 and (not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best))) checkpoint_conds["checkpoint_last.pt"] = not args.no_last_checkpoints extra_state = { "train_iterator": epoch_itr.state_dict(), "val_loss": val_loss } if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: PathManager.copy(checkpoints[0], cp, overwrite=True) write_timer.stop() logger.info( "| saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)" .format(checkpoints[0], epoch, updates, val_loss, write_timer.sum)) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt") for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt") for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( args.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( args.best_checkpoint_metric)) if not args.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] for old_chk in checkpoints[args.keep_best_checkpoints:]: if os.path.lexists(old_chk): os.remove(old_chk)
def save_checkpoint(args, trainer, epoch_itr, val_loss): if args.no_save or not distributed_utils.is_master(args): return write_timer = StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0) checkpoint_conds['checkpoint_{}_{}.pt'.format( epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0) checkpoint_conds['checkpoint_best.pt'] = ( val_loss is not None and (not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)) checkpoint_conds[ 'checkpoint_last.pt'] = True # keep this last so that it's a symlink prev_best = getattr(save_checkpoint, 'best', val_loss) if val_loss is not None: save_checkpoint.best = min(val_loss, prev_best) extra_state = { 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, } if hasattr(save_checkpoint, 'best'): extra_state.update({'best': save_checkpoint.best}) checkpoints = [ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: for cp in checkpoints: trainer.save_checkpoint(cp, extra_state) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths( args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0 and epoch > 1: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt') for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk) write_timer.stop() print( '| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)' .format(checkpoints[0], epoch, updates, write_timer.sum))
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) logger.info('num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) logger.info('training on {} GPUs'.format(args.distributed_world_size)) logger.info( 'max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') tokenize = sacrebleu.DEFAULT_TOKENIZER if not args.eval_tokenized_bleu else 'none' hyps, refs = validate(args, trainer, task, epoch_itr, valid_subsets) for h, r, split in zip(hyps, refs, args.valid_subset.split(',')): assert len(h) == len(r) sacrebleu_score, _, _ = sacrebleu.corpus_bleu( h, [r], tokenize=tokenize), hyps, refs bleu = compute_cvpr_bleu(h, r) rouge_score = rouge.rouge(h, r) print('{} set has {} samples,\n' 'sacrebleu: {},\n' 'CVPR BLEU scripts: {}\n' 'CVPR ROUGE: {}'.format(split, len(h), sacrebleu_score, bleu, rouge_score)) print('performance: {:.2f} {}'.format( rouge_score['rouge_l/f_score'] * 100, ' '.join([str(b) for b in bleu])))
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args print(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr, filtered_maxpos_indices = checkpoint_utils.load_checkpoint( args, trainer) # pretrain data actor # only the language actor model can be pretrained if args.pretrain_laser and args.pretrain_data_actor and args.data_actor == 'ave': # pretrain the agent with LASER score # epoch_itr, indices = trainer.get_train_iterator(1) path = '/home/wtan12/multiDDS/' trainer.pretrain_LASER('en-ps.laser-score', epoch_itr) if args.compare_laser: epoch_itr, indices = trainer.get_train_iterator(1) print('Number of Indices: ', len(indices)) scores = collections.defaultdict(float) # compare with laser label using R^2 Score, only used after model is trained # itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=False, shuffle=False) data_actor = trainer.data_actor itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=False, offset=0, datasize=-1, ) for i, sample in enumerate(itr): sample = trainer._prepare_sample(sample) sample = list(sample.values())[0] score = data_actor(sample).cpu().detach().numpy().tolist() indices = sample['id'].data.cpu().numpy().ravel().tolist() for k, v in zip(indices, score): scores[k] = float(v[0]) scores = sorted(scores.items(), key=lambda x: x[0]) print('Number of Indices in Scoring file: ', len(scores)) path = '/home/wtan12/multiDDS/' with open(path + 'en-ps.laser-score', 'r') as r: data = r.read() laser_score = [] for i, item in enumerate(data.split('\n')): laser_score.append(item) laser_score.pop() r2 = 0.0 with open(path + 'en-ps.dds_score', 'w') as f: for k, v in scores: f.write(str(v) + '\n') truth = float(laser_score[k]) r2 += (truth - v)**2 print('R2 Score compared to LASER file: ', r2) return # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') if args.eval_bleu: generator = task.build_generator(args) args.maximize_best_checkpoint_metric = True else: generator = None while lr > args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates( ) < max_update: # train for one epoch epoch_itr = train(args, trainer, task, epoch_itr, generator, filtered_maxpos_indices) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, generator) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if ':' in getattr(args, 'data', ''): # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)[0] train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def validate( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str], ) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] predictions = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr( shuffle=False, set_dataset_epoch=False # use a fixed valid set ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(cfg.common.tensorboard_logdir if distributed_utils.is_master( cfg.distributed_training) else None), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=(cfg.common.wandb_project if distributed_utils.is_master( cfg.distributed_training) else None), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: prediction, _ = trainer.valid_step(sample) predictions.extend(prediction) with open( cfg.criterion.save_predictions + str(torch.cuda.current_device()) + ".txt", "w") as f: for prediction in predictions: f.write(prediction) f.write("\n") if trainer.is_data_parallel_master: with open(cfg.criterion.save_predictions + ".txt", "w") as outf: for i in range(torch.cuda.device_count()): with open(cfg.criterion.save_predictions + str(i) + ".txt", "r") as inf: lines = inf.read() outf.write(lines) # log validation stats stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric] ) ###############################error return valid_losses
def is_data_parallel_master(self): return distributed_utils.is_master(self.cfg.distributed_training)
def save_checkpoint_bleu(args, trainer, epoch_itr, valid_losses, valid_bleus, valid_select, begin): if args.no_save or not distributed_utils.is_master(args): return epoch = epoch_itr.epoch if begin: end_of_epoch = True else: end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0) checkpoint_conds['checkpoint_{}_{}.pt'.format( epoch, updates)] = (not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0) checkpoint_conds['checkpoint_best_bleu.pt'] = ( valid_select in valid_bleus.keys() and (not hasattr(save_checkpoint_bleu, 'best_bleu') or valid_bleus[valid_select] > save_checkpoint_bleu.best_bleu)) checkpoint_conds[ 'checkpoint_last.pt'] = True # keep this last so that it's a symlink prev_best_bleu = getattr(save_checkpoint_bleu, 'best_bleu', valid_bleus[valid_select]) if valid_select in valid_bleus.keys(): save_checkpoint_bleu.best_bleu = max(valid_bleus[valid_select], prev_best_bleu) extra_state = { 'train_iterator': epoch_itr.state_dict(), } for domain, bleu_domain in valid_bleus.items(): extra_state.update({'valid_loss_' + domain: valid_losses[domain]}) extra_state.update({'valid_bleu_' + domain: valid_bleus[domain]}) if hasattr(save_checkpoint_bleu, 'best_bleu'): extra_state.update({'best_bleu': save_checkpoint_bleu.best_bleu}) checkpoints = [ os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: for cp in checkpoints: trainer.save_checkpoint(cp, extra_state) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths( args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint(\d+)\.pt') for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk)
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' metrics.reset() # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) logger.info('num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # (optionally) Configure quantization if args.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=args.quantization_config_path, max_epoch=args.max_epoch, max_update=args.max_update, ) else: quantizer = None # Build trainer if args.model_parallel_size == 1: trainer = Trainer(args, task, model, criterion, quantizer) else: trainer = MegatronTrainer(args, task, model, criterion) logger.info('training on {} GPUs'.format(args.distributed_world_size)) logger.info( 'max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while (lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch): # train for one epoch valid_losses = train(args, trainer, task, epoch_itr, max_update) if should_stop_early( args, valid_losses[0]) or trainer.get_num_updates() >= max_update: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=(os.pathsep in getattr(args, 'data', '')), ) train_meter.stop() logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
def validate(args, trainer, task, epoch_itr, subsets, prune=-1): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # added by Junxian if prune > 0: index_map = trainer.get_model().set_prune_index(prune) task.set_index_map(index_map) # not write templates for time profiling write_template_flag = False if args.eval_mode == 'time' else True # only one worker deals with the template file in DDP if args.distributed_rank == 0 and write_template_flag: print('write template files') if args.eval_mode == 'none': fout = open( os.path.join( args.save_dir, 'templates_{}_{}.txt'.format( epoch_itr.epoch, trainer.get_num_updates())), 'w') else: fout = open( os.path.join(args.save_dir, 'templates_eval_{}.txt'.format(subset)), 'w') if prune <= 0: task.write_lambda(fout, trainer.get_model()) else: fout = None # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample, split=subset) # added by Junxian if args.distributed_rank == 0: task.write_template(sample, trainer.get_model(), fout) if fout is not None: fout.close() # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def main(cfg: DictConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) if distributed_utils.is_master(cfg.distributed_training): checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args logger.info(cfg) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {})".format(criterion.__class__.__name__)) logger.info("num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=cfg.common.quantization_config_path, max_epoch=cfg.optimization.max_epoch, max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer if cfg.common.model_parallel_size == 1: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size)) logger.info("max tokens per GPU = {} and batch size per GPU = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def validate_iw(args, trainer, task, epoch_itr, subsets, prune=-1, mode='iw'): """Evaluate the model on the validation set(s) and return the losses.""" if mode == 'none' or mode == 'time' or args.criterior == 'lm_baseline': return [0] # top k instead of sampling to approximate sum of prototypes for evaluation for subset in subsets: task.dataset(subset).set_sampling(False) if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=1, prefix=f"valid on '{subset}' subset", tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) if prune > 0: index_map = trainer.get_model().set_prune_index(prune) task.set_index_map(index_map) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_iw_step(sample, mode=mode) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag='valid_iw', step=trainer.get_num_updates()) # valid_losses.append(stats[args.best_checkpoint_metric]) if prune > 0: trainer.get_model().reset_prune_index() task.reset_index_map() return valid_losses
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) logger.info('num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) logger.info('training on {} GPUs'.format(args.distributed_world_size)) logger.info('max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') print(args.multi_views) while ( lr > args.min_lr and ( epoch_itr.epoch < max_epoch # allow resuming training from the final checkpoint or epoch_itr._next_epoch_itr is not None ) and trainer.get_num_updates() < max_update ): # train for one epoch train(args, trainer, task, epoch_itr) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) bart = BARTHubInterface(args, task, trainer.model).cuda() #print(bart.device) bart.eval() count = 1 bsz = 8 print("Test on val set: ") with open('../data/val_sent_trans_cons_label.source') as source, open('../data/val_sent_c99_label.source') as source2, open('./val_best_multi_attn_'+str(args.lr_weight)+'_.hypo', 'wt', encoding='utf-8') as fout: s1 = source.readlines() s2 = source2.readlines() slines = [s1[0].strip()] slines2 = [s2[0].strip()] for i in tqdm(range(1, len(s1))): if count % bsz == 0: with torch.no_grad(): if args.multi_views: hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3) else: hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3) for hypothesis in hypotheses_batch: fout.write(hypothesis + '\n') fout.flush() slines = [] slines2 = [] slines.append(s1[i].strip()) slines2.append(s2[i].strip()) count += 1 if slines != []: if args.multi_views: hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3) else: hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3) #hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3) for hypothesis in hypotheses_batch: fout.write(hypothesis + '\n') fout.flush() hyp_path = './val_best_multi_attn_'+str(args.lr_weight)+'_.hypo' ref_path = '../data/val_sent_trans_cons_label.target' hypothesis = [] with open(hyp_path, 'r') as f: lines = f.readlines() for l in lines: hypothesis.append(l[:-1]) reference = [] with open(ref_path, 'r') as f: lines = f.readlines() for l in lines: reference.append(l[:-1]) rouge = Rouge() print("Val", rouge.get_scores(hypothesis, reference, avg = True)) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) print("Test on testing set: ") count = 1 bsz = 8 with open('../data/test_sent_trans_cons_label.source') as source, open('../data/test_sent_c99_label.source') as source2, open('./test_best_multi_attn_'+str(args.lr_weight)+'_.hypo', 'wt', encoding='utf-8') as fout: s1 = source.readlines() s2 = source2.readlines() slines = [s1[0].strip()] slines2 = [s2[0].strip()] for i in tqdm(range(1, len(s1))): if count % bsz == 0: with torch.no_grad(): if args.multi_views: hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3) else: hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3) for hypothesis in hypotheses_batch: fout.write(hypothesis + '\n') fout.flush() slines = [] slines2 = [] slines.append(s1[i].strip()) slines2.append(s2[i].strip()) count += 1 if slines != []: if args.multi_views: hypotheses_batch = bart.sample(slines, sentences2 = slines2, balance = True, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3) else: hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=100, min_len=5, no_repeat_ngram_size=3) for hypothesis in hypotheses_batch: fout.write(hypothesis + '\n') fout.flush() hyp_path = './test_best_multi_attn_'+str(args.lr_weight)+'_.hypo' ref_path = '../data/test_sent_trans_cons_label.target' hypothesis = [] with open(hyp_path, 'r') as f: lines = f.readlines() for l in lines: hypothesis.append(l[:-1]) reference = [] with open(ref_path, 'r') as f: lines = f.readlines() for l in lines: reference.append(l[:-1]) rouge = Rouge() print('Test', rouge.get_scores(hypothesis, reference, avg = True)) # early stop if should_stop_early(args, valid_losses[0]): logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience)) break epoch_itr = trainer.get_train_iterator( epoch_itr.epoch, # sharded data: get train iterator for next epoch load_dataset=(os.pathsep in getattr(args, 'data', '')), ) train_meter.stop() logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) logger.info('num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # Build trainer trainer = Trainer(args, task, model, criterion) logger.info('training on {} GPUs'.format(args.distributed_world_size)) logger.info( 'max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') if args.eval_mode != 'none': start_val_time = time.time() with torch.no_grad(): if args.eval_mode != 'entropy': _ = validate(args, trainer, task, epoch_itr, valid_subsets, args.prune_num) print('elapsed time (seconds): {}'.format(time.time() - start_val_time)) _ = validate_iw(args, trainer, task, epoch_itr, valid_subsets, args.prune_num, mode=args.eval_mode) return while (lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch and trainer.get_num_updates() < max_update): # train for one epoch train(args, trainer, task, epoch_itr) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) # early stop if should_stop_early(args, valid_losses[0]): logger.info( 'early stop since valid performance hasn\'t improved for last {} runs' .format(args.patience)) break epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=(os.pathsep in getattr(args, 'data', '')), ) logger.info('done training in {:.1f} seconds'.format(train_meter.sum)) # _ = validate_iw(args, trainer, task, epoch_itr, valid_subsets) train_meter.stop()
def is_data_parallel_master(self): return distributed_utils.is_master(self.args)
def train(args, trainer, task, epoch_itr, max_update=math.inf, model=None): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') for i, samples in enumerate(progress): with metrics.aggregate('train_inner'): log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') if(i==0): print('epoch: ', epoch_itr.epoch) endeattn_norm=[] selfattn_norm=[] for m in model.modules(): if(hasattr(m, 'selfattn_norm')): if(m.selfattn_norm != None): selfattn_norm.append(m.selfattn_norm) if(hasattr(m, 'endeattn_norm')): if(m.endeattn_norm != None): endeattn_norm.append(m.endeattn_norm) print('self attention norms: ', selfattn_norm) print('en/decoder attn norms:', endeattn_norm) valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return valid_losses
def main(args): utils.import_user_module(args) assert ( args.max_tokens is not None or args.max_sentences is not None ), "Must specify batch size either with --max-tokens or --max-sentences" metrics.reset() np.random.seed(args.seed) utils.set_torch_seed(args.seed) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) logger.info("criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)) logger.info("num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # (optionally) Configure quantization if args.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=args.quantization_config_path, max_epoch=args.max_epoch, max_update=args.max_update, ) else: quantizer = None # Build trainer if args.model_parallel_size == 1: trainer = Trainer(args, task, model, criterion, quantizer) else: trainer = MegatronTrainer(args, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( args.distributed_world_size)) logger.info( "max tokens per GPU = {} and max sentences per GPU = {}".format( args.max_tokens, args.max_sentences)) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(args, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def main(args): utils.import_user_module(args) assert ( args.max_tokens is not None or args.max_sentences is not None ), "Must specify batch size either with --max-tokens or --max-sentences" metrics.reset() np.random.seed(args.seed) utils.set_torch_seed(args.seed) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) checkpoint_utils.verify_checkpoint_directory(args.jason_log_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) logger.info( "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) ) logger.info( "num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), ) ) # (optionally) Configure quantization if args.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=args.quantization_config_path, max_epoch=args.max_epoch, max_update=args.max_update, ) else: quantizer = None # Build trainer if args.model_parallel_size == 1: trainer = Trainer(args, task, model, criterion, quantizer) else: trainer = MegatronTrainer(args, task, model, criterion) logger.info( "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) ) logger.info( "max tokens per GPU = {} and max sentences per GPU = {}".format( args.max_tokens, args.max_sentences ) ) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( args, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() ##### begin jason ##### updates_list = []; train_ppl_list = []; train_loss_list = []; val_ppl_list = []; val_loss_list = []; train_uid_loss_list = []; val_uid_loss_list = [] log_writer = open(os.path.join(args.save_dir, 'train_logs.csv'), 'w') log_writer.write(f'updates,train_loss,train_ppl,val_loss,val_ppl\n') backup_writefile = os.path.join(args.jason_log_dir, 'train_logs_backup.csv') os.system(f'touch {backup_writefile}') os.system(f'echo "updates,train_loss,train_ppl,val_loss,val_ppl,train_uid_loss,val_uid_loss" >> {backup_writefile}') ##### end jason ##### while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop, train_stats, valid_stats = train(args, trainer, task, epoch_itr) print("hello", valid_stats, train_stats) ##### begin jason ##### if train_stats and valid_stats: updates_list.append(train_stats['num_updates']) train_loss_list.append(train_stats['loss']) train_ppl_list.append(train_stats['ppl']) val_loss_list.append(valid_stats['loss']) val_ppl_list.append(valid_stats['ppl']) if 'uid_loss' not in train_stats: train_stats['uid_loss'] = -1 valid_stats['uid_loss'] = -1 train_uid_loss_list.append(train_stats['uid_loss']) val_uid_loss_list.append(valid_stats['uid_loss']) log_line = f"{train_stats['num_updates']},{train_stats['loss']},{train_stats['ppl']},{valid_stats['loss']},{valid_stats['ppl']},{train_stats['uid_loss']},{valid_stats['uid_loss']}" log_writer.write(f"{log_line}\n") os.system(f'echo "{log_line}" >> {backup_writefile}') best_val_loss = min(val_loss_list) best_val_loss_idx = val_loss_list.index(best_val_loss) updates_to_best_val_loss = updates_list[best_val_loss_idx] train_loss_at_best_val_loss = train_loss_list[best_val_loss_idx] jasons_vis.plot_jasons_lineplot( x_list = updates_list, y_list_list = [train_loss_list, val_loss_list, train_uid_loss_list, val_uid_loss_list], y_labels_list = ['train', 'dev', 'train uid', 'dev uid'], x_ax_label = "Updates", y_ax_label = "Loss", title = f"dev_l={best_val_loss} updates={updates_to_best_val_loss} train_l={train_loss_at_best_val_loss}", output_png_path = os.path.join(args.jason_log_dir, f"{args.jason_log_dir.split('/')[-1]}_loss.png"), ) jasons_vis.plot_jasons_lineplot( x_list = updates_list, y_list_list = [train_ppl_list, val_ppl_list], y_labels_list = ['train', 'dev'], x_ax_label = "Updates", y_ax_label = "Perplexity", title = f" best_val_ppl={best_val_loss} " + args.jason_log_dir[:20], output_png_path = os.path.join(args.jason_log_dir, f"{args.jason_log_dir.split('/')[-1]}_perplexity.png"), ) ##### end jason ##### if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def main(cfg: FairseqConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) if distributed_utils.is_master(cfg.distributed_training): checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args logger.info(cfg) if cfg.checkpoint.write_checkpoints_asynchronously: try: import iopath # noqa: F401 except ImportError: logging.exception( "Asynchronous checkpoint writing is specified but iopath is " "not installed: `pip install iopath`") return # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info("num. model params: {:,} (num. trained: {:,})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=cfg.common.quantization_config_path, max_epoch=cfg.optimization.max_epoch, max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer if cfg.common.model_parallel_size == 1: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size)) logger.info("max tokens per GPU = {} and batch size per GPU = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while epoch_itr.next_epoch_idx <= max_epoch: if lr <= cfg.optimization.stop_min_lr: logger.info( f"stopping training because current learning rate ({lr}) is smaller " "than or equal to minimum learning rate " f"(--stop-min-lr={cfg.optimization.stop_min_lr})") break # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) # ioPath implementation to wait for all asynchronous file writes to complete. if cfg.checkpoint.write_checkpoints_asynchronously: logger.info( "ioPath PathManager waiting for all asynchronous checkpoint " "writes to finish.") PathManager.async_close() logger.info("ioPath PathManager finished waiting.")
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=(args.tensorboard_logdir if distributed_utils.is_master(args) else None), default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) # task specific setup per epoch task.begin_epoch(epoch_itr.epoch, trainer.get_model()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf should_end_training = False for samples in progress: with metrics.aggregate('train_inner'): try: log_output = trainer.train_step(samples) except ResetTrainerException: trainer._wrapped_criterion = None trainer._wrapped_model = None trainer._optimizer = None logger.info("reset the trainer at {}".format( trainer.get_num_updates())) log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats( metrics.get_smoothed_values('train_inner')) progress.log(stats, tag='train_inner', step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets) if should_stop_early(args, valid_losses[0]) or num_updates >= max_update: should_end_training = True break # log end-of-epoch stats stats = get_training_stats(metrics.get_smoothed_values('train')) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') return should_end_training
def main(args): # we should not do this! ''' if args.max_tokens is None: args.max_tokens = 6000 ''' utils.xpprint(args) if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') torch.cuda.set_device(args.device_id) torch.manual_seed(args.seed) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) utils.xprintln('setup task done!') # Load dataset splits load_dataset_splits(args, task, ['train']) valid_dataset = args.valid_subset.split(',') load_dataset_splits(args, task, valid_dataset, shuffle=False) utils.xprintln('load dataset done!') if args.task.startswith('extractive_summarization'): if distributed_utils.is_master(args): from sum_eval import MultiProcSumEval sum_eval_pool = MultiProcSumEval(args.ncpu_eval) sum_valid_pool_params = dict( article_file=args.raw_valid + '.article', summary_file=args.raw_valid + '.summary', entity_map_file=None, length=-1, eval_type='predict', topk=args.topk_sent_eval, rerank=False, with_m=False, cmd='-a -c 95 -m -n 4 -w 1.2', trigram_block=args.trigram_block, ) sum_test_pool_params = dict( article_file=args.raw_test + '.article', summary_file=args.raw_test + '.summary', entity_map_file=None, length=-1, eval_type='predict', topk=args.topk_sent_eval, rerank=False, with_m=False, cmd='-a -c 95 -m -n 4 -w 1.2', trigram_block=args.trigram_block, ) sum_pool_params = dict(valid=sum_valid_pool_params, test=sum_test_pool_params) def make_params(default_dict, result_file, out_rouge_file, rerank=False, with_m=False): para_dict = dict(default_dict) para_dict['result_file'] = result_file para_dict['out_rouge_file'] = out_rouge_file para_dict['rerank'] = rerank para_dict['with_m'] = with_m return para_dict # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {}'.format( sum(p.numel() for p in model.parameters()))) # print(model) import sys sys.stdout.flush() # if summarization try to load pretrained model # if args.task.startswith('extractive_summarization') or args.task == 'pretrain_document_modeling': # # assume this is a single GPU program if args.init_from_pretrained_doc_model: task.load_pretrained_model(model, args.pretrained_doc_model_path) sys.stdout.flush() # Build trainer trainer = Trainer(args, task, model, criterion) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Initialize dataloader max_positions = trainer.get_model().max_positions() epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=False) # Load the latest checkpoint if one is available # load_checkpoint(args, trainer, epoch_itr) # make sure training from a different checkpoint will use different random seed cur_dataset = task.dataset('train') if hasattr(cur_dataset, 'rng'): print('epoch ', epoch_itr.epoch) cur_dataset.rng = numpy.random.RandomState(args.seed + epoch_itr.epoch) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() train_meter = StopwatchMeter() train_meter.start() valid_losses = [None] valid_subsets = args.valid_subset.split(',') for alpha in range(10, 9, -1): # train for one epoch # train(args, trainer, task, epoch_itr) epoch_itr.next_epoch_itr() if epoch_itr.epoch % args.validate_interval == 0: if args.task.startswith('extractive_summarization'): if distributed_utils.is_master(args): validate_metric(args, trainer, task, epoch_itr, valid_subsets)
def setup_training_state(args, trainer, task, epoch_itr): """Set up the directory for saving checkpoints. Load pretrained model if specified.""" os.makedirs(args.save_dir, exist_ok=True) # If --restore-file is already present under --save-dir, use that one # instead of --pretrained-checkpoint-file. The idea is that # --pretrained-checkpoint-file allows the user to specify restoring from a # different run's checkpoint (possibly with different training params), # while not polluting the previous run's checkpoint directory # with new checkpoints. However, if training gets interrupted # and the user restarts training, we want to resume from # the checkpoints under --save-dir, instead of # restarting again from the old run's checkpoint at # --pretrained-checkpoint-file. # # Note that if args.restore_file is an absolute path, os.path.join() will # ignore previous directory args and just use the absolute path as is. checkpoint_path = os.path.join(args.save_dir, args.restore_file) restore_state = True if os.path.isfile(checkpoint_path): print( f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}." ) elif args.pretrained_checkpoint_file and os.path.isfile( args.pretrained_checkpoint_file ): checkpoint_path = args.pretrained_checkpoint_file restore_state = args.load_pretrained_checkpoint_state print( f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, " f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}." ) extra_state = default_extra_state(args) if not os.path.isfile(checkpoint_path) and args.multi_model_restore_files: print(f"| Restoring individual models from {args.multi_model_restore_files}") multi_model.import_individual_models(args.multi_model_restore_files, trainer) else: loaded, loaded_extra_state = checkpoint.load_existing_checkpoint( checkpoint_path=checkpoint_path, trainer=trainer, restore_state=restore_state, ) if loaded_extra_state: extra_state.update(loaded_extra_state) # Reset the start time for the current training run. extra_state["start_time"] = time.time() # Skips printing all training progress to prevent log spam. training_progress = extra_state["training_progress"] extra_state["training_progress"] = ( ["...truncated...", training_progress[-1]] if len(training_progress) > 0 else [] ) print(f"| extra_state: {extra_state}") extra_state["training_progress"] = training_progress epoch = extra_state["epoch"] if extra_state["batch_offset"] == 0: epoch -= 1 # this will be incremented when we call epoch_itr.next_epoch_itr() epoch_itr.load_state_dict( {"epoch": epoch, "iterations_in_epoch": extra_state["batch_offset"]} ) checkpoint_manager = None if distributed_utils.is_master(args): checkpoint_manager = checkpoint.CheckpointManager( num_avg_checkpoints=args.num_avg_checkpoints, auto_clear_checkpoints=args.auto_clear_checkpoints, log_verbose=args.log_verbose, checkpoint_files=extra_state["checkpoint_files"], ) return extra_state, epoch_itr, checkpoint_manager
def save_checkpoint(args, trainer, epoch_itr, val_loss): from fairseq import distributed_utils, meters prev_best = getattr(save_checkpoint, 'best', val_loss) if val_loss is not None: best_function = max if args.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) if args.no_save or not distributed_utils.is_master(args): return def is_better(a, b): return a >= b if args.maximize_best_checkpoint_metric else a <= b write_timer = meters.StopwatchMeter() write_timer.start() epoch = epoch_itr.epoch end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() checkpoint_conds = collections.OrderedDict() checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( end_of_epoch and not args.no_epoch_checkpoints and epoch % args.save_interval == 0 ) checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( not end_of_epoch and args.save_interval_updates > 0 and updates % args.save_interval_updates == 0 ) checkpoint_conds['checkpoint_best.pt'] = ( val_loss is not None and (not hasattr(save_checkpoint, 'best') or is_better(val_loss, save_checkpoint.best)) ) checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints extra_state = { 'train_iterator': epoch_itr.state_dict(), 'val_loss': val_loss, } if hasattr(save_checkpoint, 'best'): extra_state.update({'best': save_checkpoint.best}) checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: shutil.copyfile(checkpoints[0], cp) write_timer.stop() print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format( checkpoints[0], epoch, updates, write_timer.sum)) if not end_of_epoch and args.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt', ) for old_chk in checkpoints[args.keep_interval_updates:]: if os.path.lexists(old_chk): os.remove(old_chk) if args.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( args.save_dir, pattern=r'checkpoint(\d+)\.pt', ) for old_chk in checkpoints[args.keep_last_epochs:]: if os.path.lexists(old_chk): os.remove(old_chk)
def main(args): import_user_module(args) assert ( args.max_tokens is not None or args.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(args.seed) utils.set_torch_seed(args.seed) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) logger.info("criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)) logger.info("num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # breakpoint() # ========== initialize the model with pretrained BART parameters ========== # for shared embeddings and subtoken split for amr nodes if 'bartsv' in args.arch: if args.initialize_with_bart: logger.info( '-' * 10 + ' initializing model parameters with pretrained BART model ' + '-' * 10) new_state_dict = copy.deepcopy(task.bart.model.state_dict()) # treat the embedding initialization separately later, as the size different logger.info( '-' * 10 + ' delay encoder embeddings, decoder input and output embeddings initialization ' + '-' * 10) ignore_keys = set([ 'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'decoder.output_projection.weight' ]) for k in ignore_keys: del new_state_dict[k] if not args.initialize_with_bart_enc: logger.info( '-' * 10 + ' do not initialize with BART encoder parameters ' + '-' * 10) for k in list(new_state_dict.keys()): if k.startswith('encoder'): del new_state_dict[k] if not args.initialize_with_bart_dec: logger.info( '-' * 10 + ' do not initialize with BART decoder parameters ' + '-' * 10) for k in list(new_state_dict.keys()): if k.startswith('decoder'): del new_state_dict[k] model.load_state_dict(new_state_dict, strict=False, args=args) # initialize the Bart part embeddings bart_vocab_size = task.target_dictionary.bart_vocab_size # NOTE we need to prune the pretrained BART embeddings, especially for bart.base bart_embed_weight = task.bart.model.encoder.embed_tokens.weight.data[: bart_vocab_size] assert len(bart_embed_weight) == bart_vocab_size with torch.no_grad(): model.encoder.embed_tokens.weight[:bart_vocab_size].copy_( bart_embed_weight) model.decoder.embed_tokens.weight[:bart_vocab_size].copy_( bart_embed_weight) model.decoder.output_projection.weight[:bart_vocab_size].copy_( bart_embed_weight) if args.bart_emb_init_composition: logger.info( '-' * 10 + ' initialize extended target embeddings with compositional embeddings ' 'from BART vocabulary ' + '-' * 10) # breakpoint() symbols = [ task.target_dictionary[idx] for idx in range(bart_vocab_size, len(task.target_dictionary)) ] mapper = MapAvgEmbeddingBART(task.bart, task.bart.model.decoder.embed_tokens) comp_embed_weight, map_all = mapper.map_avg_embeddings( symbols, transform=transform_action_symbol, add_noise=False) assert len(comp_embed_weight) == len(symbols) with torch.no_grad(): model.encoder.embed_tokens.weight[bart_vocab_size:].copy_( comp_embed_weight) model.decoder.embed_tokens.weight[bart_vocab_size:].copy_( comp_embed_weight) model.decoder.output_projection.weight[bart_vocab_size:].copy_( comp_embed_weight) elif 'bart' in args.arch: if args.initialize_with_bart: logger.info( '-' * 10 + ' initializing model parameters with pretrained BART model ' + '-' * 10) new_state_dict = copy.deepcopy(task.bart.model.state_dict()) if not args.bart_emb_decoder: logger.info('-' * 10 + ' build a separate decoder dictionary embedding ' + '-' * 10) if not args.bart_emb_decoder_input: ignore_keys = set([ 'decoder.embed_tokens.weight', 'decoder.output_projection.weight' ]) else: logger.info( '-' * 10 + ' use BART dictionary embedding for target input ' + '-' * 10) ignore_keys = set(['decoder.output_projection.weight']) for k in ignore_keys: del new_state_dict[k] if not args.initialize_with_bart_enc: logger.info( '-' * 10 + ' do not initialize with BART encoder parameters ' + '-' * 10) for k in list(new_state_dict.keys()): if k.startswith('encoder'): del new_state_dict[k] if not args.initialize_with_bart_dec: logger.info( '-' * 10 + ' do not initialize with BART decoder parameters ' + '-' * 10) for k in list(new_state_dict.keys()): if k.startswith('decoder'): del new_state_dict[k] model.load_state_dict(new_state_dict, strict=False, args=args) # initialize the target embeddings with average of subtoken embeddings in BART vocabulary if args.bart_emb_init_composition: assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of BART vocabulary here' logger.info( '-' * 10 + ' initialize target embeddings with compositional embeddings from BART vocabulary ' + '-' * 10) composite_embed = CompositeEmbeddingBART( task.bart, task.bart.model.decoder.embed_tokens, task.target_dictionary) if args.bart_emb_decoder_input: # only initialize the decoder output embeddings with torch.no_grad(): model.decoder.output_projection.weight.copy_( composite_embed.embedding_weight) else: # initialize both the decoder input and output embeddings with torch.no_grad(): model.decoder.embed_tokens.weight.copy_( composite_embed.embedding_weight) model.decoder.output_projection.weight.copy_( composite_embed.embedding_weight) elif 'roberta' in args.arch: # initialize the target embeddings with average of subtoken embeddings in BART vocabulary if args.bart_emb_init_composition: assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of RoBERTa vocabulary here' logger.info( '-' * 10 + ' initialize target embeddings with compositional embeddings from RoBERTa vocabulary ' + '-' * 10) composite_embed = CompositeEmbeddingBART( task.bart, # NOTE here "bart" means roberta task.bart.model.encoder.sentence_encoder.embed_tokens, task.target_dictionary) if args.bart_emb_decoder_input: # only initialize the decoder output embeddings with torch.no_grad(): model.decoder.output_projection.weight.copy_( composite_embed.embedding_weight) else: # initialize both the decoder input and output embeddings with torch.no_grad(): model.decoder.embed_tokens.weight.copy_( composite_embed.embedding_weight) model.decoder.output_projection.weight.copy_( composite_embed.embedding_weight) else: raise ValueError # ========================================================================== # breakpoint() # (optionally) Configure quantization if args.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=args.quantization_config_path, max_epoch=args.max_epoch, max_update=args.max_update, ) else: quantizer = None # Build trainer if args.model_parallel_size == 1: trainer = Trainer(args, task, model, criterion, quantizer) else: trainer = MegatronTrainer(args, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( args.distributed_world_size)) logger.info( "max tokens per GPU = {} and max sentences per GPU = {}".format( args.max_tokens, args.batch_size)) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( args, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(args, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum))