def setup_train(i, corpus, args): """Setup training. Handles CPU, single GPU, and distributed training. Args: i: The process index. Since one process per GPU, this is also the GPU index. For single GPU or CPU this is set to 0. corpus: The corpus for training. args: Arguments from argparse and main(). """ args.device = torch.device(args.device.type, i) # Find rank among all processes. args.rank = args.node_rank * args.gpu_per_node + i log = Logger(i, args.tensorboard_dir) log.train_add_text('arguments', str(args)) log.valid_add_text('arguments', str(args)) if args.dist: dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=args.rank) torch.cuda.set_device(args.rank) # Initialize model log("| Loading model...") model = get_model(corpus.vocab, args) model.to(args.device) args.total_param = count_param(model) if hasattr(model, 'layer_pool'): args.layer_param = sum( [count_param(layer) for layer in model.layer_pool]) elif hasattr(model, 'layer'): args.layer_param = count_param(model.layer) string = f"| Model:\n{model}\n" string += f"| Total parameters: {args.total_param}\n" string += f"| Parameters without embedding and pre-softmax linear: {args.layer_param}" log(string) log.train_add_text('arguments', string) log.valid_add_text('arguments', string) # Create optimizer and scheduler. optimizer, scheduler = get_optimizer_scheduler(model, args) if args.fp16: print("| Floating point 16 precision setting:\n", end='') model, optimizer = amp.initialize(model, optimizer, opt_level='O2') if args.dist: model = DistributedDataParallel(model, device_ids=[i], find_unused_parameters=True) resume_step = 0 resume_epoch = 0 if args.checkpoint is not None: log("| Loading checkpoint...") if args.fp16: resume_step, resume_epoch = load_checkpoint( args.checkpoint, args.device, model, optimizer, scheduler, amp) else: resume_step, resume_epoch = load_checkpoint( args.checkpoint, args.device, model, optimizer, scheduler) def update_dropout(module): if hasattr(module, 'dropout'): model.dropout = args.dropout if hasattr(module, 'attn_dropout'): model.attn_dropout = args.attn_dropout model.apply(update_dropout) else: model.apply(reset_parameters) # Initialize parameters # Get DataLoader log("| Processing data...") train_loader = get_loader(corpus.train, corpus.vocab, args) if args.valid is not None: valid_loader = get_eval_loader(corpus.valid, corpus.vocab, args) log(f"| Training on {socket.gethostname()} with rank {args.rank}.", True) def train(step, epoch, best_loss): model.train() optimizer.zero_grad() train_loader.dataset.set_seed(epoch) log.init_epoch(step, epoch, train_loader.dataset.total_target) epoch_loss = 0 epoch_num_target = 0 for batch_num, batch in enumerate(train_loader): # TODO debug f = batch['feature'].data.numpy() t = batch['target'].data.numpy() n = batch['num_target'] vocab = corpus.vocab # TODO print out data to test # feat = np.transpose(f) # for data in feat: # print(vocab.to_text(data)) # continue # TODO test dataloading num_target = sum(batch['num_target']) epoch_num_target += num_target log.num_target += num_target log.batch_size += len(batch['num_target']) try: feature = batch['feature'].to(args.device) target = batch['target'].to(args.device) assert (target != vocab.pad_idx ).sum() == num_target # TODO remove debug check loss = model(feature, target) assert loss.dtype == torch.float32 # TODO remove debug check batch_loss = loss.item() epoch_loss += batch_loss log.loss += batch_loss loss = loss / num_target loss = loss / args.step_freq if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() except RuntimeError as e: if 'out of memory' in str(e): log.oom += 1 print( f"== Rank {args.rank}: Training out of memory. Skipping this batch. ==" ) # Release memory if 'scaled_loss' in locals(): del scaled_loss if 'loss' in locals(): del loss if 'feature' in locals(): del feature if 'target' in locals(): del target for param in model.parameters(): if param.grad is not None: param.grad = None if args.cuda: torch.cuda.empty_cache() else: raise e if (batch_num + 1) % args.step_freq == 0: step, epoch, best_loss = update(step, epoch, best_loss) if args.max_step is not None and step >= args.max_step: break # Remaining batches that doesn't fit in update freq. if not args.trim_step and (batch_num + 1) % args.step_freq != 0: step, epoch, best_loss = update(step, epoch, best_loss) log.end_epoch(step, epoch) return step, epoch_loss / epoch_num_target, best_loss def update(step, epoch, best_loss): loss_scale = 1 if args.fp16: loss_scale = amp._amp_state.loss_scalers[0]._loss_scale # Calculate norm of gradients. For logging. if args.log_norm: for name, param in model.named_parameters(): if param.grad is None: continue norm = param.grad.data.float().norm().item() / loss_scale log.train_add_scalar('norm/' + name, norm, step) # Clip gradient if args.fp16: norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip_norm) else: norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm) log.norm += norm log.clip_norm += min(args.clip_norm, norm) optimizer.step() optimizer.zero_grad() step += 1 if scheduler is not None: if step < args.warmup_step: # Linear warmup warmup_lr = args.lr * step / args.warmup_step optimizer.param_groups[0]['lr'] = warmup_lr else: scheduler.step() lr = optimizer.param_groups[0]['lr'] log.train(step, lr, loss_scale) if args.step_per_save != 0 and step % args.step_per_save == 0: if i == 0: path = os.path.join(args.checkpoint_dir, f'checkpoint-{epoch}-{step}.pt') save_checkpoint(path, step, epoch, model, optimizer, scheduler, amp if args.fp16 else None) copyfile( path, os.path.join(args.checkpoint_dir, 'checkpoint_last.pt')) if args.dist: dist.barrier() if args.step_per_valid != 0 and step % args.step_per_valid == 0: # Eval on validation data. if args.valid is not None: best_loss = validate(best_loss) return step, epoch, best_loss def evaluate(): model.eval() total_loss = 0 total_target = 0 total = valid_loader.dataset.total_target if i == 0: progress = tqdm(desc="Evaluating", total=total, unit=' token') for batch in valid_loader: # TODO debug f = batch['feature'].data.numpy() t = batch['target'].data.numpy() n = batch['num_target'] vocab = corpus.vocab # TODO print out data to test # feat = np.transpose(f) # for data in feat: # print(vocab.to_text(data)) # continue # TODO test dataloading num_target = sum(batch['num_target']) total_target += num_target feature = batch['feature'].to(args.device) target = batch['target'].to(args.device) loss = model(feature, target) total_loss += loss.item() if i == 0: progress.update(num_target) if i == 0: progress.close() return total_loss / total_target def validate(best_loss): with torch.no_grad(): loss = evaluate() log.valid(loss, step, epoch) if i == 0 and best_loss > loss: best_loss = loss best_path = os.path.join(args.checkpoint_dir, 'checkpoint_best.pt') save_checkpoint(best_path, step, epoch, model, optimizer, scheduler, amp if args.fp16 else None) if args.dist: dist.barrier() log.valid_add_scalar('best loss', best_loss / math.log(2), step) log.valid_add_scalar('best ppl', 2**(best_loss / math.log(2)), step) return best_loss step = resume_step best_loss = math.inf # Start from epoch 1 or resume from next epoch for epoch in itertools.count(resume_epoch + 1): # Train on training data. step, loss, best_loss = train(step, epoch, best_loss) if args.max_step is not None and step >= args.max_step: break if args.epoch_per_valid != 0 and epoch % args.epoch_per_valid == 0: # Eval on validation data. if args.valid is not None: if args.dist: dist.barrier() best_loss = validate(best_loss) # Saving checkpoint. if args.epoch_per_save != 0 and epoch % args.epoch_per_save == 0: if i == 0: path = os.path.join(args.checkpoint_dir, f'checkpoint-{epoch}-{step}.pt') save_checkpoint(path, step, epoch, model, optimizer, scheduler, amp if args.fp16 else None) copyfile( path, os.path.join(args.checkpoint_dir, 'checkpoint_last.pt')) if args.dist: dist.barrier() # Delete old checkpoints. if i == 0 and (args.keep_step is not None or args.keep_epoch is not None): for filename in os.listdir(args.checkpoint_dir): if re.match(r'checkpoint-\d+-\d+\.pt', filename): file_epoch, file_step = re.split(r'[-.]', filename)[1:3] if args.keep_step is not None and int(file_step) <= ( step - args.keep_step): os.remove(os.path.join(args.checkpoint_dir, filename)) if args.keep_epoch is not None and int(file_epoch) <= ( epoch - args.keep_epoch): os.remove(os.path.join(args.checkpoint_dir, filename)) if args.dist: dist.barrier() if args.max_epoch is not None and epoch >= args.max_epoch: break