def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] for subset in subsets: # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue extra_meters[k].update(v) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) valid_losses.append(stats['valid_loss']) return valid_losses
def validate(args, trainer, dataset, subset, epoch): """Evaluate the model on the validation set and return the average loss.""" # Initialize dataloader max_positions_valid = ( trainer.get_model().max_encoder_positions(), trainer.get_model().max_decoder_positions(), ) itr = dataset.eval_dataloader( subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar( args, itr, epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) # log mid-validation stats stats = get_valid_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss']: continue extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) return stats['valid_loss']
def __init__(self, args, task, model, criterion, dummy_batch): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') self.args = args self.task = task # copy model and criterion to current device self.criterion = criterion.cuda() if args.fp16: self._model = model.half().cuda() else: self._model = model.cuda() # initialize meters self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory if args.fp16: self.meters['loss_scale'] = AverageMeter() # dynamic loss scale self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['train_wall'] = StopwatchMeter( ) # train wall time in seconds self._dummy_batch = dummy_batch self._num_updates = 0 self._optim_history = None self._optimizer = None self._wrapped_model = None
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus): """Evaluate the model on the validation set and return the average loss.""" itr = dataset.dataloader(subset, batch_size=None, max_tokens=args.max_tokens, max_positions=args.max_positions) loss_meter = AverageMeter() rouge_greedy_meter = AverageMeter() rouge_sampled_meter = AverageMeter() desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset) with progress_bar(itr, desc, leave=False) as t: for _, sample in data.skip_group_enumerator(t, ngpus): ntokens = sum(s['ntokens'] for s in sample) loss, mean_rouge_greedy, mean_rouge_sampled = trainer.valid_step( sample, criterion) loss_meter.update(loss, ntokens) rouge_greedy_meter.update(mean_rouge_greedy, 1) rouge_sampled_meter.update(mean_rouge_sampled, 1) t.set_postfix( collections.OrderedDict([ ('loss', '{:.2f}'.format(loss_meter.avg)), ('ROUGE-L/f (greedy)', '{:.4f}'.format(rouge_greedy_meter.avg)), ('ROUGE-L/f (sampled)', '{:.4f}'.format(rouge_sampled_meter.avg)) ])) val_loss = loss_meter.avg t.write( desc + ' | valid loss {:2.2f} | valid ppl {:3.2f} | ROUGE-L (greedy): {:.4f} | ROUGE-L (sampled): {:.4f}' .format(val_loss, math.pow(2, val_loss), rouge_greedy_meter.avg, rouge_sampled_meter.avg)) # update and return the learning rate return val_loss
def __init__(self, args, task, model, criterion, allreduce_communicators=None): super().__init__(args, task, model, criterion, allreduce_communicators) # convert model to FP16 (but keep criterion FP32) self.model.half() # dynamically scale loss to reduce overflow self.scaler = DynamicLossScaler(init_scale=2.**7) self.meters['loss_scale'] = AverageMeter() # FIXME: Add more meters self.grad_denom = 1.0 assert (not self.args.enable_parallel_backward_allred_opt), "--distributed-weight-update cannot be combined with --enable-parallel-backward-allred-opt"
def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] for subset in subsets: itr = data_utils.get_epoch_iterator( task, task.dataset(subset), max_tokens=args.max_tokens_valid, max_sentences=None, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions()), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, num_workers=args.num_workers, seed=args.seed, epoch=epoch_itr.epoch).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue extra_meters[k].update(v) # log validation stats stats = get_valid_stats(trainer, args, extra_meters) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric].avg if args. best_checkpoint_metric == 'loss' else stats[args.best_checkpoint_metric]) return valid_losses
def rollout_critic(self, num_rollouts, samples): masked, unmasked, lengths, mask = samples batch_size, seq_len = samples[0].size() meter = AverageMeter() self.opt.zero_grad() pbar = _tqdm(num_rollouts, 'critic-rollout') for rollout in pbar: loss = self.model(masked, lengths, mask, unmasked, tag="c-step") loss = loss.sum() / batch_size loss.backward() meter.update(loss.item()) self.opt.step() self.logger.log("critic/loss", self.step, meter.avg)
def __init__(self, args, task, model, criterion): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') self.args = args # copy model and criterion to current device self.task = task self.model = model.cuda() self.criterion = criterion.cuda() # initialize meters self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory self.meters['wall'] = TimeMeter() # wall time in seconds self._buffered_stats = defaultdict(lambda: []) self._flat_grads = None self._num_updates = 0 self._optim_history = None self._optimizer = None self._last_step = False if self.args.enable_parallel_backward_allred_opt and not self.args.distributed_world_size > 1: raise RuntimeError( '--enable-parallel-backward-allred-opt is only meant for distributed training' ) if self.args.enable_parallel_backward_allred_opt and not self.args.fp16: raise RuntimeError( '--enable-parallel-backward-allred-opt only works with FP16 training' )
def setup_epoch(args, epoch_itr, trainer): """Sets up data and progress meters for one epoch.""" # Initialize dataloader, starting at batch_offset itr = epoch_itr.next_epoch_itr() progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar="simple" ) # reset training meters for k in ["train_loss", "train_nll_loss", "wps", "ups", "wpb", "bsz", "clip"]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) return itr, progress, extra_meters
def valid_loop_fn( args, device, trainer, progress, loader, last_batch_index ): extra_meters = collections.defaultdict(lambda: AverageMeter()) for i, sample in enumerate(loader): if i == last_batch_index: # last batches are of different size, will cause recompilations break log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue extra_meters[k].update(v) stats = get_valid_stats(trainer, args) for k, meter in extra_meters.items(): stats[k] = meter.avg return stats
def __init__(self, args, task, model, criterion): super().__init__(args, task, model, criterion) # convert model to FP16 (but keep criterion FP32) self.model.half() # dynamically scale loss to reduce overflow self.scaler = DynamicLossScaler(init_scale=2.**7) self.meters['loss_scale'] = AverageMeter() self.grad_denom = 1.0 if self.args.enable_parallel_backward_allred_opt: import numpy as np self._reduction_stream = torch.cuda.Stream() self._flat_grads_parallel = torch.tensor([], dtype=torch.float16).cuda() self._grads_info = [] grads_size = 0 p_offset = 0 for p_i, p in enumerate([p for p in self.model.parameters() if p.requires_grad]): p_grads_size = np.prod(list(p.size())) grads_size += p_grads_size # register hooks def wrapper(param, param_i, param_grads_size, param_offset): def allreduce_hook(grad): self._do_allreduce(param_i, param_grads_size, param_offset, grad) if param.requires_grad: param.register_hook(allreduce_hook) # print(p_i, p.size(), p_grads_size, p_offset) self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset}) wrapper(p, p_i, p_grads_size, p_offset) p_offset += p_grads_size self._flat_grads_parallel.resize_(grads_size) # print(grads_size, len(self._flat_grads_parallel), self._flat_grads_parallel.dtype, self._flat_grads_parallel.get_device()) self._allreduce_flush_min_threshold = self.args.parallel_backward_allred_opt_threshold print("| parallel all-reduce ENABLED. all-reduce threshold: " + str(self._allreduce_flush_min_threshold)) self._grads_generated = [False]*len(self._grads_info) self._allreduce_processed_idx = len(self._grads_info)-1 if self.args.enable_parallel_backward_allred_opt_correctness_check: self._num_grads_generated = 0 self._all_grads_generated = False self._allreduce_schedule = []
def __init__(self, args, task, model, criterion): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') self.args = args # copy model and criterion to current device self.task = task self.model = model.cuda() self.criterion = criterion.cuda() # initialize meters self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory self.meters['wall'] = TimeMeter() # wall time in seconds self._buffered_stats = defaultdict(lambda: []) self._flat_grads = None self._num_updates = 0 self._optim_history = None self._optimizer = None if self.args.use_ema: print('Use ema.') from fairseq.utils import EMA self._ema = EMA() self._backup = {} self._init_ema()
def setup_epoch(args, epoch, batch_offset, trainer, dataset): """Sets up data and progress meters for one epoch.""" # Set seed based on args.seed and the epoch number so that we get # reproducible results when resuming from checkpoints seed = args.seed + epoch torch.manual_seed(seed) # The max number of positions can be different for train and valid # e.g., RNNs may support more positions at test time than seen in training max_positions_train = ( min(args.max_source_positions, trainer.get_model().max_encoder_positions()), min(args.max_target_positions, trainer.get_model().max_decoder_positions()), ) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( args.train_subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_train, seed=seed, epoch=epoch, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar="simple") itr = itertools.islice(progress, batch_offset, None) # reset training meters for k in [ "train_loss", "train_nll_loss", "wps", "ups", "wpb", "bsz", "clip" ]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) return itr, progress, extra_meters
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" progress = initialize_loader_for_epoch(args, epoch_itr) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second and updates-per-second calculation reset_perf_training_meters(trainer, i) num_updates = trainer.get_num_updates() if ( not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0 ): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) reset_training_meters(trainer)
def __init__(self, args, task, model, criterion): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') self.args = args # copy model and criterion to current device self.task = task self.model = model.cuda() self.criterion = criterion.cuda() # initialize meters self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory self.meters['wall'] = TimeMeter() # wall time in seconds self._buffered_stats = defaultdict(lambda: []) self._flat_grads = None self._num_updates = 0 self._optim_history = None self._optimizer = None self.prev_teacher_models = None #used as the models to perform kd training self.prev_teacher_val_losses = OrderedDict() self.kd_teacher_weights = None
def start_epoch(self) -> bool: if not (self.trainer.get_lr() > self.args.min_lr and self.epoch_itr.epoch < self.max_epoch and self.trainer.get_num_updates() < self.max_update): self._done = True return False args = self.args update_freq = args.update_freq[self.epoch_itr.epoch - 1] \ if self.epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = self.epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(self.epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) # meters in the epoch self.extra_meters = collections.defaultdict(lambda: AverageMeter()) # enumerate self.itr = enumerate(itr, start=self.epoch_itr.iterations_in_epoch) return True
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus): """Evaluate the model on the validation set and return the average loss.""" itr = dataset.dataloader(subset, batch_size=None, max_tokens=args.max_tokens, max_positions=args.max_positions) loss_meter = AverageMeter() desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset) with progress_bar(itr, desc, leave=False) as t: for _, sample in data.skip_group_enumerator(t, ngpus): ntokens = sum(s['ntokens'] for s in sample) loss = trainer.valid_step(sample, criterion) loss_meter.update(loss, ntokens) t.set_postfix(loss='{:.2f}'.format(loss_meter.avg)) val_loss = loss_meter.avg t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format( val_loss, math.pow(2, val_loss))) # update and return the learning rate return val_loss
def __init__(self, args, model): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') self.args = args self.model = model.cuda() self.criterion = CRITERION_REGISTRY[args.criterion](args).cuda() self.optimizer = optim.build_optimizer(self.args, self.model.parameters()) self.lr_scheduler = lr_scheduler.build_lr_scheduler( self.args, self.optimizer) self.scaler = amp.GradScaler(enabled=self.args.amp, init_scale=2**15) if self.args.distributed_world_size > 1: self.model = DDP(model) self._buffered_stats = defaultdict(lambda: []) self._num_updates = 0 self._optim_history = None self.throughput_meter = TimeMeter() self.avg_loss_meter = AverageMeter()
def validate(args, trainer, dataset, subset, extra_state): """Evaluate the model on the validation set and return the average loss.""" epoch = extra_state["epoch"] # Initialize dataloader max_positions_valid = ( trainer.get_model().max_encoder_positions(), trainer.get_model().max_decoder_positions(), ) itr = dataset.eval_dataloader( subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar( args, itr, epoch, prefix=f"valid on '{subset}' subset", no_progress_bar="simple") # reset validation loss meters for k in ["valid_loss", "valid_nll_loss"]: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) # log mid-validation stats stats = get_valid_stats(trainer) for k, v in log_output.items(): if k in ["loss", "nll_loss"]: continue if "loss" in k: extra_meters[k].update(v, log_output["sample_size"]) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) val_loss = stats["valid_loss"] val_ppl = stats["valid_ppl"] if ("validate" not in extra_state or val_loss < extra_state["validate"]["lowest_loss"]): extra_state["validate"] = { "lowest_loss": val_loss, "num_since_best": 0 } else: extra_state["validate"]["num_since_best"] += 1 stop_due_to_val_loss = False if (args.stop_no_best_validate_loss >= 0 and extra_state["validate"]["num_since_best"] > args.stop_no_best_validate_loss): stop_due_to_val_loss = True print( f"Stopping training due to validation score stagnation - last best " f"validation loss of {extra_state['validate']['lowest_loss']} (current loss: {val_loss})" f"was {extra_state['validate']['num_since_best']} validations ago." ) return val_loss, val_ppl, stop_due_to_val_loss
def estimate_head_importance(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus) if args.n_pruning_steps > 0: itr = islice(itr, args.n_pruning_steps) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) # Inititalize meters extra_meters = collections.defaultdict(lambda: AverageMeter()) # Initialize head importance scores encoder_layers = trainer.args.encoder_layers decoder_layers = trainer.args.decoder_layers encoder_heads = trainer.args.encoder_attention_heads decoder_heads = trainer.args.decoder_attention_heads device = next(trainer.model.parameters()).device head_importance = { "encoder_self": torch.zeros(encoder_layers, encoder_heads).to(device), "encoder_decoder": torch.zeros(decoder_layers, decoder_heads).to(device), "decoder_self": torch.zeros(decoder_layers, decoder_heads).to(device), } # Denominators to normalize properly denoms = { attn_type: val.clone() for attn_type, val in head_importance.items() } head_stats = { attn_type: [{} for _ in range(val.size(0))] for attn_type, val in head_importance.items() } for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): # Compute gradients log_output = trainer.prune_step(samples) # Retrieve importance scores for the encoder for layer in range(encoder_layers): self_attn_variables = trainer.model.encoder.layers[ layer].self_attn_variables importance, denom = batch_head_importance(self_attn_variables, one_minus=args.one_minus) head_importance["encoder_self"][layer] += importance denoms["encoder_self"][layer] += denom # Stats aggregate_stats(head_stats["encoder_self"][layer], batch_head_stats(self_attn_variables)[0]) # Retrieve importance scores for the decoder for layer in range(decoder_layers): # Self attention self_attn_variables = trainer.model.decoder.layers[ layer].self_attn_variables importance, denom = batch_head_importance(self_attn_variables, one_minus=args.one_minus) head_importance["decoder_self"][layer] += importance denoms["decoder_self"][layer] += denom aggregate_stats( head_stats["decoder_self"][layer], batch_head_stats(self_attn_variables, triu_masking=True)[0]) # Encoder attention encoder_attn_variables = trainer.model.decoder.layers[ layer].encoder_attn_variables importance, denom = batch_head_importance(encoder_attn_variables, one_minus=args.one_minus) head_importance["encoder_decoder"][layer] += importance denoms["encoder_decoder"][layer] += denom aggregate_stats(head_stats["encoder_decoder"][layer], batch_head_stats(encoder_attn_variables)[0]) # log mid-epoch stats stats = get_pruning_stats(trainer) for k, v in log_output.items(): extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() # log end-of-epoch stats stats = get_pruning_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # Normalize by type for attn_type in denoms: head_importance[attn_type] /= denoms[attn_type] # Normalize head stats for attn_type in denoms: for layer in range(len(head_stats[attn_type])): for key in head_stats[attn_type][layer]: head_stats[attn_type][layer][key] /= denoms[attn_type].mean( ).cpu() # Normalize by layer if args.normalize_by_layer: for layer in range(encoder_layers): for attn_type, importance in head_importance.items(): head_importance[attn_type][layer] /= torch.sqrt( torch.sum(importance[layer]**2)) return {k: v.cpu() for k, v in head_importance.items()}, head_stats
def init_meters(self, args): self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() for domain in ['all'] + args.valid_domains: self.meters['valid_loss_' + domain] = AverageMeter() self.meters['valid_nll_loss_' + domain] = AverageMeter() self.meters['valid_bleu_' + domain] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory if args.fp16: self.meters['loss_scale'] = AverageMeter() # dynamic loss scale self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['train_wall'] = StopwatchMeter( ) # train wall time in seconds
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr() progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple') # update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch): if i < num_batches - 1 and (i + 1) % update_freq > 0: # buffer updates according to --update-freq trainer.train_step(sample, update_params=False) continue else: log_output = trainer.train_step(sample, update_params=True) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'sample_size']: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip' ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def init_meters(self, args): self.meters = OrderedDict() self.meters['train_loss'] = AverageMeter() self.meters['train_distribution_loss'] = AverageMeter() self.meters['train_label_loss'] = AverageMeter() self.meters['train_label_acc'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter() self.meters['copy_alpha'] = AverageMeter() self.meters['wps'] = TimeMeter() # words per second self.meters['ups'] = TimeMeter() # updates per second self.meters['wpb'] = AverageMeter() # words per batch self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['oom'] = AverageMeter() # out of memory if args.fp16: self.meters['loss_scale'] = AverageMeter() # dynamic loss scale self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): try: log_output = trainer.train_step(samples) except FloatingPointError as e: if "Minimum loss scale reached" in str(e): print(f'Check samples: len={len(samples)}') for ik, s in enumerate(samples): if s is None: print(f'[{ik}]: None') else: for k, v in s.items(): if isinstance(v, torch.Tensor): print(f'[{ik}][{k}]: {v.size()}') raise e if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus): """Train the model for one epoch.""" itr = dataset.dataloader( args.train_subset, batch_size=args.batch_size, test_batch_size=args.test_batch_size, valid_batch_size=args.valid_batch_size, num_workers=args.workers, max_tokens=args.max_tokens, seed=args.seed, epoch=epoch, max_positions=args.max_positions, sample_without_replacement=args.sample_without_replacement) loss_meter = AverageMeter() bsz_meter = AverageMeter() # sentences per batch wpb_meter = AverageMeter() # words per batch wps_meter = TimeMeter() # words per second clip_meter = AverageMeter() # % of updates clipped gnorm_meter = AverageMeter() # gradient norm desc = '| epoch {:03d}'.format(epoch) lr = trainer.get_lr() with progress_bar(itr, desc, leave=False) as t: for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): loss, grad_norm = trainer.train_step(sample, criterion) ntokens = sum(s['ntokens'] for s in sample) src_size = sum(s['src_tokens'].size(0) for s in sample) loss_meter.update(loss, ntokens) bsz_meter.update(src_size) wpb_meter.update(ntokens) wps_meter.update(ntokens) clip_meter.update(1 if grad_norm > args.clip_norm else 0) gnorm_meter.update(grad_norm) t.set_postfix( collections.OrderedDict([ ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)), ('wps', '{:5d}'.format(round(wps_meter.avg))), ('wpb', '{:5d}'.format(round(wpb_meter.avg))), ('bsz', '{:5d}'.format(round(bsz_meter.avg))), ('lr', lr), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ('gnorm', '{:.4f}'.format(gnorm_meter.avg)), ])) if i == 0: # ignore the first mini-batch in words-per-second calculation wps_meter.reset() if args.save_interval > 0 and (i + 1) % args.save_interval == 0: trainer.save_checkpoint(args, epoch, i + 1) fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}' fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}' fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}' t.write( fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg), round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg), round(bsz_meter.avg), lr, clip_meter.avg * 100, gnorm_meter.avg))
def validate(args, trainer, dataset, subset, epoch): """Evaluate the model on the validation set and return the average loss.""" # Initialize dataloader max_positions_valid = ( trainer.get_model().max_encoder_positions(), trainer.get_model().max_decoder_positions(), ) itr = dataset.eval_dataloader( subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=max_positions_valid, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar( args, itr, epoch, prefix=f'valid on \'{subset}\' subset', no_progress_bar='simple') # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) # log mid-validation stats stats = get_valid_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss']: continue if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) val_loss = stats['valid_loss'] val_ppl = stats['valid_ppl'] if not hasattr(validate, 'lowest_loss') or val_loss < validate.lowest_loss: validate.lowest_loss = val_loss validate.num_since_best = 0 elif not hasattr(validate, 'num_since_best'): validate.num_since_best = 1 else: validate.num_since_best += 1 stop_due_to_val_loss = False if (args.stop_no_best_validate_loss >= 0 and validate.num_since_best > args.stop_no_best_validate_loss): stop_due_to_val_loss = True print( f'Stopping training due to validation score stagnation - last best ' 'validation loss of {validate.lowest_loss} (current loss: {val_loss})' 'was {validate.num_since_best} validations ago.') return val_loss, val_ppl, stop_due_to_val_loss
def train(args, trainer, dataset, epoch, batch_offset): """Train the model for one epoch.""" # Set seed based on args.seed and the epoch number so that we get # reproducible results when resuming from checkpoints seed = args.seed + epoch torch.manual_seed(seed) # The max number of positions can be different for train and valid # e.g., RNNs may support more positions at test time than seen in training max_positions_train = ( min(args.max_source_positions, trainer.get_model().max_encoder_positions()), min(args.max_target_positions, trainer.get_model().max_decoder_positions()) ) # Initialize dataloader, starting at batch_offset itr = dataset.train_dataloader( args.train_subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions_train, seed=seed, epoch=epoch, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch <= args.curriculum), shard_id=args.distributed_rank, num_shards=args.distributed_world_size, ) progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple') itr = itertools.islice(progress, batch_offset, None) # reset training meters for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for i, sample in enumerate(itr, start=batch_offset): log_output = trainer.train_step(sample) # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in ['loss', 'nll_loss']: continue # these are already logged above extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats) # save mid-epoch checkpoints if i == batch_offset: # ignore the first mini-batch in words-per-second calculation trainer.get_meter('wps').reset() if args.save_interval > 0 and trainer.get_num_updates() % args.save_interval == 0: save_checkpoint(trainer, args, epoch, i + 1) # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats)
def train(args, trainer, task, epoch_itr, summary_writer=None): """Train the model for one epoch.""" # Update parameters every N batches if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) first_valid = args.valid_subset.split(',')[0] max_update = args.max_update or math.inf num_batches = len(epoch_itr) distributed_utils.barrier(args, "train_%d" % trainer.get_num_updates()) for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k: extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg stats['progress'] = round( i / num_batches * args.distributed_world_size * args.update_freq[-1], 3) progress.log(stats) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0: valid_losses = validate(args, trainer, task, epoch_itr, [first_valid]) save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) distributed_utils.barrier( args, "train_val_%d" % trainer.get_num_updates()) if num_updates % args.log_interval == 0: summary_writer.log_stats('train', stats, num_updates) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def train(args, trainer, task, epoch_itr): """Train the model for one epoch.""" # Update parameters every N batches update_freq = args.update_freq[epoch_itr.epoch - 1] \ if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.epoch >= args.curriculum), ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple', ) extra_meters = collections.defaultdict(lambda: AverageMeter()) valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): log_output = trainer.train_step(samples, epoch_itr=epoch_itr) if log_output is None: continue # log mid-epoch stats stats = get_training_stats(trainer) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue # these are already logged above if 'loss' in k or k == 'accuracy': extra_meters[k].update(v, log_output['sample_size']) else: extra_meters[k].update(v) stats[k] = extra_meters[k].avg progress.log(stats, tag='train', step=stats['num_updates']) # ignore the first mini-batch in words-per-second calculation if i == 0: trainer.get_meter('wps').reset() num_updates = trainer.get_num_updates() if (not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0): valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) if num_updates >= max_update: break # log end-of-epoch stats stats = get_training_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg progress.print(stats, tag='train', step=stats['num_updates']) # reset training meters for k in [ 'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip', ]: meter = trainer.get_meter(k) if meter is not None: meter.reset()
def validate(args, trainer, task, epoch_itr, subsets, test_bleu=False, summary_writer=None): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] distributed_utils.barrier(args, "validate1_%d" % trainer.get_num_updates()) for subset in subsets: # Initialize data iterator def get_itr(): itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') return progress progress = get_itr() num_dataset = task.dataset(subset).num_dataset # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue extra_meters[k].update(v) bleu_scorers = [ bleu.Scorer(task.target_dictionary.pad(), task.target_dictionary.eos(), task.target_dictionary.unk()) for _ in range(num_dataset) ] if test_bleu else None # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg if bleu_scorers is not None: # test bleu print("| test bleu.") sample_size = [0 for _ in range(num_dataset)] bleu_scores = [0 for _ in range(num_dataset)] progress = get_itr() tgt_str_files = [] hypo_str_files = [] for ds_id in range(num_dataset): tgt_str_path = task.dataset( subset).dataset_names[ds_id] + '.tgt.txt' hypo_str_path = task.dataset( subset).dataset_names[ds_id] + '.hypo.txt' tgt_str_files.append( open(os.path.join(args.save_dir, tgt_str_path), 'w', encoding='utf-8')) hypo_str_files.append( open(os.path.join(args.save_dir, hypo_str_path), 'w', encoding='utf-8')) def print_to_file(dataset_id, tgt_str, hypo_str): tgt_str_files[dataset_id].write(tgt_str + '\n') hypo_str_files[dataset_id].write(hypo_str + '\n') for sample in progress: trainer.test_bleu_step(sample, bleu_scorers, print_to_file) if 'dataset_id' in sample: for ds_id in range(num_dataset): sample_size[ds_id] += ( sample['dataset_id'] == ds_id).int().sum().item() elif 'id' in sample: sample_size[0] += len(sample['id']) for f in tgt_str_files + hypo_str_files: f.close() distributed_utils.barrier( args, "validate2_%d" % trainer.get_num_updates()) for ds_id in range(num_dataset): try: bleu_scores[ds_id] = bleu_scorers[ds_id].score( ) * sample_size[ds_id] except Exception as e: bleu_scores[ds_id] = 0 sample_size = torch.Tensor(sample_size).cuda() bleu_scores = torch.Tensor(bleu_scores).cuda() if args.distributed_world_size > 1: all_reduce(sample_size) all_reduce(bleu_scores) bleu_dict = {} for ds_id in range(num_dataset): if sample_size[ds_id].item() > 0: name = "bleu_" + task.dataset(subset).dataset_names[ds_id] bleu_dict[name] = stats[name] = bleu_scores[ds_id].item( ) / sample_size[ds_id].item() try: train_ds_id = task.dataset( 'train').dataset_names.index( task.dataset(subset).dataset_names[ds_id]) task.dataset('train').student_scores[ train_ds_id] = bleu_dict[name] except ValueError: pass output_path = os.path.join(args.save_dir, 'val_bleu.json') json.dump(bleu_dict, open(output_path, 'w')) progress.print(stats) if summary_writer is not None: summary_writer.log_stats('val/' + subset, stats, trainer.get_num_updates()) valid_losses.append(stats['valid_loss']) return valid_losses