def train(args, model, device, train_loader, optimizer, epoch, event_writer): model.train() tqdm_bar = tqdm.tqdm(train_loader) for batch_idx, (data, target) in enumerate(tqdm_bar): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: step = batch_idx * len(data) + (epoch-1) * len(train_loader.dataset) log_lamb_rs(optimizer, event_writer, step) event_writer.add_scalar('loss', loss.item(), step) tqdm_bar.set_description( f'Train epoch {epoch} Loss: {loss.item():.6f}')
def main_loop(): util.cancel_shutdown() losses = [] args = g.args if not args.local: g.logger.info( f'Distributed initializing process group with ' f'{args.dist_backend}, {args.dist_url}, {util.get_world_size()}') dist.init_process_group( backend=args.dist_backend, #init_method=args.dist_url, #world_size=util.get_world_size() ) assert (util.get_world_size() == dist.get_world_size()) g.logger.info( f"Distributed: success ({args.local_rank}/{dist.get_world_size()})" ) g.logger.info("creating new model") g.state = TrainState(args) g.state.model = MemTransformerLM(g.ntokens, args.n_layer, args.n_head, args.d_model, args.d_head, args.d_inner, args.dropout, args.dropatt, tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, tie_projs=g.tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=g.cutoffs, same_length=args.same_length, attn_type=args.attn_type, clamp_len=args.clamp_len, sample_softmax=args.sample_softmax, freeze_below=args.freeze_below) g.state.model.to(g.device) optimizer_setup(g.state) if args.checkpoint: if args.checkpoint_secondary: g.logger.info(f"restoring extra checkpoint") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint_secondary, args.optim_state_dict) g.logger.info(f"Restoring model from {args.checkpoint}" + f" and optimizer from {args.optim_state_dict}" if args. optim_state_dict else "") util.restore_from_checkpoint(g.state.model, g.state.optimizer, args.checkpoint, args.optim_state_dict) else: g.state.model.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing g.state.model.word_emb.apply(weights_init) model: MemTransformerLM = g.state.model optimizer = g.state.optimizer if g.state.args.fp16: model = FP16_Module(model) optimizer = FP16_Optimizer( optimizer, static_loss_scale=g.state.args.static_loss_scale, dynamic_loss_scale=g.state.args.dynamic_loss_scale, dynamic_loss_args={'init_scale': 2**16}, verbose=False) # log model info # n_all_param = sum([p.nelement() for p in model.parameters()]) # log_tb('sizes/params', n_all_param) # n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) # log_tb('sizes/non_emb_params', n_nonemb_param) # g.logger.info('params %s non_emb_params %s', n_all_param, n_nonemb_param) # scheduler if args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, args.max_tokens // 1e6, eta_min=args.eta_min) elif args.scheduler == 'finder': g.state.scheduler: LRFinder = LRFinder(optimizer, args.max_tokens, init_value=args.lr / 1e3) else: assert args.scheduler == 'constant' g.state.scheduler = util.NoOp() # Setup distributed model if args.local: model = nn.DataParallel(model, dim=1) else: # Uncomment find_unused_parameters and upgrade to torch 1.1 for adaptive embedding. model = DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) # , find_unused_parameters=True) if util.get_global_rank() == 0: if not args.test: wandb.config.update(vars(args)) # wandb.watch(model) g.event_writer.add_text('args', str(args)) # TODO: replace with log_tb accumulated_loss = 0 # At any point you can hit Ctrl + C to break out of training early. try: for epoch in itertools.count(start=g.state.last_epoch): print(f"epoch -- {epoch}, token_count -- {g.state.token_count}") model.train() log_tb('sizes/batch_size', args.batch_size) log_tb('sizes/seq_size', args.tgt_len) if g.state.partial_epoch: # reuse previously loaded tr_iter and states assert g.state.tr_iter is not None assert g.state.mems is not None else: g.state.tr_iter = g.corpus.get_dist_iterator( 'train', rank=util.get_global_rank(), max_rank=util.get_world_size(), bsz=args.batch_size, bptt=args.tgt_len, device=g.device, ext_len=args.ext_len, skip_files=g.args.skip_files) g.state.mems = tuple() g.state.last_epoch = epoch log_start_time = time.time() tokens_per_epoch = 0 for batch, (data, target, seq_len) in enumerate(g.state.tr_iter): # assert seq_len == data.shape[0] # for i in range(1, data.shape[0]): # assert torch.all(torch.eq(data[i], target[i - 1])) # break # print(g.state.token_count, data) if g.state.train_step % args.eval_interval == 0: evaluate_and_log(model, g.va_iter, 'val_short-mem-1', generate_text=False, reset_mems_interval=1) evaluate_and_log(model, g.va_iter, 'val_short-mem-2', generate_text=False, reset_mems_interval=2) evaluate_and_log(model, g.va_iter, 'val_short-mem-3', generate_text=False, reset_mems_interval=3) evaluate_and_log(model, g.va_iter, 'val') if g.va_custom_iter: evaluate_and_log(g.state.model, g.va_custom_iter, g.args.valid_custom, generate_text=False) batch_total = torch.tensor(data.shape[1]).to(g.device) if args.local: # TODO(y): factor out (need way to see if dist was inited) batch_total = batch_total.sum() else: batch_total = util.dist_sum_tensor( batch_total) # global batch size batch_total = util.toscalar(batch_total) should_log = (g.state.train_step < args.verbose_log_steps) or \ (g.state.train_step + 1) % args.log_interval == 0 model.zero_grad() ret = model(data, target, *g.state.mems) loss, g.state.mems = ret[0], ret[1:] loss: torch.Tensor = loss.float().mean().type_as(loss) with timeit('backwards', noop=not should_log): if args.fp16: optimizer.backward(loss) else: loss.backward() loss0 = util.toscalar(loss) util.record('loss', loss0) util.record('params', torch.sum(util.flat_param(model)).item()) losses.append(loss0) accumulated_loss += loss0 if args.fp16: optimizer.clip_master_grads(args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) # step-wise learning rate annealing if hasattr(optimizer, 'overflow') and optimizer.overflow: g.logger.info("skipped iteration") else: if args.scheduler in ['cosine', 'constant', 'dev_perf']: # linear warmup stage if g.state.token_count < args.warmup_tokens: curr_lr = args.lr * float( g.state.token_count) / args.warmup_tokens optimizer.param_groups[0]['lr'] = curr_lr elif args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. g.state.scheduler.step(g.state.token_count // 1000 // 1000) else: g.state.scheduler.step(g.state.token_count) optimizer.step() g.state.train_step += 1 consumed_tokens = data.shape[0] * data.shape[1] world_size = int(os.environ.get("WORLD_SIZE", "8")) if world_size > 8: # correction factor for multiple machines consumed_tokens = consumed_tokens * (world_size // 8) tokens_per_epoch += consumed_tokens g.state.token_count += consumed_tokens g.token_count = g.state.token_count if g.state.token_count >= args.max_tokens: g.state.partial_epoch = True raise StopIteration # break out of parent train loop if should_log: elapsed_time = time.time() - log_start_time elapsed_steps = g.state.train_step - g.state.last_log_step # compute average loss over last logging interval cur_loss = accumulated_loss / elapsed_steps cur_loss_mean = util.dist_mean(cur_loss) log_str = f'| epoch {epoch:3d} step {g.state.train_step:>8d} ' \ f'| {batch:>6d} batches ' \ f'| lr {optimizer.param_groups[0]["lr"]:.3g} ' \ f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} ' \ f'| loss {cur_loss:5.2f}' if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {cur_loss / math.log(2):9.5f}' else: log_str += f' | ppl {math.exp(cur_loss):9.3f}' g.logger.info(log_str) log_tb('learning/epoch', epoch) log_tb('_loss', cur_loss_mean) # the most important thing log_tb('learning/loss', cur_loss_mean) log_tb('learning/ppl', math.exp(cur_loss_mean)) # currently step timings are not synchronized in multi-machine # case (see #4). Can add torch.distributed.barrier() to get # more accurate timings, but this may add slowness. log_tb('times/step', 1000 * elapsed_time / elapsed_steps) current_lr = optimizer.param_groups[0]['lr'] log_tb('learning/lr', current_lr) # 32 is the "canonical" batch size linear_scaling_factor = batch_total / 32 # TODO(y): merge logic from master log_tb('learning/base_lr', current_lr / linear_scaling_factor) if args.optim == 'lamb': log_lamb_rs(optimizer, g.event_writer, g.state.token_count) time_per_batch = elapsed_time / elapsed_steps time_per_sample = time_per_batch / args.batch_size time_per_token = time_per_sample / args.tgt_len log_tb('times/batches_per_sec', 1 / time_per_batch) log_tb('times/samples_per_sec', 1 / time_per_sample) log_tb('times/tokens_per_sec', 1 / time_per_token) if str(g.device) == 'cuda': log_tb("memory/allocated_gb", torch.cuda.memory_allocated() / 1e9) log_tb("memory/max_allocated_gb", torch.cuda.max_memory_allocated() / 1e9) log_tb("memory/cached_gb", torch.cuda.memory_cached() / 1e9) log_tb("memory/max_cached_gb", torch.cuda.max_memory_cached() / 1e9) accumulated_loss = 0 log_start_time = time.time() g.state.last_log_step = g.state.train_step if args.checkpoint_each_epoch: g.logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{epoch}') if tokens_per_epoch == 0: logging.info("Zero tokens in last epoch, breaking") break g.state.partial_epoch = False except KeyboardInterrupt: g.logger.info('-' * 100) g.logger.info('Exiting from training early') except StopIteration: pass return losses
def train(va_iter, optimizer, scheduler): global global_token_count, event_writer, train_loss, best_val_loss, \ train_step, last_log_step, epoch # Turn on training mode which enables dropout. model.train() log_tb('sizes/batch_size', args.batch_size) log_tb('sizes/seq_size', args.tgt_len) tr_iter = corpus.get_dist_iterator('train', global_rank, max_rank, args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len) mems = tuple() log_start_time = time.time() for batch, (data, target, seq_len) in enumerate(tr_iter): assert seq_len == data.shape[0] for i in range(1, data.shape[0]): assert torch.all(torch.eq(data[i], target[i - 1])) break batch_total = torch.tensor(data.shape[1]).to(device) batch_total = batch_total.to(device) # needed for NCCL sync if args.local: batch_total = batch_total.sum() else: batch_total = util.dist_sum_tensor( batch_total) # global batch size batch_total = util.toscalar(batch_total) total_tokens = batch_total * seq_len should_log = train_step < args.verbose_log_steps or train_step % args.log_interval == 0 global_token_count += total_tokens model.zero_grad() ret = model(data, target, *mems) loss, mems = ret[0], ret[1:] loss = loss.float().mean().type_as(loss) with timeit('backwards', noop=not should_log): if args.fp16: optimizer.backward(loss) else: loss.backward() train_loss += loss.float().item() if args.fp16: optimizer.clip_master_grads(args.clip) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() train_step += 1 # step-wise learning rate annealing if args.fp16 and optimizer.overflow: logger.info("skipped iteration") else: if args.scheduler in ['cosine', 'constant', 'dev_perf']: # linear warmup stage if global_token_count < args.warmup_tokens: curr_lr = args.lr * global_token_count / args.warmup_tokens optimizer.param_groups[0]['lr'] = curr_lr elif args.scheduler == 'cosine': # Divide by 1e6 for numerical stability. scheduler.step(global_token_count // 1e6) else: scheduler.step(global_token_count) if should_log: elapsed_time = time.time() - log_start_time elapsed_steps = train_step - last_log_step # compute average loss over last logging interval cur_loss = train_loss / elapsed_steps log_str = f'| epoch {epoch:3d} step {train_step:>8d} | {batch:>6d} batches | lr {optimizer.param_groups[0]["lr"]:.3g} ' \ f'| ms/batch {elapsed_time * 1000 / elapsed_steps:5.2f} | loss {cur_loss:5.2f}' if args.dataset in ['enwik8', 'text8']: log_str += f' | bpc {cur_loss / math.log(2):9.5f}' else: log_str += f' | ppl {math.exp(cur_loss):9.3f}' logger.info(log_str) log_tb('learning/epoch', epoch) log_tb('_loss', cur_loss) # the most important thing log_tb('learning/loss', cur_loss) log_tb('learning/ppl', math.exp(cur_loss)) # currently step timings are not synchronized in multi-machine # case (see #4). Can add torch.distributed.barrier() to get # more accurate timings, but this may add slowness. log_tb('times/step', 1000 * elapsed_time / elapsed_steps) current_lr = optimizer.param_groups[0]['lr'] log_tb('learning/lr', current_lr) # 32 is the "canonical" batch size linear_scaling_factor = batch_total / 32 log_tb('learning/base_lr', current_lr / linear_scaling_factor) if args.optim == 'lamb': log_lamb_rs(optimizer, event_writer, global_token_count) time_per_batch = elapsed_time / elapsed_steps time_per_sample = time_per_batch / args.batch_size time_per_token = time_per_sample / args.tgt_len log_tb('times/batches_per_sec', 1 / time_per_batch) log_tb('times/samples_per_sec', 1 / time_per_sample) log_tb('times/tokens_per_sec', 1 / time_per_token) if str(device) == 'cuda': log_tb("memory/allocated_gb", torch.cuda.memory_allocated() / 1e9) log_tb("memory/max_allocated_gb", torch.cuda.max_memory_allocated() / 1e9) log_tb("memory/cached_gb", torch.cuda.memory_cached() / 1e9) log_tb("memory/max_cached_gb", torch.cuda.max_memory_cached() / 1e9) train_loss = 0 log_start_time = time.time() last_log_step = train_step if train_step % args.eval_interval == 0: evaluate_and_log(optimizer, va_iter, 'val', train_step) if global_token_count >= args.max_tokens: if args.eta_min == 0: raise StopIteration logger.info('End of schedule, staying at current LR') args.scheduler = 'constant' if args.checkpoint_each_epoch: logger.info(f'Saving checkpoint for epoch {epoch}') util.dist_save_checkpoint(model, optimizer, args.logdir, suffix=f'{epoch}')