def evaluate(eval_iter, split, train_step=-1): global best_val_loss eval_start_time = time.time() # Turn on evaluation mode which disables dropout. model.eval() # Have to unwrap twice: DDP & FP16 model_to_reset = model.module.module if args.fp16 else model.module # If the model does not use memory at all, make the ext_len longer. # Otherwise, make the mem_len longer and keep the ext_len the same. if args.mem_len == 0: model_to_reset.reset_length( args.eval_tgt_len, args.ext_len + args.tgt_len - args.eval_tgt_len, args.mem_len) else: model_to_reset.reset_length( args.eval_tgt_len, args.ext_len, args.mem_len + args.tgt_len - args.eval_tgt_len) # Evaluation total_len, total_loss = 0, 0. with torch.no_grad(): mems = tuple() bar = tqdm.tqdm(eval_iter, leave=False, desc="Eval") for i, (data, target, seq_len) in enumerate(bar): if args.max_eval_steps > 0: if i >= args.max_eval_steps: break ret = model(data, target, *mems) loss, mems = ret[0], ret[1:] loss = loss.mean() bar.set_description(f'Eval loss {loss:.2f}') total_loss += seq_len * loss.float().item() total_len += seq_len # Switch back to the training mode model_to_reset.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.train() # Log all the things. mean_loss = total_loss / total_len logger.info('-' * 100) log_str = ( f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | ' + f'time: {time.time() - eval_start_time:5.2f}s ' + f'| {split} loss {mean_loss:5.2f}') if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {mean_loss / math.log(2):9.5f}' else: log_str += f' | {split} ppl {math.exp(mean_loss):9.3f}' logger.info(log_str) logger.info('-' * 100) log_tb(f'learning/{split}_loss', mean_loss) log_tb(f'learning/{split}_ppl', math.exp(mean_loss)) # Update checkpoint if validation loss improved. if split == 'val' and (not best_val_loss or mean_loss < best_val_loss): logger.info('Saving checkpoint for new best loss') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix='best') best_val_loss = mean_loss
def evaluate_and_log(optimizer, eval_iter, split, train_step=-1): global best_val_loss eval_start_time = time.time() # Have to unwrap DDP & FP16, if using. def unwrap(module): if isinstance(module, MemTransformerLM): return module return unwrap(module.module) model_to_reset = unwrap(model) # If the model does not use memory at all, make the ext_len longer. # Otherwise, make the mem_len longer and keep the ext_len the same. if args.mem_len == 0: model_to_reset.reset_length( args.eval_tgt_len, args.ext_len + args.tgt_len - args.eval_tgt_len, args.mem_len) else: model_to_reset.reset_length( args.eval_tgt_len, args.ext_len, args.mem_len + args.tgt_len - args.eval_tgt_len) total_loss, total_len = evaluate(model, eval_iter, split, args.max_eval_steps) # Switch back to the training mode model_to_reset.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.train() # Log all the things. mean_loss = total_loss / total_len logger.info('-' * 100) log_str = ( f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | ' + f'time: {time.time() - eval_start_time:5.2f}s ' + f'| {split} loss {mean_loss:5.2f}') if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {mean_loss / math.log(2):9.5f}' else: log_str += f' | {split} ppl {math.exp(mean_loss):9.3f}' logger.info(log_str) logger.info('-' * 100) log_tb(f'learning/{split}_loss', mean_loss) log_tb(f'learning/{split}_ppl', math.exp(mean_loss)) # Update checkpoint if validation loss improved. if split == 'val' and (not best_val_loss or mean_loss < best_val_loss): logger.info('Saving checkpoint for new best loss') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix='best') best_val_loss = mean_loss
def main_loop(): util.cancel_shutdown() losses = [] args = g.args if not args.local: g.logger.info( f'Distributed initializing process group with ' f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}') dist.init_process_group( backend=args.dist_backend, #init_method=args.dist_url, #world_size=util.get_world_size() ) assert (util.get_world_size() == dist.get_world_size()) g.logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) g.logger.info("creating new model") g.state = TrainState(args) g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, freeze_below=args.freeze_below) g.state.model.to(g.device) optimizer_setup(g.state) if args.checkpoint: if args.checkpoint_secondary: g.logger.info(f"restoring extra checkpoint") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint_secondary, args.optim_state_dict) g.logger.info(f"Restoring model from {args.checkpoint}" + f" and optimizer from {args.optim_state_dict}" if args. optim_state_dict else "") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint, args.optim_state_dict) else: g.state.model.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing g.state.model.word_emb.apply(weights_init) model: MemTransformerLM = g.state.model optimizer = g.state.optimizer if g.state.args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer( optimizer, static_loss_scale=g.state.args.static_loss_scale, dynamic_loss_scale=g.state.args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) # log model info # n_all_param = sum([p.nelement() for p in model.parameters()]) # log_tb('sizes/params', n_all_param) # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) # log_tb('sizes/non_emb_params', n_nonemb_param) # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) else: assert args.scheduler == 'constant' g.state.scheduler = util.NoOp() # Setup distributed model if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) # , find_unused_parameters=True) if util.get_global_rank() == 0: if not args.test: wandb.config.update(vars(args)) # wandb.watch(model) g.event_writer.add_text('args', str(args)) # TODO: replace with log_tb accumulated_loss = 0 # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=g.state.last_epoch): print(f"epoch -- {epoch}, token_count -- {g.state.token_count}") model.train() log_tb('sizes/batch_size', args.batch_size) log_tb('sizes/seq_size', args.tgt_len) if g.state.partial_epoch: # reuse previously loaded tr_iter and states assert g.state.tr_iter is not None assert g.state.mems is not None else: g.state.tr_iter = g.corpus.get_dist_iterator( 'train', rank=util.get_global_rank(), max_rank=util.get_world_size(), bsz=args.batch_size, bptt=args.tgt_len, device=g.device, ext_len=args.ext_len, skip_files=g.args.skip_files) g.state.mems = tuple() g.state.last_epoch = epoch log_start_time = time.time() tokens_per_epoch = 0 for batch, (data, target, seq_len) in enumerate(g.state.tr_iter): # assert seq_len == data.shape[0] # for i in range(1, data.shape[0]): # assert torch.all(torch.eq(data[i], target[i - 1])) # break # print(g.state.token_count, data) if g.state.train_step % args.eval_interval == 0: evaluate_and_log(model, g.va_iter, 'val_short-mem-1', generate_text=False, reset_mems_interval=1) evaluate_and_log(model, g.va_iter, 'val_short-mem-2', generate_text=False, reset_mems_interval=2) evaluate_and_log(model, g.va_iter, 'val_short-mem-3', generate_text=False, reset_mems_interval=3) evaluate_and_log(model, g.va_iter, 'val') if g.va_custom_iter: evaluate_and_log(g.state.model, g.va_custom_iter, g.args.valid_custom, generate_text=False) batch_total = torch.tensor(data.shape[1]).to(g.device) if args.local: # TODO(y): factor out (need way to see if dist was inited) batch_total = batch_total.sum() else: batch_total = util.dist_sum_tensor( batch_total) # global batch size batch_total = util.toscalar(batch_total) should_log = (g.state.train_step < args.verbose_log_steps) or \ (g.state.train_step + 1) % args.log_interval == 0 model.zero_grad() ret = model(data, target, *g.state.mems) loss, g.state.mems = ret[0], ret[1:] loss: torch.Tensor = loss.float().mean().type_as(loss) with timeit('backwards', noop=not should_log): if args.fp16: optimizer.backward(loss) else: loss.backward() loss0 = util.toscalar(loss) util.record('loss', loss0) util.record('params', torch.sum(util.flat_param(model)).item()) losses.append(loss0) accumulated_loss += loss0 if args.fp16: optimizer.clip_master_grads(args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) # step-wise learning rate annealing if hasattr(optimizer, 'overflow') and optimizer.overflow: g.logger.info("skipped iteration") else: if args.scheduler in ['cosine', 'constant', 'dev_perf']: # linear warmup stage if g.state.token_count < args.warmup_tokens: curr_lr = args.lr * float( g.state.token_count) / args.warmup_tokens optimizer.param_groups[0]['lr'] = curr_lr elif args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler.step(g.state.token_count // 1000 // 1000) else: g.state.scheduler.step(g.state.token_count) optimizer.step() g.state.train_step += 1 consumed_tokens = data.shape[0] * data.shape[1] world_size = int(os.environ.get("WORLD_SIZE", "8")) if world_size > 8: # correction factor for multiple machines consumed_tokens = consumed_tokens * (world_size // 8) tokens_per_epoch += consumed_tokens g.state.token_count += consumed_tokens g.token_count = g.state.token_count if g.state.token_count >= args.max_tokens: g.state.partial_epoch = True raise StopIteration # break out of parent train loop if should_log: elapsed_time = time.time() - log_start_time elapsed_steps = g.state.train_step - g.state.last_log_step # compute average loss over last logging interval cur_loss = accumulated_loss / elapsed_steps cur_loss_mean = util.dist_mean(cur_loss) log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \ f'| {batch:>6d} batches ' \ f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \ f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \ f'| loss {cur_loss:5.2f}' if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {cur_loss / math.log(2):9.5f}' else: log_str += f' | ppl {math.exp(cur_loss):9.3f}' g.logger.info(log_str) log_tb('learning/epoch', epoch) log_tb('_loss', cur_loss_mean) # the most important thing log_tb('learning/loss', cur_loss_mean) log_tb('learning/ppl', math.exp(cur_loss_mean)) # currently step timings are not synchronized in multi-machine # case (see #4). Can add torch.distributed.barrier() to get # more accurate timings, but this may add slowness. log_tb('times/step', 1000 * elapsed_time / elapsed_steps) current_lr = optimizer.param_groups[0]['lr'] log_tb('learning/lr', current_lr) # 32 is the "canonical" batch size linear_scaling_factor = batch_total / 32 # TODO(y): merge logic from master log_tb('learning/base_lr', current_lr / linear_scaling_factor) if args.optim == 'lamb': log_lamb_rs(optimizer, g.event_writer, g.state.token_count) time_per_batch = elapsed_time / elapsed_steps time_per_sample = time_per_batch / args.batch_size time_per_token = time_per_sample / args.tgt_len log_tb('times/batches_per_sec', 1 / time_per_batch) log_tb('times/samples_per_sec', 1 / time_per_sample) log_tb('times/tokens_per_sec', 1 / time_per_token) if str(g.device) == 'cuda': log_tb("memory/allocated_gb", torch.cuda.memory_allocated() / 1e9) log_tb("memory/max_allocated_gb", torch.cuda.max_memory_allocated() / 1e9) log_tb("memory/cached_gb", torch.cuda.memory_cached() / 1e9) log_tb("memory/max_cached_gb", torch.cuda.max_memory_cached() / 1e9) accumulated_loss = 0 log_start_time = time.time() g.state.last_log_step = g.state.train_step if args.checkpoint_each_epoch: g.logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{epoch}') if tokens_per_epoch == 0: logging.info("Zero tokens in last epoch, breaking") break g.state.partial_epoch = False except KeyboardInterrupt: g.logger.info('-' * 100) g.logger.info('Exiting from training early') except StopIteration: pass return losses
def evaluate_and_log(model: torch.nn.Module, eval_iter, split: str, generate_text: bool = True, reset_mems_interval: int = None): args = g.args state = g.state optimizer = g.state.optimizer eval_start_time = time.time() model_to_reset = util.unwrap_model(model) # If the model does not use memory at all, make the ext_len longer. # Otherwise, make the mem_len longer and keep the ext_len the same. if g.args.mem_len == 0: model_to_reset.reset_length( args.eval_tgt_len, args.ext_len + args.tgt_len - args.eval_tgt_len, args.mem_len) else: model_to_reset.reset_length( args.eval_tgt_len, args.ext_len, args.mem_len + args.tgt_len - args.eval_tgt_len) # Calculate metrics ret = evaluate(model, eval_iter, split, args.max_eval_steps, reset_mems_interval=reset_mems_interval) total_loss, accuracy_top1, accuracy_top5, MRR, total_len = \ ret["total_loss"], ret["accuracy_top1"], ret["accuracy_top5"], ret["MRR_top5"], ret["total_len"] # Switch back to the training mode model_to_reset.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.train() # Log all the things. loss = total_loss / total_len mean_loss = util.dist_mean(loss) mean_accuracy_top1 = util.dist_mean(accuracy_top1) mean_accuracy_top5 = util.dist_mean(accuracy_top5) mean_MRR = util.dist_mean(MRR) g.logger.info('-' * 100) log_str = ( f'| Eval {g.state.train_step // args.eval_interval:3d} at step {g.state.train_step:>8d} | ' f'time: {time.time() - eval_start_time:5.2f}s ' f'| {split} loss {loss:5.2f}') log_tb(f'learning/{split}_loss', mean_loss) if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {loss / math.log(2):9.5f}' log_tb(f'learning/{split}_bpc', mean_loss / math.log(2)) elif args.dataset == 'git': log_str += f' | accuracy@1 {accuracy_top1:.2f} ' \ f'| accuracy@5 {accuracy_top5:.2f} ' \ f'| MRR@5 {MRR:.2f}' log_tb(f'learning/{split}_acc@1', mean_accuracy_top1) log_tb(f'learning/{split}_acc@5', mean_accuracy_top5) log_tb(f'learning/{split}_MRR@5', mean_MRR) else: log_str += f' | {split} ppl {math.exp(loss):9.3f}' log_tb(f'learning/{split}_ppl', math.exp(mean_loss)) g.logger.info(log_str) g.logger.info('-' * 100) # Update checkpoint if validation loss improved. if split == 'val' and (not state.best_val_loss or mean_loss < state.best_val_loss): g.logger.info('Saving checkpoint for new best loss') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix='best') state.best_val_loss = mean_loss
def main(): global global_token_count, event_writer, train_step, train_loss, last_log_step, \ best_val_loss, epoch, model if args.local_rank > 0: pass # skip shutdown when rank is explicitly set + not zero rank else: os.system('shutdown -c') if not args.local: logger.info( f'Distributed initializing process group with {args.dist_backend}, {args.dist_url}, {util.get_world_size()}' ) dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=util.get_world_size()) assert (util.get_world_size() == dist.get_world_size()) logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) # log model info n_all_param = sum([p.nelement() for p in model.parameters()]) log_tb('sizes/params', n_all_param) n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) log_tb('sizes/non_emb_params', n_nonemb_param) logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # optimizer if args.optim.lower() == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) elif args.optim.lower() == 'lamb': optimizer = Lamb(model.parameters(), lr=args.lr, weight_decay=args.wd) else: assert args.optim.lower() == 'adam' optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': scheduler = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) elif args.scheduler == 'constant': pass model.apply(weights_init) model.word_emb.apply( weights_init ) # ensure embedding init is not overridden by out_layer in case of weight sharing if args.checkpoint: if global_rank == 0: util.restore_from_checkpoint(model=model, checkpoint_fn=args.checkpoint) model = model.to(device) if args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.static_loss_scale, dynamic_loss_scale=args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) #, find_unused_parameters=True) if global_rank == 0: event_writer = SummaryWriter(args.logdir) event_writer.add_text('args', str(args)) # test checkpoint writing if args.checkpoint_each_epoch: logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{0}') # Loop over epochs. train_step = 0 train_loss = 0 last_log_step = 0 best_val_loss = None va_iter, te_iter = [ corpus.get_dist_iterator(split, global_rank, max_rank, args.batch_size * 2, args.tgt_len, device=device, ext_len=args.ext_len) for split in ('valid', 'test') ] # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=1): train(va_iter, optimizer, scheduler) except KeyboardInterrupt: logger.info('-' * 100) logger.info('Exiting from training early') except StopIteration: pass # Eval one more time. evaluate_and_log(optimizer, va_iter, 'val', train_step=-1) # Load the best saved model. logger.info("Loading best checkpoint") model_file = os.path.join(args.logdir, 'model-best.pt') if os.path.exists(model_file): with open(model_file, 'rb') as model_f: with timeit('load'): if args.local: model = torch.load(model_f) else: model = torch.load(model_f, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) else: logger.warn('no model file, using current model for loss') # Run on test data. evaluate_and_log(optimizer, te_iter, 'test', -1)
def train(va_iter, optimizer, scheduler): global global_token_count, event_writer, train_loss, best_val_loss, \ train_step, last_log_step, epoch # Turn on training mode which enables dropout. model.train() log_tb('sizes/batch_size', args.batch_size) log_tb('sizes/seq_size', args.tgt_len) tr_iter = corpus.get_dist_iterator('train', global_rank, max_rank, args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) mems = tuple() log_start_time = time.time() for batch, (data, target, seq_len) in enumerate(tr_iter): assert seq_len == data.shape[0] for i in range(1, data.shape[0]): assert torch.all(torch.eq(data[i], target[i - 1])) break batch_total = torch.tensor(data.shape[1]).to(device) batch_total = batch_total.to(device) # needed for NCCL sync if args.local: batch_total = batch_total.sum() else: batch_total = util.dist_sum_tensor( batch_total) # global batch size batch_total = util.toscalar(batch_total) total_tokens = batch_total * seq_len should_log = train_step < args.verbose_log_steps or train_step % args.log_interval == 0 global_token_count += total_tokens model.zero_grad() ret = model(data, target, *mems) loss, mems = ret[0], ret[1:] loss = loss.float().mean().type_as(loss) with timeit('backwards', noop=not should_log): if args.fp16: optimizer.backward(loss) else: loss.backward() train_loss += loss.float().item() if args.fp16: optimizer.clip_master_grads(args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() train_step += 1 # step-wise learning rate annealing if args.fp16 and optimizer.overflow: logger.info("skipped iteration") else: if args.scheduler in ['cosine', 'constant', 'dev_perf']: # linear warmup stage if global_token_count < args.warmup_tokens: curr_lr = args.lr * global_token_count / args.warmup_tokens optimizer.param_groups[0]['lr'] = curr_lr elif args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. scheduler.step(global_token_count // 1e6) else: scheduler.step(global_token_count) if should_log: elapsed_time = time.time() - log_start_time elapsed_steps = train_step - last_log_step # compute average loss over last logging interval cur_loss = train_loss / elapsed_steps log_str = f'| epoch {epoch:3d} step {train_step:>8d} | {batch:>6d} batches | lr {optimizer.param_groups[0]["lr"]:.3g} ' \ f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} | loss {cur_loss:5.2f}' if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {cur_loss / math.log(2):9.5f}' else: log_str += f' | ppl {math.exp(cur_loss):9.3f}' logger.info(log_str) log_tb('learning/epoch', epoch) log_tb('_loss', cur_loss) # the most important thing log_tb('learning/loss', cur_loss) log_tb('learning/ppl', math.exp(cur_loss)) # currently step timings are not synchronized in multi-machine # case (see #4). Can add torch.distributed.barrier() to get # more accurate timings, but this may add slowness. log_tb('times/step', 1000 * elapsed_time / elapsed_steps) current_lr = optimizer.param_groups[0]['lr'] log_tb('learning/lr', current_lr) # 32 is the "canonical" batch size linear_scaling_factor = batch_total / 32 log_tb('learning/base_lr', current_lr / linear_scaling_factor) if args.optim == 'lamb': log_lamb_rs(optimizer, event_writer, global_token_count) time_per_batch = elapsed_time / elapsed_steps time_per_sample = time_per_batch / args.batch_size time_per_token = time_per_sample / args.tgt_len log_tb('times/batches_per_sec', 1 / time_per_batch) log_tb('times/samples_per_sec', 1 / time_per_sample) log_tb('times/tokens_per_sec', 1 / time_per_token) if str(device) == 'cuda': log_tb("memory/allocated_gb", torch.cuda.memory_allocated() / 1e9) log_tb("memory/max_allocated_gb", torch.cuda.max_memory_allocated() / 1e9) log_tb("memory/cached_gb", torch.cuda.memory_cached() / 1e9) log_tb("memory/max_cached_gb", torch.cuda.max_memory_cached() / 1e9) train_loss = 0 log_start_time = time.time() last_log_step = train_step if train_step % args.eval_interval == 0: evaluate_and_log(optimizer, va_iter, 'val', train_step) if global_token_count >= args.max_tokens: if args.eta_min == 0: raise StopIteration logger.info('End of schedule, staying at current LR') args.scheduler = 'constant' if args.checkpoint_each_epoch: logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{epoch}')