def validate(args, epoch, trainer, criterion, dataset, subset, ngpus): """Evaluate the model on the validation set and return the average loss.""" itr = dataset.dataloader(subset, batch_size=None, max_tokens=args.max_tokens, max_positions=args.max_positions, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test) loss_meter = AverageMeter() desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset) with progress_bar(itr, desc, leave=False) as t: for _, sample in data.skip_group_enumerator(t, ngpus): ntokens = sum(s['ntokens'] for s in sample) loss = trainer.valid_step(sample, criterion) loss_meter.update(loss, ntokens) t.set_postfix(loss='{:.2f}'.format(loss_meter.avg), refresh=False) val_loss = loss_meter.avg t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format( val_loss, math.pow(2, val_loss))) # update and return the learning rate return val_loss
def rollout_generator(self, num_rollouts, samples): masked, unmasked, lengths, mask = samples batch_size, seq_len = samples[0].size() meter = AverageMeter() ppl_meter = defaultdict(lambda: AverageMeter()) self.opt.zero_grad() pbar = _tqdm(num_rollouts, 'generator-rollout') for rollout in pbar: loss, generated, ppl = self.model(masked, lengths, mask, unmasked, tag="g-step") loss = loss.sum() / batch_size loss.backward() meter.update(-1 * loss.item()) # for key in ppl: # ppl[key] = ppl[key].sum() / batch_size # ppl_meter[key].update(ppl[key].item()) self.opt.step() self.logger.log("generator/advantage", self.step, meter.avg) # for key in ppl_meter: # self.logger.log("ppl/{}".format(key), ppl_meter[key].avg) self.debug('train', samples, generated)
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus): """Train the model for one epoch.""" itr = dataset.dataloader( args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens, seed=args.seed, epoch=epoch, max_positions=args.max_positions, sample_without_replacement=args.sample_without_replacement, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test) loss_meter = AverageMeter() bsz_meter = AverageMeter() # sentences per batch wpb_meter = AverageMeter() # words per batch wps_meter = TimeMeter() # words per second clip_meter = AverageMeter() # % of updates clipped gnorm_meter = AverageMeter() # gradient norm desc = '| epoch {:03d}'.format(epoch) lr = trainer.get_lr() with progress_bar(itr, desc, leave=False) as t: for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): loss, grad_norm = trainer.train_step(sample, criterion) ntokens = sum(s['ntokens'] for s in sample) src_size = sum(s['src_tokens'].size(0) for s in sample) loss_meter.update(loss, ntokens) bsz_meter.update(src_size) wpb_meter.update(ntokens) wps_meter.update(ntokens) clip_meter.update(1 if grad_norm > args.clip_norm else 0) gnorm_meter.update(grad_norm) t.set_postfix(collections.OrderedDict([ ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)), ('wps', '{:5d}'.format(round(wps_meter.avg))), ('wpb', '{:5d}'.format(round(wpb_meter.avg))), ('bsz', '{:5d}'.format(round(bsz_meter.avg))), ('lr', lr), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ('gnorm', '{:.4f}'.format(gnorm_meter.avg)), ]), refresh=False) if i == 0: # ignore the first mini-batch in words-per-second calculation wps_meter.reset() if args.save_interval > 0 and (i + 1) % args.save_interval == 0: trainer.save_checkpoint(args, epoch, i + 1) fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}' fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}' fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}' t.write( fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg), round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg), round(bsz_meter.avg), lr, clip_meter.avg * 100, gnorm_meter.avg))
def rollout_critic(self, num_rollouts, samples): masked, unmasked, lengths, mask = samples batch_size, seq_len = samples[0].size() meter = AverageMeter() self.opt.zero_grad() pbar = _tqdm(num_rollouts, 'critic-rollout') for rollout in pbar: loss = self.model(masked, lengths, mask, unmasked, tag="c-step") loss = loss.sum() / batch_size loss.backward() meter.update(loss.item()) self.opt.step() self.logger.log("critic/loss", self.step, meter.avg)
def validate(args, epoch, trainer, dataset, max_positions, subset): """Evaluate the model on the validation set and return the average loss.""" itr = dataset.eval_dataloader( subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=max_positions, skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator ) loss_meter = AverageMeter() nll_loss_meter = AverageMeter() extra_meters = collections.defaultdict(lambda: AverageMeter()) prefix = 'valid on \'{}\' subset'.format(subset) with utils.build_progress_bar(args, itr, epoch, prefix) as t: for _, sample in data.skip_group_enumerator(t, args.num_gpus): loss_dict = trainer.valid_step(sample) ntokens = sum(s['ntokens'] for s in sample) loss = loss_dict['loss'] del loss_dict['loss'] # don't include in extra_meters or extra_postfix if 'nll_loss' in loss_dict: nll_loss = loss_dict['nll_loss'] nll_loss_meter.update(nll_loss, ntokens) loss_meter.update(loss, ntokens) extra_postfix = [] for k, v in loss_dict.items(): extra_meters[k].update(v) extra_postfix.append((k, extra_meters[k].avg)) t.log(collections.OrderedDict([ ('valid loss', round(loss_meter.avg, 2)), ] + extra_postfix)) t.print(collections.OrderedDict([ ('valid loss', round(loss_meter.avg, 2)), ('valid ppl', get_perplexity(nll_loss_meter.avg if nll_loss_meter.count > 0 else loss_meter.avg)), ] + [ (k, meter.avg) for k, meter in extra_meters.items() ])) # update and return the learning rate return loss_meter.avg
def rollout_discriminator(self, num_rollouts, samples): masked, unmasked, lengths, mask = samples real, fake = AverageMeter(), AverageMeter() batch_size, seq_len = samples[0].size() self.opt.zero_grad() pbar = _tqdm(num_rollouts, 'discriminator-rollout') for rollout in pbar: real_loss = self.model(masked, lengths, mask, unmasked, tag="d-step", real=True) real_loss = real_loss.sum() / batch_size with torch.no_grad(): net_output = self.model(masked, lengths, mask, unmasked, tag="g-step") generated = net_output[1] fake_loss = self.model(masked, lengths, mask, generated, tag="d-step", real=False) fake_loss = fake_loss.sum() / batch_size loss = (real_loss + fake_loss) / 2 loss.backward() real.update(real_loss.item()) fake.update(fake_loss.item()) self.opt.step() self.logger.log("discriminator/real", self.step, real.avg) self.logger.log("discriminator/fake", self.step, fake.avg) self.logger.log("discriminator", self.step, real.avg + fake.avg)
def validate(args, epoch, trainer, dataset, subset, ngpus): """Evaluate the model on the validation set and return the average loss.""" itr = dataset.dataloader(subset, batch_size=None, max_tokens=args.max_tokens, max_positions=args.max_positions, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test) loss_meter = AverageMeter() extra_meters = collections.defaultdict(lambda: AverageMeter()) desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset) with progress_bar(itr, desc, leave=False) as t: for _, sample in data.skip_group_enumerator(t, ngpus): loss_dict = trainer.valid_step(sample) loss = loss_dict['loss'] del loss_dict[ 'loss'] # don't include in extra_meters or extra_postfix ntokens = sum(s['ntokens'] for s in sample) loss_meter.update(loss, ntokens) extra_postfix = [] for k, v in loss_dict.items(): extra_meters[k].update(v) extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg))) t.set_postfix(collections.OrderedDict([ ('loss', '{:.2f}'.format(loss_meter.avg)), ] + extra_postfix), refresh=False) val_loss = loss_meter.avg fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format( val_loss, get_perplexity(val_loss)) fmt += ''.join(' | {} {:.4f}'.format(k, meter.avg) for k, meter in extra_meters.items()) t.write(fmt) # update and return the learning rate return val_loss
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus): """Evaluate the model on the validation set and return the average loss.""" itr = dataset.eval_dataloader( subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, descending=True, # largest batch first to warm the caching allocator ) loss_meter = AverageMeter() extra_meters = collections.defaultdict(lambda: AverageMeter()) prefix = 'valid on \'{}\' subset'.format(subset) with utils.build_progress_bar(args, itr, epoch, prefix) as t: for _, sample in data.skip_group_enumerator(t, ngpus): loss_dict = trainer.valid_step(sample) loss = loss_dict['loss'] del loss_dict['loss'] # don't include in extra_meters or extra_postfix ntokens = sum(s['ntokens'] for s in sample) loss_meter.update(loss, ntokens) extra_postfix = [] for k, v in loss_dict.items(): extra_meters[k].update(v) extra_postfix.append((k, extra_meters[k].avg)) t.log(collections.OrderedDict([ ('valid loss', round(loss_meter.avg, 2)), ] + extra_postfix)) t.print(collections.OrderedDict([ ('valid loss', round(loss_meter.avg, 2)), ('valid ppl', get_perplexity(loss_meter.avg)), ] + [ (k, meter.avg) for k, meter in extra_meters.items() ])) # update and return the learning rate return loss_meter.avg
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus): """Evaluate the model on the validation set and return the average loss.""" itr = dataset.dataloader(subset, batch_size=None, max_tokens=args.max_tokens, max_positions=args.max_positions) loss_meter = AverageMeter() rouge_greedy_meter = AverageMeter() rouge_sampled_meter = AverageMeter() desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset) with progress_bar(itr, desc, leave=False) as t: for _, sample in data.skip_group_enumerator(t, ngpus): ntokens = sum(s['ntokens'] for s in sample) loss, mean_rouge_greedy, mean_rouge_sampled = trainer.valid_step( sample, criterion) loss_meter.update(loss, ntokens) rouge_greedy_meter.update(mean_rouge_greedy, 1) rouge_sampled_meter.update(mean_rouge_sampled, 1) t.set_postfix( collections.OrderedDict([ ('loss', '{:.2f}'.format(loss_meter.avg)), ('ROUGE-L/f (greedy)', '{:.4f}'.format(rouge_greedy_meter.avg)), ('ROUGE-L/f (sampled)', '{:.4f}'.format(rouge_sampled_meter.avg)) ])) val_loss = loss_meter.avg t.write( desc + ' | valid loss {:2.2f} | valid ppl {:3.2f} | ROUGE-L (greedy): {:.4f} | ROUGE-L (sampled): {:.4f}' .format(val_loss, math.pow(2, val_loss), rouge_greedy_meter.avg, rouge_sampled_meter.avg)) # update and return the learning rate return val_loss
def train(args, epoch, batch_offset, trainer, dataset, num_gpus): """Train the model for one epoch.""" itr = dataset.dataloader( args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens, seed=args.seed, epoch=epoch, max_positions=args.max_positions, sample_without_replacement=args.sample_without_replacement, skip_invalid_size_inputs_valid_test=args. skip_invalid_size_inputs_valid_test) loss_meter = AverageMeter() bsz_meter = AverageMeter() # sentences per batch wpb_meter = AverageMeter() # words per batch wps_meter = TimeMeter() # words per second clip_meter = AverageMeter() # % of updates clipped extra_meters = collections.defaultdict(lambda: AverageMeter()) desc = '| epoch {:03d}'.format(epoch) trainer.set_seed(args.seed + epoch) lr = trainer.get_lr() with progress_bar(itr, desc, leave=False) as t: for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): loss_dict = trainer.train_step(sample) loss = loss_dict['loss'] del loss_dict[ 'loss'] # don't include in extra_meters or extra_postfix ntokens = sum(s['ntokens'] for s in sample) src_size = sum(s['src_tokens'].size(0) for s in sample) loss_meter.update(loss, ntokens) bsz_meter.update(src_size) wpb_meter.update(ntokens) wps_meter.update(ntokens) clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0) extra_postfix = [] for k, v in loss_dict.items(): extra_meters[k].update(v) extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg))) t.set_postfix(collections.OrderedDict([ ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)), ('wps', '{:5d}'.format(round(wps_meter.avg))), ('wpb', '{:5d}'.format(round(wpb_meter.avg))), ('bsz', '{:5d}'.format(round(bsz_meter.avg))), ('lr', lr), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ] + extra_postfix), refresh=False) if i == 0: # ignore the first mini-batch in words-per-second calculation wps_meter.reset() if args.save_interval > 0 and (i + 1) % args.save_interval == 0: save_checkpoint(trainer, args, epoch, i + 1) fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format( loss_meter.avg, get_perplexity(loss_meter.avg)) fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format( round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg)) fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format( round(bsz_meter.avg), lr, clip_meter.avg * 100) fmt += ''.join(' | {} {:.4f}'.format(k, meter.avg) for k, meter in extra_meters.items()) t.write(fmt)
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): """Train the model for one epoch.""" seed = args.seed + epoch torch.manual_seed(seed) trainer.set_seed(seed) itr = dataset.train_dataloader( args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, seed=seed, epoch=epoch, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch <= args.curriculum)) loss_meter = AverageMeter() bsz_meter = AverageMeter() # sentences per batch wpb_meter = AverageMeter() # words per batch wps_meter = TimeMeter() # words per second clip_meter = AverageMeter() # % of updates clipped extra_meters = collections.defaultdict(lambda: AverageMeter()) lr = trainer.get_lr() with utils.build_progress_bar(args, itr, epoch) as t: for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): loss_dict = trainer.train_step(sample) loss = loss_dict['loss'] del loss_dict[ 'loss'] # don't include in extra_meters or extra_postfix ntokens = sum(s['ntokens'] for s in sample) nsentences = sum(s['src_tokens'].size(0) for s in sample) loss_meter.update(loss, nsentences if args.sentence_avg else ntokens) bsz_meter.update(nsentences) wpb_meter.update(ntokens) wps_meter.update(ntokens) clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0) extra_postfix = [] for k, v in loss_dict.items(): extra_meters[k].update(v) extra_postfix.append((k, extra_meters[k].avg)) t.log( collections.OrderedDict([ ('loss', loss_meter), ('wps', round(wps_meter.avg)), ('wpb', round(wpb_meter.avg)), ('bsz', round(bsz_meter.avg)), ('lr', lr), ('clip', '{:.0%}'.format(clip_meter.avg)), ] + extra_postfix)) if i == 0: # ignore the first mini-batch in words-per-second calculation wps_meter.reset() if args.save_interval > 0 and (i + 1) % args.save_interval == 0: save_checkpoint(trainer, args, epoch, i + 1) t.print( collections.OrderedDict([ ('train loss', round(loss_meter.avg, 2)), ('train ppl', get_perplexity(loss_meter.avg)), ('s/checkpoint', round(wps_meter.elapsed_time)), ('words/s', round(wps_meter.avg)), ('words/batch', round(wpb_meter.avg)), ('bsz', round(bsz_meter.avg)), ('lr', lr), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ] + [(k, meter.avg) for k, meter in extra_meters.items()]))
def validate(val_loader, r, epoch): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode n = r.num_iterations(loader_size=len(val_loader)) if args.num_minibatches is not None: n = min(n, args.num_minibatches) r.eval(n) if not is_first_stage(): val_loader = None r.set_loader(val_loader) end = time.time() epoch_start_time = time.time() if args.no_input_pipelining: num_warmup_minibatches = 0 else: num_warmup_minibatches = r.num_warmup_minibatches if args.verbose_frequency > 0: print("Letting in %d warm-up minibatches" % num_warmup_minibatches) print("Running validation for %d minibatches" % n) with torch.no_grad(): for i in range(num_warmup_minibatches): r.run_forward() for i in range(n - num_warmup_minibatches): # perform forward pass r.run_forward() r.run_ack() if is_last_stage(): output, target, loss, num_tokens = r.output, r.target, r.loss.item( ), r.num_tokens() # measure accuracy and record loss # prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss, output.size(0)) # top1.update(prec1[0], output.size(0)) # top5.update(prec5[0], output.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print( 'Test: [{0}][{1}/{2}]\t' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Memory: {memory:.3f} ({cached_memory:.3f})\t' 'Loss: {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch, i, n, batch_time=batch_time, loss=losses, memory=(float(torch.cuda.memory_allocated()) / 10**9), cached_memory=(float(torch.cuda.memory_cached()) / 10**9))) import sys sys.stdout.flush() if is_last_stage(): print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format( top1=top1, top5=top5)) for i in range(num_warmup_minibatches): r.run_ack() # wait for all helper threads to complete r.wait() print('Epoch %d: %.3f seconds' % (epoch, time.time() - epoch_start_time)) print("Epoch start time: %.3f, epoch end time: %.3f" % (epoch_start_time, time.time())) return top1.avg
def train(train_loader, r, optimizer, epoch, lr_scheduler): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode n = 10000 # n = r.num_iterations(loader_size=len(train_loader)) if args.num_minibatches is not None: n = min(n, args.num_minibatches) # accumulation = 32 accumulation = n n -= (n % accumulation) assert n % accumulation == 0 r.train(n) if not is_first_stage(): train_loader = None r.set_loader(train_loader) end = time.time() epoch_start_time = time.time() if args.no_input_pipelining: num_warmup_minibatches = 0 else: num_warmup_minibatches = r.num_warmup_minibatches if args.verbose_frequency > 0: print("Letting in %d warm-up minibatches" % num_warmup_minibatches) print("Running training for %d minibatches" % n) r.set_loss_scale(4 / accumulation) total_updates = n // accumulation for t in range(total_updates): # start num_warmup_minibatches forward passes for i in range(num_warmup_minibatches): r.run_forward() for i in range(accumulation - num_warmup_minibatches): end = time.time() # perform forward pass r.run_forward() if is_last_stage(): # measure accuracy and record loss output, target, loss, num_tokens = r.output, r.target, r.loss.item( ), r.num_tokens() # print(loss, num_tokens) losses.update(loss / num_tokens / math.log(2), num_tokens) # perform backward pass r.run_backward() if is_last_stage(): # measure elapsed time batch_time.update(time.time() - end) end = time.time() epoch_time = (end - epoch_start_time) / 3600.0 full_epoch_time = (epoch_time / float(accumulation * t + i + 1)) * float(n) if (t * accumulation + i + num_warmup_minibatches) % args.print_freq == 0: print( 'Stage: [{0}] Epoch: [{1}][{2}/{3}]\t' 'Time({timestamp}): {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Epoch time [hr]: {epoch_time:.3f} ({full_epoch_time:.3f})\t' 'Memory: {memory:.3f} ({cached_memory:.3f})\t' 'Loss: {loss.val:.4f} ({loss.avg:.4f})\t'.format( args.stage, epoch, accumulation * t + i + 1, n, timestamp=time.time(), batch_time=batch_time, epoch_time=epoch_time, full_epoch_time=full_epoch_time, loss=losses, # top1=top1, top5=top5, memory=(float(torch.cuda.memory_allocated()) / 10**9), cached_memory=(float(torch.cuda.memory_cached()) / 10**9))) import sys sys.stdout.flush() else: if i + num_warmup_minibatches == accumulation - 1 and ( t * accumulation + i + num_warmup_minibatches ) % args.print_freq < accumulation: print( 'Stage: [{0}] Epoch: [{1}][{2}/{3}]\tMemory: {memory:.3f} ({cached_memory:.3f})' .format( args.stage, epoch, accumulation * t + i + 1, n, memory=(float(torch.cuda.memory_allocated()) / 10**9), cached_memory=(float(torch.cuda.memory_cached()) / 10**9))) import sys sys.stdout.flush() # if i == 500 and args.local_rank == 0: # subprocess.Popen(['python', 'usage.py', 'gpu.log']) # finish remaining backward passes for i in range(num_warmup_minibatches): r.run_backward() # optimizer.step() if args.fp16: r.zero_grad() else: optimizer.zero_grad() num_updates = epoch * total_updates + t + 1 lr_scheduler.step_update(num_updates) # wait for all helper threads to complete r.wait() print("Epoch %d: %.3f seconds" % (epoch, time.time() - epoch_start_time)) print("Epoch start time: %.3f, epoch end time: %.3f" % (epoch_start_time, time.time()))
class DDPTrainer(): """Main class for data parallel training. This class supports data parallel training, where multiple workers each have a full model replica and gradients are accumulated synchronously via torch.distributed.all_reduce. """ def __init__(self, args, model): if not torch.cuda.is_available(): raise NotImplementedError('Training on CPU is not supported') self.args = args self.model = model.cuda() self.criterion = CRITERION_REGISTRY[args.criterion](args).cuda() self.optimizer = optim.build_optimizer(self.args, self.model.parameters()) self.lr_scheduler = lr_scheduler.build_lr_scheduler( self.args, self.optimizer) self.scaler = amp.GradScaler(enabled=self.args.amp, init_scale=2**15) if self.args.distributed_world_size > 1: self.model = DDP(model) self._buffered_stats = defaultdict(lambda: []) self._num_updates = 0 self._optim_history = None self.throughput_meter = TimeMeter() self.avg_loss_meter = AverageMeter() def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" if distributed_utils.is_master(self.args): # only save one checkpoint utils.save_state( filename, self.args, self.get_model(), self.criterion, self.optimizer, self.lr_scheduler, self._num_updates, self._optim_history, extra_state, ) def load_checkpoint(self, filename, load_optim=True): """Load all training state from a checkpoint file.""" extra_state, optim_history, last_optim_state = \ utils.load_model_state(filename, self.get_model()) if last_optim_state is not None: # rebuild optimizer after loading model, since params may have changed #self.optimizer = optim.build_optimizer(self.args, self.model.parameters()) self.lr_scheduler = lr_scheduler.build_lr_scheduler( self.args, self.optimizer) if load_optim: self._optim_history = optim_history # only reload optimizer and lr_scheduler if they match last_optim = self._optim_history[-1] if last_optim[ 'criterion_name'] == self.criterion.__class__.__name__: self.lr_scheduler.load_state_dict( last_optim['lr_scheduler_state']) if last_optim[ 'optimizer_name'] == self.optimizer.__class__.__name__: self.optimizer.load_state_dict(last_optim_state) self._num_updates = last_optim['num_updates'] return extra_state def train_step(self, sample, update_params=True, last_step=False): """Do forward, backward and parameter update.""" # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints seed = self.args.seed + self.get_num_updates() torch.manual_seed(seed) torch.cuda.manual_seed(seed) self.model.train() if isinstance(self.model, DDP): if last_step: self.model.disable_allreduce() else: self.model.enable_allreduce() # forward and backward pass sample = self._prepare_sample(sample) loss, oom_fwd = self._forward(sample) # If this is a last batch forward pass is skipped on some workers # Batch with sample_size 0 is not accounted for in weighted loss logging_output = { 'ntokens': sample['ntokens'] if sample is not None else 0, 'nsentences': sample['target'].size(0) if sample is not None else 0, 'loss': utils.item(loss.data) if loss is not None else 0, } sample_size = sample['ntokens'] if sample is not None else 0 oom_bwd = self._backward(loss) # buffer stats and logging outputs self._buffered_stats['sample_sizes'].append(sample_size) self._buffered_stats['logging_outputs'].append(logging_output) self._buffered_stats['ooms_fwd'].append(oom_fwd) self._buffered_stats['ooms_bwd'].append(oom_bwd) # update parameters if update_params and not last_step: # gather logging outputs from all replicas sample_sizes = self._buffered_stats['sample_sizes'] logging_outputs = self._buffered_stats['logging_outputs'] ooms_fwd = self._buffered_stats['ooms_fwd'] ooms_bwd = self._buffered_stats['ooms_bwd'] if self.args.distributed_world_size > 1: sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map( lambda l: list(chain.from_iterable(l)), zip(*distributed_utils.all_gather_list((sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)))) ooms_fwd = sum(ooms_fwd) ooms_bwd = sum(ooms_bwd) ooms = ooms_fwd + ooms_bwd # this is always <= distributed_world_size if ooms == self.args.distributed_world_size: print('| WARNING: OOM in all workers, skipping batch') self.zero_grad() return # aggregate stats and logging outputs grad_denom = sum(sample_sizes) for p in self.model.parameters(): if p.requires_grad and p.grad is not None: p.grad /= grad_denom self._opt() # Handle logging ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) self.throughput_meter.update(ntokens) info_log_data = { 'tokens/s': self.throughput_meter.avg, 'tokens': ntokens, 'loss': sum(log.get('loss', 0) for log in logging_outputs) / ntokens / math.log(2) } self.avg_loss_meter.update(info_log_data['loss']) debug_log_data = { 'batch_size': sum(log.get('nsentences', 0) for log in logging_outputs), 'lr': self.get_lr(), 'grad_denom': grad_denom, 'updates': 1 } DLLogger.log(step=self._num_updates, data=info_log_data, verbosity=0) DLLogger.log(step=self._num_updates, data=debug_log_data, verbosity=1) self.clear_buffered_stats() def _forward(self, sample): loss = None oom = 0 try: if sample is not None: with amp.autocast(enabled=self.args.amp): # calculate loss and sample size logits, _ = self.model(**sample['net_input']) target = sample['target'] probs = F.log_softmax(logits, dim=-1, dtype=torch.float32) loss = self.criterion(probs, target) except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory in worker {}, skipping batch' .format(self.args.distributed_rank), force=True) oom = 1 loss = None else: raise e return loss, oom def _backward(self, loss): oom = 0 if loss is not None: try: self.scaler.scale(loss).backward() except RuntimeError as e: if 'out of memory' in str(e): print( '| WARNING: ran out of memory in worker {}, skipping batch' .format(self.args.distributed_rank), force=True) oom = 1 self.zero_grad() else: raise e return oom def _opt(self): # take an optimization step self.scaler.step(self.optimizer.optimizer) self.scaler.update() self.zero_grad() self._num_updates += 1 # update learning rate self.lr_scheduler.step_update(self._num_updates) def valid_step(self, sample): """Do forward pass in evaluation mode.""" self.model.eval() # forward pass sample = self._prepare_sample(sample) with torch.no_grad(): loss, oom_fwd = self._forward(sample) logging_output = { 'ntokens': sample['ntokens'] if sample is not None else 0, 'nsentences': sample['target'].size(0) if sample is not None else 0, } loss = loss.item() if loss is not None else 0 assert not oom_fwd, 'Ran out of memory during validation' # gather logging outputs from all GPUs if self.args.distributed_world_size > 1: losses, logging_outputs = zip( *distributed_utils.all_gather_list((loss, logging_output))) else: losses = [loss] logging_outputs = [logging_output] weight = sum(log.get('ntokens', 0) for log in logging_outputs) scaled_loss = sum(losses) / weight / math.log(2) return scaled_loss def dummy_train_step(self, dummy_batch): """Dummy training step for warming caching allocator.""" self.train_step(dummy_batch, update_params=False) self.zero_grad() self.clear_buffered_stats() def zero_grad(self): self.optimizer.zero_grad() def clear_buffered_stats(self): self._buffered_stats.clear() def lr_step(self, epoch, val_loss=None): """Adjust the learning rate based on the validation loss.""" return self.lr_scheduler.step(epoch, val_loss) def lr_step_update(self, num_updates): """Update the learning rate after each update.""" return self.lr_scheduler.step_update(num_updates) def get_lr(self): """Get the current learning rate.""" return self.optimizer.get_lr() def get_throughput_meter(self): """Get the throughput meter""" return self.throughput_meter def get_model(self): """Get the model replica.""" return self.model.module if isinstance(self.model, DDP) else self.model def get_num_updates(self): """Get the number of parameters updates.""" return self._num_updates def _prepare_sample(self, sample): if not sample: return None return utils.move_to_cuda(sample)
def main(): args = parser.parse_args() # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) model = TransformerModel.build_model(args, task).cuda() criterion = task.build_criterion(args).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=eval(args.adam_betas), eps=args.adam_eps, weight_decay=args.weight_decay) # Load dataset splits load_dataset_splits(task, ['train', 'valid']) epoch_itr = data.EpochBatchIterator( dataset=task.dataset(args.train_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=(args.max_source_positions, args.max_target_positions), ignore_invalid_inputs=True, required_batch_size_multiple=8, seed=1, num_shards=1, shard_id=0, ) losses = AverageMeter() encoder_layer_forward = [ AverageMeter() for _ in range(len(model.encoder.layers[0].layer)) ] decoder_layer_forward = [ AverageMeter() for _ in range(len(model.decoder.layers[0].layer)) ] encoder_layer_backward = [ AverageMeter() for _ in range(len(model.encoder.layers[0].layer)) ] decoder_layer_backward = [ AverageMeter() for _ in range(len(model.decoder.layers[0].layer)) ] def measure_hook(forward, backward): def hook(module, input, output): for i, layer in enumerate(module.layer): if len(input) == 2: x, _ = input else: x, = input x = x.detach().clone().requires_grad_() # warm-up for _ in range(5): if isinstance(layer, nn.MultiheadAttention): out, _ = layer(x, x, x) else: out = layer(x) torch.autograd.backward(out, out) starter, ender = torch.cuda.Event( enable_timing=True), torch.cuda.Event(enable_timing=True) for _ in range(50): starter.record() if isinstance(layer, nn.MultiheadAttention): out, _ = layer(x, x, x) else: out = layer(x) ender.record() torch.cuda.synchronize() forward[i].update(starter.elapsed_time(ender)) starter.record() torch.autograd.backward(out, out) ender.record() torch.cuda.synchronize() backward[i].update(starter.elapsed_time(ender)) return hook for layer in model.encoder.layers: layer.register_forward_hook( measure_hook(encoder_layer_forward, encoder_layer_backward)) for layer in model.decoder.layers: layer.register_forward_hook( measure_hook(decoder_layer_forward, decoder_layer_backward)) embed_forward = AverageMeter() embed_backward = AverageMeter() def embed_hook(module, input, output): tokens, _ = input # warm-up for _ in range(5): x = module.embed_scale * module.embed_tokens(tokens) x += module.embed_positions(tokens) torch.autograd.backward(x, x) starter, ender = torch.cuda.Event( enable_timing=True), torch.cuda.Event(enable_timing=True) for _ in range(50): starter.record() x = module.embed_scale * module.embed_tokens(tokens) x += module.embed_positions(tokens) ender.record() torch.cuda.synchronize() embed_forward.update(starter.elapsed_time(ender)) starter.record() torch.autograd.backward(x, x) ender.record() torch.cuda.synchronize() embed_backward.update(starter.elapsed_time(ender)) model.encoder.register_forward_hook(embed_hook) linear_forward = AverageMeter() linear_backward = AverageMeter() def linear_hook(module, input, output): _, encode_out = input encode_out = encode_out.detach().clone().requires_grad_() # warm-up for _ in range(5): x = encode_out.transpose(0, 1) out = F.linear(x, module.embed_out) torch.autograd.backward(out, out) starter, ender = torch.cuda.Event( enable_timing=True), torch.cuda.Event(enable_timing=True) for _ in range(50): starter.record() x = encode_out.transpose(0, 1) out = F.linear(x, module.embed_out) ender.record() torch.cuda.synchronize() linear_forward.update(starter.elapsed_time(ender)) starter.record() torch.autograd.backward(out, out) ender.record() torch.cuda.synchronize() linear_backward.update(starter.elapsed_time(ender)) model.decoder.register_forward_hook(linear_hook) itr = epoch_itr.next_epoch_itr() max_positions = (args.max_source_positions, args.max_target_positions) for i, sample in enumerate(itr): sample = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) sample = utils.move_to_cuda(sample) loss, _, logging_output = criterion(model, sample) num_tokens = logging_output['ntokens'] losses.update(loss.item() / num_tokens / math.log(2), num_tokens) if i % 100 == 0: print('Loss: {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses)) print( 'Time: {forward_time.avg:.3f} ({backward_time.avg:.3f})' '{forward_time_decoder.avg:.3f} ({backward_time_decoder.avg:.3f})' .format(forward_time=encoder_layer_forward[0], backward_time=encoder_layer_backward[0], forward_time_decoder=decoder_layer_forward[-1], backward_time_decoder=decoder_layer_backward[-1])) loss.backward() optimizer.step() optimizer.zero_grad() break stat = {i: {} for i in range(len(decoder_layer_forward))} for i, (f, b) in enumerate(zip(encoder_layer_forward, encoder_layer_backward)): stat[i]['encoder'] = {} stat[i]['encoder']['forward'] = f.avg stat[i]['encoder']['backward'] = b.avg for i, (f, b) in enumerate(zip(decoder_layer_forward, decoder_layer_backward)): stat[i]['decoder'] = {} stat[i]['decoder']['forward'] = f.avg stat[i]['decoder']['backward'] = b.avg stat['embed'] = {} stat['embed']['forward'] = embed_forward.avg stat['embed']['backward'] = embed_backward.avg stat['linear'] = {} stat['linear']['forward'] = linear_forward.avg stat['linear']['backward'] = linear_backward.avg with open('time.json', 'w') as file: json.dump(stat, file, indent=4)
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus): """Train the model for one epoch.""" seed = args.seed + epoch torch.manual_seed(seed) trainer.set_seed(seed) itr = dataset.train_dataloader( args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, seed=seed, epoch=epoch, sample_without_replacement=args.sample_without_replacement, sort_by_source_size=(epoch <= args.curriculum)) loss_meter = AverageMeter() bsz_meter = AverageMeter() # sentences per batch wpb_meter = AverageMeter() # words per batch wps_meter = TimeMeter() # words per second clip_meter = AverageMeter() # % of updates clipped extra_meters = collections.defaultdict(lambda: AverageMeter()) lr = trainer.get_lr() with utils.build_progress_bar(args, itr, epoch) as t: for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): loss_dict = trainer.train_step(sample) loss = loss_dict['loss'] del loss_dict['loss'] # don't include in extra_meters or extra_postfix ntokens = sum(s['ntokens'] for s in sample) nsentences = sum(s['src_tokens'].size(0) for s in sample) loss_meter.update(loss, nsentences if args.sentence_avg else ntokens) bsz_meter.update(nsentences) wpb_meter.update(ntokens) wps_meter.update(ntokens) clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0) extra_postfix = [] for k, v in loss_dict.items(): extra_meters[k].update(v) extra_postfix.append((k, extra_meters[k].avg)) t.log(collections.OrderedDict([ ('loss', loss_meter), ('wps', round(wps_meter.avg)), ('wpb', round(wpb_meter.avg)), ('bsz', round(bsz_meter.avg)), ('lr', lr), ('clip', '{:.0%}'.format(clip_meter.avg)), ] + extra_postfix)) if i == 0: # ignore the first mini-batch in words-per-second calculation wps_meter.reset() if args.save_interval > 0 and (i + 1) % args.save_interval == 0: save_checkpoint(trainer, args, epoch, i + 1) t.print(collections.OrderedDict([ ('train loss', round(loss_meter.avg, 2)), ('train ppl', get_perplexity(loss_meter.avg)), ('s/checkpoint', round(wps_meter.elapsed_time)), ('words/s', round(wps_meter.avg)), ('words/batch', round(wpb_meter.avg)), ('bsz', round(bsz_meter.avg)), ('lr', lr), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ] + [ (k, meter.avg) for k, meter in extra_meters.items() ]))
def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, stream=output_file, ) logger = logging.getLogger('fairseq_cli.generate') utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(os.pathsep), arg_overrides=eval(args.model_overrides), task=task, ) # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() from fairseq.sequence_scorer import SequenceScorer generator = task.build_generator(args) # Generate and compute BLEU score if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) num_sentences = 0 has_target = True avg_ranks = AverageMeter() with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() all_ents = [] for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens) if 'ents' in sample: all_ents.extend(sample['ents']) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None # Remove padding src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) target_tokens = None if has_target: target_tokens = utils.strip_pad( sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) if not args.quiet: if src_dict is not None: print('S-{}\t{}'.format(sample_id, src_str), file=output_file) if has_target: print('T-{}\t{}'.format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: score = hypo['score'] / math.log( 2) # convert to base 2 print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file) print( 'P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 hypo['positional_scores'].div_( math.log(2)).tolist(), ))), file=output_file) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join([ '{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment ])), file=output_file) if args.print_step: print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file) if getattr(args, 'retain_iter_history', False): for step, h in enumerate(hypo['history']): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h['tokens'].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print('E-{}_{}\t{}'.format( sample_id, step, h_str), file=output_file) if getattr(args, 'score_reference', False): print('R-{}\t{}'.format( sample_id, '{:.4f}'.format(hypo['avg_ranks'])), file=output_file) # Score only the top hypothesis if getattr(args, 'score_reference', False): avg_ranks.update(hypo['avg_ranks']) if has_target and j == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True) if hasattr(scorer, 'add_string'): scorer.add_string(target_str, hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info( 'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: logger.info('Generate {} with beam={}: {}'.format( args.gen_subset, args.beam, scorer.result_string())) if getattr(args, 'score_reference', False): logger.info('Average rank of reference={:.4f}, Entropy={:.4f}'.format( avg_ranks.avg, torch.cat(all_ents, dim=0).mean())) return scorer
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus): """Train the model for one epoch.""" itr = dataset.dataloader( args.train_subset, num_workers=args.workers, max_tokens=args.max_tokens, seed=args.seed, epoch=epoch, max_positions=args.max_positions, sample_without_replacement=args.sample_without_replacement) ###print("itr:"+str(itr)) loss_meter = AverageMeter() bsz_meter = AverageMeter() # sentences per batch wpb_meter = AverageMeter() # words per batch wps_meter = TimeMeter() # words per second clip_meter = AverageMeter() # % of updates clipped gnorm_meter = AverageMeter() # gradient norm desc = '| epoch {:03d}'.format(epoch) lr = trainer.get_lr() with progress_bar(itr, desc, leave=False) as t: for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): ###print("i:"+str(i)+" sample:"+str(sample)) ###id,src_tokens,input_tokens,input_positions,target,src_positions,ntokens ###print("i:"+str(i)+" sample len:"+str(len(sample))+" sample id:"+str(sample[0]['id'])+" sample src_tokens:"+str(sample[0]['src_tokens'][0])) aggregate_res = trainer.train_step(sample, criterion) mixed_loss = aggregate_res.loss ml_loss = aggregate_res.ml_loss grad_norm = aggregate_res.grad_norm mixed_loss = aggregate_res.loss rl_loss = aggregate_res.rl_loss mean_rouge_greedy = aggregate_res.mean_rouge_greedy mean_rouge_sampled = aggregate_res.mean_rouge_sampled mean_sum_log_prob = aggregate_res.mean_sum_log_prob ntokens = sum(s['ntokens'] for s in sample) src_size = sum(s['src_tokens'].size(0) for s in sample) loss_meter.update(ml_loss, ntokens) bsz_meter.update(src_size) wpb_meter.update(ntokens) wps_meter.update(ntokens) clip_meter.update(1 if grad_norm > args.clip_norm else 0) gnorm_meter.update(grad_norm) t.set_postfix( collections.OrderedDict([ ('loss', '{:.2f} ({:.2f})'.format(ml_loss, loss_meter.avg)), ('wps', '{:5d}'.format(round(wps_meter.avg))), ('wpb', '{:5d}'.format(round(wpb_meter.avg))), ('bsz', '{:5d}'.format(round(bsz_meter.avg))), ('lr', lr), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ('gnorm', '{:.4f}'.format(gnorm_meter.avg)), ])) if args.enable_rl: fmt_other = 'mixed_loss: {:^10.4f} | ml_loss: {:^10.4f}' fmt_other += '| rl_loss: {:^10.4f} | mean_rouge_greedy: {:^10.4f}' fmt_other += '| mean_rouge_sampled: {:^10.4f} | mean_sum_log_prob: {:^10.4f}' print( fmt_other.format(mixed_loss, ml_loss, rl_loss, mean_rouge_greedy, mean_rouge_sampled, mean_sum_log_prob)) if i == 0: # ignore the first mini-batch in words-per-second calculation wps_meter.reset() if args.save_interval > 0 and (i + 1) % args.save_interval == 0: trainer.save_checkpoint(args, epoch, i + 1) fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}' fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}' fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}' t.write( fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg), round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg), round(bsz_meter.avg), lr, clip_meter.avg * 100, gnorm_meter.avg))