def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus): """Evaluate the model on the validation set and return the average loss.""" itr = dataset.eval_dataloader( subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator ) loss_meter = AverageMeter() extra_meters = collections.defaultdict(lambda: AverageMeter()) prefix = 'valid on \'{}\' subset'.format(subset) with utils.build_progress_bar(args, itr, epoch, prefix) as t: for _, sample in data.skip_group_enumerator(t, ngpus): loss_dict = trainer.valid_step(sample) loss = loss_dict['loss'] del loss_dict['loss'] # don't include in extra_meters or extra_postfix ntokens = sum(s['ntokens'] for s in sample) loss_meter.update(loss, ntokens) extra_postfix = [] for k, v in loss_dict.items(): extra_meters[k].update(v) extra_postfix.append((k, extra_meters[k].avg)) t.log(collections.OrderedDict([ ('valid loss', round(loss_meter.avg, 2)), ] + extra_postfix)) t.print(collections.OrderedDict([ ('valid loss', round(loss_meter.avg, 2)), ('valid ppl', get_perplexity(loss_meter.avg)), ] + [ (k, meter.avg) for k, meter in extra_meters.items() ])) # update and return the learning rate return loss_meter.avg
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf if callable(getattr(trainer.criterion, 'set_epoch', None)): trainer.criterion.set_epoch(epoch_itr.epoch) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): if callable(getattr(trainer.criterion, 'set_num_updates', None)): trainer.criterion.set_num_updates(trainer.get_num_updates()) log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def validate(args, trainer, task, epoch_itr, subsets, sampled_arch_name): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] for subset in subsets: # Initialize data iterator def get_itr(): itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='validate on \'{}\' subset'.format(subset), ) return progress progress = get_itr() # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue extra_meters[k].update(v) # log validation stats stats = utils.get_valid_stats(trainer, args) for k, meter in extra_meters.items(): stats[k] = meter.avg # log validation stats stats = utils.get_valid_stats(trainer, args, extra_meters) stats[sampled_arch_name + '_loss'] = deepcopy(stats['loss']) stats[sampled_arch_name + '_nll_loss'] = deepcopy(stats['nll_loss']) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric].avg if args. best_checkpoint_metric == 'loss' else stats[args.best_checkpoint_metric]) return valid_losses
def validate(args, trainer, dataset, subset, extra_state): """Evaluate the model on the validation set and return the average loss.""" epoch = extra_state["epoch"] # Initialize dataloader max_positions_valid = ( trainer.get_model().max_encoder_positions(), trainer.get_model().max_decoder_positions(), ) itr = dataset.eval_dataloader( subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar( args, itr, epoch, prefix=f"valid on '{subset}' subset", no_progress_bar="simple") # reset validation loss meters for k in ["valid_loss", "valid_nll_loss"]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) # log mid-validation stats stats = get_valid_stats(trainer) for k, v in log_output.items(): if k in ["loss", "nll_loss"]: continue if "loss" in k: extra_meters[k].update(v, log_output["sample_size"]) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) val_loss = stats["valid_loss"] val_ppl = stats["valid_ppl"] if ("validate" not in extra_state or val_loss < extra_state["validate"]["lowest_loss"]): extra_state["validate"] = { "lowest_loss": val_loss, "num_since_best": 0 } else: extra_state["validate"]["num_since_best"] += 1 stop_due_to_val_loss = False if (args.stop_no_best_validate_loss >= 0 and extra_state["validate"]["num_since_best"] > args.stop_no_best_validate_loss): stop_due_to_val_loss = True print( f"Stopping training due to validation score stagnation - last best " f"validation loss of {extra_state['validate']['lowest_loss']} (current loss: {val_loss})" f"was {extra_state['validate']['num_since_best']} validations ago." ) return val_loss, val_ppl, stop_due_to_val_loss
def init_meters(self, args): self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm if args.sep_optim: self.meters['dec_gnorm'] = AverageMeter( ) # gradient norm for decoder self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory if args.fp16: self.meters['loss_scale'] = AverageMeter() # dynamic loss scale self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['train_wall'] = StopwatchMeter( ) # train wall time in seconds
def validate(args, trainer, task, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) cnt = 0 if args.distributed_world_size > 1: fout_hidden = './tmp/hidden_{}.h5'.format(args.distributed_rank) fout_target = './tmp/target_{}.h5'.format(args.distributed_rank) else: fout_hidden = './tmp/hidden.h5' fout_target = './tmp/target.h5' fout_hidden, hidden_list = open_h5(fout_hidden, 1024) fout_target, target_list = open_h5(fout_target, 1) for sample in progress: record, log_output = trainer.valid_step(sample) hidden_list.append(record[0].cpu().numpy().astype('float16')) target_list.append(record[1].cpu().numpy().astype('float16')) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue extra_meters[k].update(v) cnt += 1 if (cnt > 10): break # log validation stats fout_hidden.close() fout_target.close() stats = get_valid_stats(trainer, args, extra_meters) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) return valid_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix="valid on '{}' subset".format(subset), no_progress_bar="simple", ) # reset validation loss meters for k in ["valid_loss", "valid_nll_loss"]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ "loss", "nll_loss", "ntokens", "nsentences", "sample_size" ]: continue extra_meters[k].update(v) # log validation stats stats = get_valid_stats(trainer, args, extra_meters) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric].avg if args. best_checkpoint_metric == "loss" else stats[args.best_checkpoint_metric]) return valid_losses
args_transformer.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args_transformer.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args_transformer.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args_transformer, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) for count, ind in enumerate(inds): # Get the coordinates of the loss value being calculated coord = coords[count] dx = directions[0] dy = directions[1] changes = [d0 * coord[0] + d1 * coord[1] for (d0, d1) in zip(dx, dy)] new_states = copy.deepcopy(states) assert (len(new_states) == len(changes)) for (k, v), d in zip(new_states.items(), changes): d = torch.tensor(d) v.add_(d.type(v.type())) ## upload the weight model.load_state_dict(new_states)
def validate(args, trainer, task, epoch_itr, subsets, force_refine_step=None): """Evaluate the model on the validation set(s) and return the losses.""" valid_random = np.random.RandomState(3) valid_task_random = np.random.RandomState(3) if not hasattr(task, "random"): task.random = None task_random_bak = task.random task.random = valid_task_random valid_losses = [] for subset in subsets: # Initialize data iterator dataset = task.dataset(subset) if hasattr(dataset, "random"): random_bak = dataset.random else: random_bak = None dataset.random = valid_random set_valid_tokens(task, dataset, trainer, args) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) if hasattr(args, "progressive") and args.progressive: dataset.set_random_refine_step(args.refinetot, force_refine_step=force_refine_step) for sample in progress: if trainer._oom_batch is None: trainer._oom_batch = sample if sample is None or len(sample) == 0: sys.stderr.write("empty valid sample detected\n") sys.stderr.flush() break log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue extra_meters[k].update(v) if hasattr(args, "progressive") and args.progressive: dataset.set_random_refine_step( args.refinetot, force_refine_step=force_refine_step) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric] if type(stats[args.best_checkpoint_metric]) == float else stats[args.best_checkpoint_metric].avg) dataset.random = random_bak if HAS_WANDB and distributed_utils.is_master(args): stat_dict = {} for k, v in stats.items(): if isinstance(v, AverageMeter): stat_dict[k] = v.val else: stat_dict[k] = v wandb.log(stat_dict, step=trainer.get_num_updates()) task.random = task_random_bak return valid_losses
def train(args, trainer, task, epoch_itr, force_refine_step=None): """Train the model for one epoch.""" # Update parameters every N batches def is_better(a, b): return a > b if args.maximize_best_checkpoint_metric else a < b update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf if hasattr(args, "progressive") and args.progressive: task.dataset("train").set_random_refine_step( args.refinetot, force_refine_step=force_refine_step) last_samples = None last_best_update = 0 for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): if samples is None or len(samples) == 0: sys.stderr.write("Empty sample detected\n") sys.stderr.flush() samples = last_samples else: last_samples = samples log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets, force_refine_step=force_refine_step) # if distributed_utils.is_master(args): # print("saving:", trainer.get_num_updates()) # nsml.save(str(trainer.get_num_updates())) if not hasattr(checkpoint_utils.save_checkpoint, 'best') or is_better( valid_losses[0], checkpoint_utils.save_checkpoint.best): last_best_update = num_updates if distributed_utils.is_master(args): print("saving checkpoint ...") sys.stdout.flush() if getattr(args, "save_path", False) and len(args.save_path) > 0: if not os.path.exists(args.save_path): os.mkdir(args.save_path) torch.save({"model": trainer.get_model().state_dict()}, "{}/best.pt".format(args.save_path)) elif HAS_NSML: nsml.save("best") else: torch.save({"model": trainer.get_model().state_dict()}, "/tmp/best.pt") if HAS_WANDB: wandb.save("/tmp/best.pt") sys.stdout.flush() checkpoint_utils.save_checkpoint.best = valid_losses[0] if args.decoder_wise_training and update_num_to_refine_step( num_updates) != force_refine_step: if HAS_NSML: nsml.load("best") else: # Retrieve the model if distributed_utils.is_master(args): state = torch.load("/tmp/best.pt", map_location="cpu") trainer.model.load_state_dict(state["model"]) # Sync assert isinstance(trainer.model, parallel.DistributedDataParallel) if isinstance(trainer.model, parallel.DistributedDataParallel): trainer.model._sync_params() checkpoint_utils.save_checkpoint.best = 0. force_refine_step = update_num_to_refine_step(num_updates) trainer.criterion.pool.clear() print("| Start refinement step:", force_refine_step) if num_updates >= max_update: break if args.early_stop and num_updates - last_best_update >= 3000: if distributed_utils.is_master(args): print("early stop") setattr(args, "early_stopping", True) break if hasattr(args, "progressive") and args.progressive: task.dataset("train").set_random_refine_step( args.refinetot, force_refine_step=force_refine_step) # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr, shuffling_seeds): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr() progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple') # update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] if args.enable_parallel_backward_allred_opt and update_freq > 1: raise RuntimeError( '--enable-parallel-backward-allred-opt is incompatible with --update-freq > 1' ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) if args.time_step: begin = time.time() end = time.time() count = 0 #profile_count = 13 profile_count = 10000000000 for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch): if args.time_step: start_step = time.time() if i < num_batches - 1 and (i + 1) % update_freq > 0: # buffer updates according to --update-freq trainer.train_step(sample, update_params=False, last_step=(i == len(itr) - 1)) continue else: log_output = trainer.train_step(sample, update_params=True, last_step=(i == len(itr) - 1)) if args.time_step: end_step = time.time() #if count > 10 and sample['target'].size(0) > 248 : seqs = sample['target'].size(0) srclen = sample['net_input']['src_tokens'].size(1) tgtlen = sample['target'].size(1) srcbatch = srclen * seqs tgtbatch = tgtlen * seqs #print("ITER {}> Seqs: {} SrcLen: {} TgtLen: {} Src Batch: {} Tgt Batch {}".format( count, seqs, srclen, tgtlen, srcbatch, tgtbatch)) print("ITER {}> Seqs: {} SrcLen: {} TgtLen: {} Total Time: {:.3} Step Time: {:.3} Load Time: {:.3}".format( \ count, \ sample['target'].size(0), \ sample['net_input']['src_tokens'].size(1), \ sample['target'].size(1), \ (end_step-begin)*1000.0, \ (end_step-start_step)*1000.0, \ (start_step-end)*1000.0)) count += 1 begin = time.time() # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'sample_size']: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() if args.profile is not None and i == args.profile: import sys sys.exit() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid], shuffling_seeds) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break if args.time_step: end = time.time() # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip' ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, epoch, batch_offset, trainer, dataset, max_positions): """Train the model for one epoch.""" seed = args.seed + epoch torch.manual_seed(seed) trainer.set_seed(seed) itr = dataset.train_dataloader( args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, seed=seed, epoch=epoch, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch <= args.curriculum)) loss_meter = AverageMeter() nll_loss_meter = AverageMeter() bsz_meter = AverageMeter() # sentences per batch wpb_meter = AverageMeter() # words per batch wps_meter = TimeMeter() # words per second clip_meter = AverageMeter() # % of updates clipped extra_meters = collections.defaultdict(lambda: AverageMeter()) lr = trainer.get_lr() with utils.build_progress_bar(args, itr, epoch) as t: for i, sample in data.skip_group_enumerator(t, args.num_gpus, batch_offset): loss_dict = trainer.train_step(sample) loss = loss_dict['loss'] del loss_dict['loss'] # don't include in extra_meters or extra_postfix ntokens = sum(s['ntokens'] for s in sample) if 'nll_loss' in loss_dict: nll_loss = loss_dict['nll_loss'] nll_loss_meter.update(nll_loss, ntokens) nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample) loss_meter.update(loss, nsentences if args.sentence_avg else ntokens) bsz_meter.update(nsentences) wpb_meter.update(ntokens) wps_meter.update(ntokens) clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0) extra_postfix = [] for k, v in loss_dict.items(): extra_meters[k].update(v) extra_postfix.append((k, extra_meters[k].avg)) t.log(collections.OrderedDict([ ('loss', loss_meter), ('wps', round(wps_meter.avg)), ('wpb', round(wpb_meter.avg)), ('bsz', round(bsz_meter.avg)), ('lr', lr), ('clip', '{:.0%}'.format(clip_meter.avg)), ] + extra_postfix)) if i == 0: # ignore the first mini-batch in words-per-second calculation wps_meter.reset() if args.save_interval > 0 and (i + 1) % args.save_interval == 0: save_checkpoint(trainer, args, epoch, i + 1) t.print(collections.OrderedDict([ ('train loss', round(loss_meter.avg, 2)), ('train ppl', get_perplexity(nll_loss_meter.avg if nll_loss_meter.count > 0 else loss_meter.avg)), ('s/checkpoint', round(wps_meter.elapsed_time)), ('words/s', round(wps_meter.avg)), ('words/batch', round(wpb_meter.avg)), ('bsz', round(bsz_meter.avg)), ('lr', lr), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ] + [ (k, meter.avg) for k, meter in extra_meters.items() ]))
def estimate_head_importance(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus) if args.n_pruning_steps > 0: itr = islice(itr, args.n_pruning_steps) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # Inititalize meters extra_meters = collections.defaultdict(lambda: AverageMeter()) # Initialize head importance scores encoder_layers = trainer.args.encoder_layers decoder_layers = trainer.args.decoder_layers encoder_heads = trainer.args.encoder_attention_heads decoder_heads = trainer.args.decoder_attention_heads device = next(trainer.model.parameters()).device head_importance = { "encoder_self": torch.zeros(encoder_layers, encoder_heads).to(device), "encoder_decoder": torch.zeros(decoder_layers, decoder_heads).to(device), "decoder_self": torch.zeros(decoder_layers, decoder_heads).to(device), } # Denominators to normalize properly denoms = { attn_type: val.clone() for attn_type, val in head_importance.items() } head_stats = { attn_type: [{} for _ in range(val.size(0))] for attn_type, val in head_importance.items() } for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): # Compute gradients log_output = trainer.prune_step(samples) # Retrieve importance scores for the encoder for layer in range(encoder_layers): self_attn_variables = trainer.model.encoder.layers[ layer].self_attn_variables importance, denom = batch_head_importance(self_attn_variables, one_minus=args.one_minus) head_importance["encoder_self"][layer] += importance denoms["encoder_self"][layer] += denom # Stats aggregate_stats(head_stats["encoder_self"][layer], batch_head_stats(self_attn_variables)[0]) # Retrieve importance scores for the decoder for layer in range(decoder_layers): # Self attention self_attn_variables = trainer.model.decoder.layers[ layer].self_attn_variables importance, denom = batch_head_importance(self_attn_variables, one_minus=args.one_minus) head_importance["decoder_self"][layer] += importance denoms["decoder_self"][layer] += denom aggregate_stats( head_stats["decoder_self"][layer], batch_head_stats(self_attn_variables, triu_masking=True)[0]) # Encoder attention encoder_attn_variables = trainer.model.decoder.layers[ layer].encoder_attn_variables importance, denom = batch_head_importance(encoder_attn_variables, one_minus=args.one_minus) head_importance["encoder_decoder"][layer] += importance denoms["encoder_decoder"][layer] += denom aggregate_stats(head_stats["encoder_decoder"][layer], batch_head_stats(encoder_attn_variables)[0]) # log mid-epoch stats stats = get_pruning_stats(trainer) for k, v in log_output.items(): extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() # log end-of-epoch stats stats = get_pruning_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # Normalize by type for attn_type in denoms: head_importance[attn_type] /= denoms[attn_type] # Normalize head stats for attn_type in denoms: for layer in range(len(head_stats[attn_type])): for key in head_stats[attn_type][layer]: head_stats[attn_type][layer][key] /= denoms[attn_type].mean( ).cpu() # Normalize by layer if args.normalize_by_layer: for layer in range(encoder_layers): for attn_type, importance in head_importance.items(): head_importance[attn_type][layer] /= torch.sqrt( torch.sum(importance[layer]**2)) return {k: v.cpu() for k, v in head_importance.items()}, head_stats
def main(): args = parser.parse_args() # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) model = TransformerModel.build_model(args, task).cuda() criterion = task.build_criterion(args).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=eval(args.adam_betas), eps=args.adam_eps, weight_decay=args.weight_decay) # Load dataset splits load_dataset_splits(task, ['train', 'valid']) epoch_itr = data.EpochBatchIterator( dataset=task.dataset(args.train_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=(args.max_source_positions, args.max_target_positions), ignore_invalid_inputs=True, required_batch_size_multiple=8, seed=1, num_shards=1, shard_id=0, ) losses = AverageMeter() encoder_layer_forward = [ AverageMeter() for _ in range(len(model.encoder.layers[0].layer)) ] decoder_layer_forward = [ AverageMeter() for _ in range(len(model.decoder.layers[0].layer)) ] encoder_layer_backward = [ AverageMeter() for _ in range(len(model.encoder.layers[0].layer)) ] decoder_layer_backward = [ AverageMeter() for _ in range(len(model.decoder.layers[0].layer)) ] def measure_hook(forward, backward): def hook(module, input, output): for i, layer in enumerate(module.layer): if len(input) == 2: x, _ = input else: x, = input x = x.detach().clone().requires_grad_() # warm-up for _ in range(5): if isinstance(layer, nn.MultiheadAttention): out, _ = layer(x, x, x) else: out = layer(x) torch.autograd.backward(out, out) starter, ender = torch.cuda.Event( enable_timing=True), torch.cuda.Event(enable_timing=True) for _ in range(50): starter.record() if isinstance(layer, nn.MultiheadAttention): out, _ = layer(x, x, x) else: out = layer(x) ender.record() torch.cuda.synchronize() forward[i].update(starter.elapsed_time(ender)) starter.record() torch.autograd.backward(out, out) ender.record() torch.cuda.synchronize() backward[i].update(starter.elapsed_time(ender)) return hook for layer in model.encoder.layers: layer.register_forward_hook( measure_hook(encoder_layer_forward, encoder_layer_backward)) for layer in model.decoder.layers: layer.register_forward_hook( measure_hook(decoder_layer_forward, decoder_layer_backward)) embed_forward = AverageMeter() embed_backward = AverageMeter() def embed_hook(module, input, output): tokens, _ = input # warm-up for _ in range(5): x = module.embed_scale * module.embed_tokens(tokens) x += module.embed_positions(tokens) torch.autograd.backward(x, x) starter, ender = torch.cuda.Event( enable_timing=True), torch.cuda.Event(enable_timing=True) for _ in range(50): starter.record() x = module.embed_scale * module.embed_tokens(tokens) x += module.embed_positions(tokens) ender.record() torch.cuda.synchronize() embed_forward.update(starter.elapsed_time(ender)) starter.record() torch.autograd.backward(x, x) ender.record() torch.cuda.synchronize() embed_backward.update(starter.elapsed_time(ender)) model.encoder.register_forward_hook(embed_hook) linear_forward = AverageMeter() linear_backward = AverageMeter() def linear_hook(module, input, output): _, encode_out = input encode_out = encode_out.detach().clone().requires_grad_() # warm-up for _ in range(5): x = encode_out.transpose(0, 1) out = F.linear(x, module.embed_out) torch.autograd.backward(out, out) starter, ender = torch.cuda.Event( enable_timing=True), torch.cuda.Event(enable_timing=True) for _ in range(50): starter.record() x = encode_out.transpose(0, 1) out = F.linear(x, module.embed_out) ender.record() torch.cuda.synchronize() linear_forward.update(starter.elapsed_time(ender)) starter.record() torch.autograd.backward(out, out) ender.record() torch.cuda.synchronize() linear_backward.update(starter.elapsed_time(ender)) model.decoder.register_forward_hook(linear_hook) itr = epoch_itr.next_epoch_itr() max_positions = (args.max_source_positions, args.max_target_positions) for i, sample in enumerate(itr): sample = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) sample = utils.move_to_cuda(sample) loss, _, logging_output = criterion(model, sample) num_tokens = logging_output['ntokens'] losses.update(loss.item() / num_tokens / math.log(2), num_tokens) if i % 100 == 0: print('Loss: {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses)) print( 'Time: {forward_time.avg:.3f} ({backward_time.avg:.3f})' '{forward_time_decoder.avg:.3f} ({backward_time_decoder.avg:.3f})' .format(forward_time=encoder_layer_forward[0], backward_time=encoder_layer_backward[0], forward_time_decoder=decoder_layer_forward[-1], backward_time_decoder=decoder_layer_backward[-1])) loss.backward() optimizer.step() optimizer.zero_grad() break stat = {i: {} for i in range(len(decoder_layer_forward))} for i, (f, b) in enumerate(zip(encoder_layer_forward, encoder_layer_backward)): stat[i]['encoder'] = {} stat[i]['encoder']['forward'] = f.avg stat[i]['encoder']['backward'] = b.avg for i, (f, b) in enumerate(zip(decoder_layer_forward, decoder_layer_backward)): stat[i]['decoder'] = {} stat[i]['decoder']['forward'] = f.avg stat[i]['decoder']['backward'] = b.avg stat['embed'] = {} stat['embed']['forward'] = embed_forward.avg stat['embed']['backward'] = embed_backward.avg stat['linear'] = {} stat['linear']['forward'] = linear_forward.avg stat['linear']['backward'] = linear_backward.avg with open('time.json', 'w') as file: json.dump(stat, file, indent=4)
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): """Train the model for one epoch.""" seed = args.seed + epoch torch.manual_seed(seed) trainer.set_seed(seed) itr = dataset.train_dataloader( args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, seed=seed, epoch=epoch, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch <= args.curriculum)) loss_meter = AverageMeter() bsz_meter = AverageMeter() # sentences per batch wpb_meter = AverageMeter() # words per batch wps_meter = TimeMeter() # words per second clip_meter = AverageMeter() # % of updates clipped extra_meters = collections.defaultdict(lambda: AverageMeter()) lr = trainer.get_lr() with utils.build_progress_bar(args, itr, epoch) as t: for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): loss_dict = trainer.train_step(sample) loss = loss_dict['loss'] del loss_dict['loss'] # don't include in extra_meters or extra_postfix ntokens = sum(s['ntokens'] for s in sample) nsentences = sum(s['src_tokens'].size(0) for s in sample) loss_meter.update(loss, nsentences if args.sentence_avg else ntokens) bsz_meter.update(nsentences) wpb_meter.update(ntokens) wps_meter.update(ntokens) clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0) extra_postfix = [] for k, v in loss_dict.items(): extra_meters[k].update(v) extra_postfix.append((k, extra_meters[k].avg)) t.log(collections.OrderedDict([ ('loss', loss_meter), ('wps', round(wps_meter.avg)), ('wpb', round(wpb_meter.avg)), ('bsz', round(bsz_meter.avg)), ('lr', lr), ('clip', '{:.0%}'.format(clip_meter.avg)), ] + extra_postfix)) if i == 0: # ignore the first mini-batch in words-per-second calculation wps_meter.reset() if args.save_interval > 0 and (i + 1) % args.save_interval == 0: save_checkpoint(trainer, args, epoch, i + 1) t.print(collections.OrderedDict([ ('train loss', round(loss_meter.avg, 2)), ('train ppl', get_perplexity(loss_meter.avg)), ('s/checkpoint', round(wps_meter.elapsed_time)), ('words/s', round(wps_meter.avg)), ('words/batch', round(wpb_meter.avg)), ('bsz', round(bsz_meter.avg)), ('lr', lr), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ] + [ (k, meter.avg) for k, meter in extra_meters.items() ]))
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr_groups = OrderedDict() progress_len = 0 for domain in itr.keys(): itr_groups[domain] = iterators_dtn.GroupedIteratorDtn( itr[domain], update_freq) progress_len = progress_len + len(itr_groups[domain]) progress = progress_bar.build_progress_bar( args, range(progress_len), epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] valid_select = args.valid_select[0] max_update = args.max_update or math.inf def sample_a_training_set(prob=None): train_idx = np.random.choice(np.arange(len(prob)), p=prob) return train_idx train_domains = args.train_domains if args.random_select: lens_map = OrderedDict() count = 0.0 for domain in train_domains: lens_map[domain] = len(itr_groups[domain]) count += lens_map[domain] prob_map = OrderedDict() for domain in train_domains: prob_map[domain] = lens_map[domain] / count keys = [] probs = [] for key, value in prob_map.items(): keys.append(key) probs.append(value) probs_new = [] probs_norm = 0.0 for prob in probs: probs_norm += prob**args.random_select_factor for prob in probs: probs_new.append(prob**args.random_select_factor / probs_norm) for i, _ in enumerate(progress, start=epoch_itr.iterations_in_epoch): if args.random_select: train_idx = sample_a_training_set(probs_new) domain = keys[train_idx] else: domain = train_domains[i % len(train_domains)] samples = next(itr_groups[domain]) # domain, samples = sample_a_training_set(itr_groups) log_output = trainer.train_step(samples, domain=domain) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg stats['dataset'] = domain progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.validate_interval_updates > 0 and num_updates % args.validate_interval_updates == 0 and num_updates > 0: valid_losses, valid_bleus = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses, valid_bleus, valid_select) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def __init__(self, args, task, model, criterion, allreduce_communicators=None): super().__init__(args, task, model, criterion, allreduce_communicators) # convert model to FP16 (but keep criterion FP32) self.model.half() # dynamically scale loss to reduce overflow self.scaler = DynamicLossScaler(init_scale=2.**7) self.meters['loss_scale'] = AverageMeter() self.grad_denom = 1.0 if self.args.enable_parallel_backward_allred_opt: import numpy as np self._flat_grads_parallel = torch.tensor( [], dtype=torch.float16).cuda() self._grads_info = [] grads_size = 0 p_offset = 0 for p_i, p in enumerate( [p for p in self.model.parameters() if p.requires_grad]): p_grads_size = np.prod(list(p.size())) grads_size += p_grads_size # register hooks def wrapper(param, param_i, param_grads_size, param_offset): def allreduce_hook(grad): self._do_allreduce(param_i, param_grads_size, param_offset, grad) if param.requires_grad: param.register_hook(allreduce_hook) # print(p_i, p.size(), p_grads_size, p_offset) self._grads_info.append({ "param_grads_size": p_grads_size, "param_offset": p_offset }) wrapper(p, p_i, p_grads_size, p_offset) p_offset += p_grads_size self._flat_grads_parallel.resize_(grads_size) # print(grads_size, len(self._flat_grads_parallel), self._flat_grads_parallel.dtype, self._flat_grads_parallel.get_device()) self._allreduce_flush_min_threshold = self.args.parallel_backward_allred_opt_threshold print("| parallel all-reduce ENABLED. all-reduce threshold: " + str(self._allreduce_flush_min_threshold)) self._grads_generated = [False] * len(self._grads_info) self._allreduce_processed_idx = len(self._grads_info) - 1 self._num_allreduce_sent = 0 print("| # of parallel all-reduce cuda streams: " + str(self.args.parallel_backward_allred_cuda_nstreams)) if allreduce_communicators: self._allreduce_groups = allreduce_communicators[0] self._allreduce_streams = allreduce_communicators[1] else: raise RuntimeError( 'Moved communicator init before RUN_START (invalid code path)' ) self._allreduce_groups = [ torch.distributed.new_group() for _ in range( self.args.parallel_backward_allred_cuda_nstreams) ] self._allreduce_streams = [ torch.cuda.Stream() for _ in range( self.args.parallel_backward_allred_cuda_nstreams) ] if self.args.enable_parallel_backward_allred_opt_correctness_check: self._num_grads_generated = 0 self._all_grads_generated = False self._allreduce_schedule = []
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = OrderedDict() valid_bleus = OrderedDict() valid_select = task.args.valid_select[0] assert len(subsets) == 1 for subset in subsets: # Initialize data iterator valid_loss_all = [] valid_nll_loss_all = [] valid_bleu_all = [] for k in ['valid_loss', 'valid_nll_loss', 'valid_bleu']: meter = trainer.get_meter(k + '_all') meter.reset() for domain, data_valid in task.dataset(subset).items(): itr = task.get_batch_iterator_valid( dataset=data_valid, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset \'{}\' domain'.format( subset, domain), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss', 'valid_bleu']: meter = trainer.get_meter(k + '_' + domain) meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) src_target_hypo_strs = [] for sample in progress: log_output, src_target_hypo_str = trainer.valid_step( sample, domain=domain) src_target_hypo_strs.extend(src_target_hypo_str) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue extra_meters[k].update(v) src_target_hypo_strs_filter = [] for sents in src_target_hypo_strs: for sent in sents: if sent is None or len(sent) == 0: continue src_target_hypo_strs_filter.append(sent) src_target_hypo_strs_filter = sorted(src_target_hypo_strs_filter, key=lambda elem: int(elem[0]), reverse=False) if args.valid_decoding_path is not None: with open( os.path.join( args.valid_decoding_path, domain, 'decoding_{}.txt'.format(args.distributed_rank)), 'w') as f: for sent in src_target_hypo_strs_filter: if len(sent) == 0: continue f.write(sent[-1] + '\n') num_ref = args.num_ref[domain] ref_path = [] for i in range(int(num_ref)): ref_path.append( os.path.join(args.valid_decoding_path, domain, 'valid.tok.' + args.target_lang + str(i))) valid_decoding_path = os.path.join( args.valid_decoding_path, domain, 'decoding_{}.txt'.format(args.distributed_rank)) with open(valid_decoding_path) as out_file: out_file.seek(0) res = subprocess.check_output( 'perl %s/multi-bleu.perl %s' % (args.multi_bleu_path, ' '.join(ref_path)), stdin=out_file, shell=True).decode("utf-8") trainer.get_meter('valid_bleu_' + domain).update( float(res.split(',')[0].split('=')[1]), 1.0) stats = get_valid_stats(trainer, domain=domain, valid_select=valid_select) for k in ['loss', 'nll_loss', 'bleu']: stats[k] = stats[k].avg for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=os.path.join(subset, domain), step=trainer.get_num_updates()) valid_losses.update({domain: stats['loss']}) valid_bleus.update({domain: stats['bleu']}) valid_loss_all.append(stats['loss']) valid_nll_loss_all.append(stats['nll_loss']) valid_bleu_all.append(stats['bleu']) trainer.get_meter('valid_loss_all').update(np.mean(valid_loss_all), 1.0) trainer.get_meter('valid_nll_loss_all').update( np.mean(valid_nll_loss_all), 1.0) trainer.get_meter('valid_bleu_all').update(np.mean(valid_bleu_all), 1.0) stats = get_valid_stats(trainer, domain='all', valid_select=valid_select) for k in ['loss', 'nll_loss', 'bleu']: stats[k] = stats[k].avg progress = progress_bar.build_progress_bar( args, [0], epoch_itr.epoch, prefix='valid on \'{}\' subset \'{}\' domain'.format( subset, 'all'), no_progress_bar='simple') progress.print(stats, tag=os.path.join(subset, 'all'), step=trainer.get_num_updates()) valid_losses.update({'all': stats['loss']}) valid_bleus.update({'all': stats['bleu']}) return valid_losses, valid_bleus
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = (args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1]) # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar="simple", ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(",") max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ "loss", "nll_loss", "ntokens", "nsentences", "sample_size" ]: continue # these are already logged above if "loss" in k or k == "accuracy": extra_meters[k].update(v, log_output["sample_size"]) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag="train", step=stats["num_updates"]) # ignore the first mini-batch in words-per-second and updates-per-second calculation if i == 0: trainer.get_meter("wps").reset() trainer.get_meter("ups").reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag="train", step=stats["num_updates"]) # reset training meters for k in [ "train_loss", "train_nll_loss", "wps", "ups", "wpb", "bsz", "gnorm", "clip", ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] for subset in subsets: # Initialize data iterator itr = data.EpochBatchIterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=trainer.get_model().max_positions(), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) predicted_results, gold_clusters = None, None if 'gap_bert' in args.task: predicted_results, gold_clusters = collections.defaultdict( dict), {} for sample in progress: log_output = trainer.valid_step(sample) if 'gap_bert' in args.task: for threshold, predicted_dict in log_output[ 'predicted_results'].items(): len_before = len(predicted_results[threshold]) predicted_results[threshold].update(predicted_dict) assert len_before + len(predicted_dict) == len( predicted_results[threshold]) len_before = len(gold_clusters) gold_clusters.update(log_output['gold_clusters']) assert len_before + len( log_output['gold_clusters']) == len(gold_clusters) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'sample_size', 'predicted_results', 'gold_clusters' ]: continue extra_meters[k].update(v) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg if 'gap_bert' in args.task: best_f1, best_mf1, best_ff1 = float('-inf'), None, None best_threshold = None for idx, (k, predicted_result) in enumerate( predicted_results.items()): scores = trainer.criterion.coref_evaluator.eval( gold_clusters, predicted_result) masculine_score = scores[1] _, _, _, mf1 = masculine_score feminine_score = scores[2] _, _, _, ff1 = feminine_score overall_score = scores[0] _, _, _, f1 = overall_score if f1 > best_f1: best_f1, best_mf1, best_ff1 = f1, mf1, ff1 best_threshold = k if idx == 0: continue if idx == 1: stats['valid_loss'] = -f1 else: stats['valid_loss'] = min(-f1, stats['valid_loss']) if not args.no_train: if hasattr(save_checkpoint, 'best'): if stats['valid_loss'] < save_checkpoint.best: stats['best'] = -1 * stats['valid_loss'] stats['best_threshold'] = best_threshold save_checkpoint.best_threshold = best_threshold else: stats['best'] = -1 * save_checkpoint.best stats[ 'best_threshold'] = save_checkpoint.best_threshold else: stats['best'] = best_f1 stats['best_threshold'] = best_threshold save_checkpoint.best_threshold = best_threshold stats['f@%.2f' % best_threshold] = best_f1 stats['mf@%.2f' % best_threshold] = best_mf1 stats['ff@%.2f' % best_threshold] = best_ff1 progress.print(stats) valid_losses.append(stats['valid_loss']) return valid_losses
def train(args, trainer, dataset, epoch, batch_offset): """Train the model for one epoch.""" # Set seed based on args.seed and the epoch number so that we get # reproducible results when resuming from checkpoints seed = args.seed + epoch torch.manual_seed(seed) # The max number of positions can be different for train and valid # e.g., RNNs may support more positions at test time than seen in training max_positions_train = ( min(args.max_source_positions, trainer.get_model().max_encoder_positions()), min(args.max_target_positions, trainer.get_model().max_decoder_positions()) ) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( args.train_subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_train, seed=seed, epoch=epoch, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple') itr = itertools.islice(progress, batch_offset, None) # reset training meters for k in ['train_loss', 'train_nll_loss', 'src_train_loss', 'src_train_nll_loss', 'reg_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for i, sample in enumerate(itr, start=batch_offset): log_output = trainer.train_step(sample) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'reg_loss', 'src_loss']: continue # these are already logged above extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # save mid-epoch checkpoints if i == batch_offset: # ignore the first mini-batch in words-per-second calculation trainer.get_meter('wps').reset() if args.save_interval > 0 and trainer.get_num_updates() % args.save_interval == 0: save_checkpoint(trainer, args, epoch, i + 1) # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats)
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple' ) # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) preds, targets, all_results = [], [], [] for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size', 'targets', 'preds', 'starts', 'ends']: continue extra_meters[k].update(v) if 'targets' in log_output: preds.append(log_output['preds']) targets.append(log_output['targets']) if 'starts' in log_output: for i in range(len(sample['id'])): indice = sample['id'][i].tolist() start = log_output['starts'][i].cpu().tolist() end = log_output['ends'][i].cpu().tolist() unique_id = task.features[indice].unique_id result = SquadResult(unique_id, start, end) all_results.append(result) if len(preds) > 0: preds = torch.cat(preds, 0).cpu().numpy() targets = torch.cat(targets, 0).cpu().numpy() else: preds = None targets = None if len(all_results) > 0: results = task.compute_predictions_logits(all_results) for k, v in results.items(): print("({}, {})".format(k, v)) exit() # log validation stats stats = get_valid_stats(trainer, args, extra_meters, preds, targets) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append( stats[args.best_checkpoint_metric].avg if args.best_checkpoint_metric == 'loss' else stats[args.best_checkpoint_metric] ) return valid_losses
def baseline_with_meta_evaluation(model, meta_learning_task, meta_learning_args, meta_learning_criterion, fine_tune_args): meta_epoch_itr, meta_trainer, max_meta_epoch, max_meta_update, valid_subsets = prepare_meta_task( model=model, meta_learning_task=meta_learning_task, meta_learning_args=meta_learning_args, meta_learning_criterion=meta_learning_criterion) # Combine and do fine-tuning on combined data meta_train = meta_learning_task.dataset(meta_learning_args.train_subset) combined_fairseq_task = combine_data(meta_train=meta_train, fine_tune_args=fine_tune_args) # Fine-tune using the combined task criterion = combined_fairseq_task.build_criterion(fine_tune_args) import math from fairseq.trainer import Trainer combined_fairseq_task.load_dataset(fine_tune_args.train_subset) train_dataset = combined_fairseq_task.dataset(fine_tune_args.train_subset) # Make a dummy batch to (i) warm the caching allocator and (ii) as a placeholder DistributedDataParallel when # there's an uneven number of batches per worker. max_positions = utils.resolve_max_positions( combined_fairseq_task.max_positions(), model.max_positions(), ) dummy_batch = train_dataset.get_dummy_batch( num_tokens=fine_tune_args.max_tokens, max_positions=max_positions) oom_batch = combined_fairseq_task.dataset( fine_tune_args.train_subset).get_dummy_batch(1, max_positions) # Create a trainer for training the model trainer = Trainer(fine_tune_args, combined_fairseq_task, model, criterion, dummy_batch, oom_batch) epoch_itr = utils.create_epoch_iterator(task=combined_fairseq_task, dataset=train_dataset, args=fine_tune_args, max_positions=max_positions) max_epoch = fine_tune_args.max_epoch or math.inf max_update = fine_tune_args.max_update or math.inf # Do SGD on this task valid_subsets = fine_tune_args.valid_subset.split(',') lr = trainer.get_lr() batch_info = [] # Always validate once before training valid_losses, _ = utils.validate(fine_tune_args, trainer, combined_fairseq_task, epoch_itr, valid_subsets) while lr > fine_tune_args.min_lr and epoch_itr.epoch < max_epoch and trainer.get_num_updates( ) < max_update: # Train the model for one epoch import collections import math from fairseq.data import iterators from fairseq import progress_bar from fairseq.meters import AverageMeter, ConcatentateMeter, BleuMeter """Train the model for one epoch.""" # Update parameters every N batches update_freq = fine_tune_args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(fine_tune_args.update_freq) else fine_tune_args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=fine_tune_args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= fine_tune_args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( fine_tune_args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters['strings'] = ConcatentateMeter() extra_meters['bleu_stats'] = BleuMeter() valid_subsets = fine_tune_args.valid_subset.split(',') max_update = fine_tune_args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = utils.get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag=fine_tune_args.train_subset, step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if fine_tune_args.save_interval_updates > 0 and num_updates % fine_tune_args.save_interval_updates == 0 and num_updates > 0: valid_losses, _ = utils.validate(fine_tune_args, trainer, combined_fairseq_task, epoch_itr, valid_subsets, train_progress=progress) utils.save_checkpoint(fine_tune_args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = utils.get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg stats[k + '_std'] = meter.std progress.print(stats, tag=fine_tune_args.train_subset, step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset() # Evaluate on validation split if epoch_itr.epoch % fine_tune_args.validate_interval == 0: valid_losses, _ = utils.validate(fine_tune_args, trainer, combined_fairseq_task, epoch_itr, valid_subsets) # save checkpoint if epoch_itr.epoch % fine_tune_args.save_interval == 0: utils.save_checkpoint(fine_tune_args, trainer, epoch_itr, valid_losses[0]) # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) if batch_info is None: # Handle the original train function batch_info = [] # Evaluate on validation split maybe_validate(meta_epoch_itr=meta_epoch_itr, meta_learning_args=meta_learning_args, meta_trainer=meta_trainer, meta_learning_task=meta_learning_task, valid_subsets=valid_subsets)
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr() progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple') # update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] if args.enable_parallel_backward_allred_opt and update_freq > 1: raise RuntimeError( '--enable-parallel-backward-allred-opt is incompatible with --update-freq > 1' ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) #begin = time.time() #inside = 0 for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch): #newbegin = time.time() #print("iter time", newbegin - begin, inside, (newbegin - begin - inside)*1000) #begin = newbegin if i < num_batches - 1 and (i + 1) % update_freq > 0: # buffer updates according to --update-freq trainer.train_step(sample, update_params=False, last_step=(i == len(itr) - 1)) continue else: log_output = trainer.train_step(sample, update_params=True, last_step=(i == len(itr) - 1)) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'sample_size']: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() if args.profile is not None and i == args.profile: import sys sys.exit() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break #end = time.time() #inside = end - begin # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip' ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def init_meters(self, args): self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() self.meters['train_generate_loss'] = AverageMeter() self.meters['train_generate_nll_loss'] = AverageMeter() self.meters['train_predict_loss'] = AverageMeter() self.meters['train_predict_nll_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter() self.meters['valid_generate_loss'] = AverageMeter() self.meters['valid_generate_nll_loss'] = AverageMeter() self.meters['valid_predict_loss'] = AverageMeter() self.meters['valid_predict_nll_loss'] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory self.meters['src_tokens'] = [] # self.meters['target_tokens'] = [] # self.meters['select_retrive_tokens'] = [] # self.meters['loss_weight'] = [] # if args.fp16: self.meters['loss_scale'] = AverageMeter() # dynamic loss scale self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
def main(args, init_distributed=False): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args print(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): if "target" not in valid_sub_split: task.load_dataset(valid_sub_split, combine=False, epoch=0) # task.load_dataset("target_" + valid_sub_split, combine=False, epoch=0) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) print(model) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) assert isinstance(model, XlmrTransformerEncoderDecoder) # encoder = model.encoder # args.task = 'semparse_classification' # adv_task = tasks.setup_task(args, xlmr=task.xlmr) # assert isinstance(adv_task, SemparseClassificationTask) # # # Build adversarial language critic model and criterion (WGAN-GP) # adv_model = adv_task.build_model(args) # adv_criterion = adv_task.build_criterion(args) # print(adv_model) # print('| model {}, criterion {}'.format(args.arch, adv_criterion.__class__.__name__)) # print('| num. model params: {} (num. trained: {})'.format( # sum(p.numel() for p in adv_model.parameters()), # sum(p.numel() for p in adv_model.parameters() if p.requires_grad), # )) # Build # Build trainer trainer = Trainer(args, task, model, criterion) # adv_trainer = Trainer(args, adv_task, adv_model, adv_criterion) print('| training on {} GPUs'.format(args.distributed_world_size)) print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( args.max_tokens, args.max_sentences, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = trainer.get_lr() grad_meter = AverageMeter() train_meter = StopwatchMeter() train_meter.start() valid_subsets = args.valid_subset.split(',') if args.plot_features: valid_losses, accuracy = validate(args, trainer, task, epoch_itr, valid_subsets) plot_features(args, trainer, task, epoch_itr, valid_subsets, accuracy=accuracy) if args.validate_first: validate(args, trainer, task, epoch_itr, valid_subsets) while (lr > args.min_lr and (epoch_itr.epoch < max_epoch or (epoch_itr.epoch == max_epoch and epoch_itr._next_epoch_itr is not None)) and trainer.get_num_updates() < max_update): # train for one epoch train(args, trainer, task, epoch_itr, grad_meter) if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses, accuracy = validate(args, trainer, task, epoch_itr, valid_subsets) else: valid_losses = [None] # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) reload_dataset = ':' in getattr(args, 'data', '') # sharded data: get train iterator for next epoch epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def eval_tune_loss(args, trainer, task, subset, extra_state): """Evaluate the model on the validation set and return the average loss.""" # Initialize dataloader itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions()), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args=args, iterator=itr, epoch=extra_state["epoch"], prefix=f"valid on '{subset}' subset", no_progress_bar="simple", ) # reset validation loss meters for k in ["valid_loss", "valid_nll_loss"]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) # log mid-validation stats stats = get_valid_stats(trainer) for k, v in log_output.items(): if k in [ "loss", "nll_loss", "ntokens", "nsentences", "sample_size" ]: continue if "loss" in k: extra_meters[k].update(v, log_output["sample_size"]) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) extra_state["tune_eval"]["loss"] = stats["valid_loss"] extra_state["tune_eval"]["perplexity"] = stats["valid_ppl"] if (extra_state["tune_eval"]["lowest_loss"] is None or extra_state["tune_eval"]["loss"] < extra_state["tune_eval"]["lowest_loss"]): extra_state["tune_eval"]["lowest_loss"] = extra_state["tune_eval"][ "loss"] extra_state["tune_eval"]["num_since_best"] = 0 else: extra_state["tune_eval"]["num_since_best"] += 1 stop_due_to_tune_loss = False if (args.stop_no_best_validate_loss >= 0 and extra_state["tune_eval"]["num_since_best"] > args.stop_no_best_validate_loss): stop_due_to_tune_loss = True print( f"Stopping training due to eval tune loss stagnation - last best " f"eval tune loss of {extra_state['tune_eval']['lowest_loss']} " f"(current loss: {extra_state['tune_eval']['loss']}) " f"was {extra_state['tune_eval']['num_since_best']} validations ago." ) return extra_state, stop_due_to_tune_loss
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr() itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'loss_sen_piece', 'nll_loss_sen_piece', 'overall_loss', 'overall_nll_loss', 'ntokens', 'ntokens_sen_piece', 'nsentences', 'sample_size', 'sample_size_sen_piece', 'sample_size_overall' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) if 'loss_sen_piece' in k: extra_meters[k].update(v, log_output['sample_size_sen_piece']) else: extra_meters[k].update(v) if 'overall_loss' in k: extra_meters[k].update(v, (log_output['sample_size_overall']) / 2.0) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses, valid_losses_sen_piece, valid_overall_losses = validate( args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0], valid_losses_sen_piece[0], valid_overall_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'train_loss_sen_piece', 'train_nll_loss_sen_piece', 'train_overall_loss', 'train_overall_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr() progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple') # update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch): if i < num_batches - 1 and (i + 1) % update_freq > 0: # buffer updates according to --update-freq trainer.train_step(sample, update_params=False) continue else: log_output = trainer.train_step(sample, update_params=True) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'sample_size']: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip' ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] valid_losses_sen_piece = [] valid_overall_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in [ 'valid_loss', 'valid_nll_loss', 'valid_loss_sen_piece', 'valid_nll_loss_sen_piece', 'valid_overall_loss', 'valid_overall_nll_loss' ]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'loss_sen_piece', 'nll_loss_sen_piece', 'overall_loss', 'overall_nll_loss', 'ntokens', 'ntokens_sen_piece', 'nsentences', 'sample_size', 'sample_size_sen_piece', 'sample_size_overall' ]: continue extra_meters[k].update(v) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) valid_losses.append(stats['valid_loss']) valid_losses_sen_piece.append(stats['valid_loss_sen_piece']) valid_overall_losses.append(stats['valid_overall_loss']) return valid_losses, valid_losses_sen_piece, valid_overall_losses
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) if callable(getattr(trainer.criterion, 'set_valid_tgt_dataset', None)): trainer.criterion.set_valid_tgt_dataset(task.dataset(subset).tgt) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size', 'word_count', 'char_count' ]: continue if k == 'word_error': extra_meters['wer'].update( float(v) / log_output['word_count'] * 100, log_output['word_count']) elif k == 'char_error': extra_meters['cer'].update( float(v) / log_output['char_count'] * 100, log_output['char_count']) else: extra_meters[k].update(v) # log validation stats stats = get_valid_stats(trainer, args, extra_meters) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric].avg if args. best_checkpoint_metric == 'loss' else stats[args.best_checkpoint_metric]) return valid_losses
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=False, # edited 5-5-2020 for not shuffling ) print('iterator in training:') print(itr.iterable.batch_sampler) itr = iterators.GroupedIterator(itr, update_freq) print('iterator in training-2:') print(itr) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) print('iterator in training-2**:') print(itr) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf print('iterator in training-3:') print(epoch_itr.iterations_in_epoch) print('progress:') print(progress) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): #Christine (6-5-2020) print('samples:') #print(samples) dtype = samples[0]['net_input']['src_tokens'].dtype #print('dtype:') #print(dtype) deleted_batches = samples[0]['deleted'] task.initiate_memory(i, deleted_batches, trainer, dtype) log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break if itr: print('iterator is not empty') else: pri stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()