def eval( # Data eval_filepath: str, type_vocab_filepath: str, spm_filepath: str, num_workers=1, max_seq_len=-1, # Model resume_path: str = "", no_output_attention: bool = False, encoder_type: str = "transformer", n_encoder_layers: int = 6, d_model: int = 512, # Output layer hparams d_out_projection: int = 512, n_hidden_output: int = 1, # Optimization batch_size=16, # Loss subword_regularization_alpha: float = 0, # Computational use_cuda: bool = True, seed: int = 0, ): """Evaluate model""" torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) config = locals() logger.info(f"Config: {config}") if use_cuda: assert torch.cuda.is_available( ), "CUDA not available. Check env configuration, or pass --use_cuda False" sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") id_to_target, target_to_id = load_type_vocab(type_vocab_filepath) no_type_id = target_to_id["O"] assert no_type_id == 0 # Just a sense check since O is the first line in the vocab file collate_fn = get_collate_fn(pad_id, no_type_id) # Create eval dataset and dataloader logger.info(f"Eval data path {eval_filepath}") eval_dataset = DeepTyperDataset( eval_filepath, type_vocab_filepath, spm_filepath, max_length=max_seq_len, subword_regularization_alpha=subword_regularization_alpha, split_source_targets_by_tab=eval_filepath.endswith(".json"), ) logger.info(f"Eval dataset size: {len(eval_dataset)}") eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) # Create model model = TypeTransformer( n_tokens=sp.GetPieceSize(), n_output_tokens=len(id_to_target), pad_id=pad_id, encoder_type=encoder_type, n_encoder_layers=n_encoder_layers, d_model=d_model, d_out_projection=d_out_projection, n_hidden_output=n_hidden_output, ) logger.info( f"Created TypeTransformer {encoder_type} with {count_parameters(model)} params" ) model = nn.DataParallel(model) model = model.cuda() if use_cuda else model model.eval() with torch.no_grad(): # Load checkpoint logger.info(f"Loading parameters from {resume_path}") checkpoint = torch.load(resume_path) model.module.load_state_dict(checkpoint["model_state_dict"]) epoch = checkpoint["epoch"] global_step = checkpoint["global_step"] # Evaluate metrics logger.info( f"Evaluating model after epoch {epoch} ({global_step} steps)...") _, eval_metrics = _evaluate(model, eval_loader, sp, target_to_id=target_to_id, use_cuda=use_cuda, no_output_attention=no_output_attention) for metric, value in eval_metrics.items(): logger.info( f"Evaluation {metric} after epoch {epoch} ({global_step} steps): {value:.4f}" )
def train( run_name: str, # Data train_filepath: str, eval_filepath: str, type_vocab_filepath: str, spm_filepath: str, num_workers=1, max_seq_len=1024, max_eval_seq_len=1024, run_dir=RUN_DIR, # Model resume_path: str = "", pretrain_resume_path: str = "", pretrain_resume_encoder_name: str = "encoder_q", # encoder_q, encoder_k, encoder pretrain_resume_project: bool = False, no_output_attention: bool = False, encoder_type: str = "transformer", n_encoder_layers: int = 6, d_model: int = 512, # Output layer hparams d_out_projection: int = 512, n_hidden_output: int = 1, # Optimization num_epochs: int = 100, save_every: int = 2, batch_size: int = 256, lr: float = 8e-4, adam_beta1: float = 0.9, adam_beta2: float = 0.98, adam_eps: float = 1e-6, weight_decay: float = 0, warmup_steps: int = 5000, num_steps: int = 200000, # Augmentations subword_regularization_alpha: float = 0, sample_lines_prob: float = 0, sample_lines_prob_keep_line: float = 0.9, # Loss ignore_any_loss: bool = False, # Computational use_cuda: bool = True, seed: int = 1, ): """Train model""" torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if run_dir != RUN_DIR: run_dir = Path(run_dir) run_dir = run_dir / run_name run_dir.mkdir(exist_ok=True, parents=True) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") config = locals() logger.info(f"Config: {config}") wandb.init(name=run_name, config=config, job_type="training", project="type_prediction", entity="ml4code") if use_cuda: assert torch.cuda.is_available( ), "CUDA not available. Check env configuration, or pass --use_cuda False" sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") id_to_target, target_to_id = load_type_vocab(type_vocab_filepath) no_type_id = target_to_id["O"] assert no_type_id == 0 # Just a sense check since O is the first line in the vocab file any_id = target_to_id["$any$"] collate_fn = get_collate_fn(pad_id, no_type_id) # Create training dataset and dataloader logger.info(f"Training data path {train_filepath}") if sample_lines_prob > 0: augmentations = [ { "fn": "sample_lines", "options": { "prob": sample_lines_prob, "prob_keep_line": sample_lines_prob_keep_line } }, ] program_mode = "augmentation" else: augmentations = None program_mode = "identity" train_dataset = DeepTyperDataset( train_filepath, type_vocab_filepath, spm_filepath, max_length=max_seq_len, subword_regularization_alpha=subword_regularization_alpha, augmentations=augmentations, program_mode=program_mode, ) logger.info(f"Training dataset size: {len(train_dataset)}") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, collate_fn=collate_fn) # Create eval dataset and dataloader logger.info(f"Eval data path {eval_filepath}") eval_dataset = DeepTyperDataset( eval_filepath, type_vocab_filepath, spm_filepath, max_length=max_eval_seq_len, subword_regularization_alpha=0, split_source_targets_by_tab=eval_filepath.endswith(".json"), ) logger.info(f"Eval dataset size: {len(eval_dataset)}") eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) # Create model model = TypeTransformer( n_tokens=sp.GetPieceSize(), n_output_tokens=len(id_to_target), pad_id=pad_id, encoder_type=encoder_type, n_encoder_layers=n_encoder_layers, d_model=d_model, d_out_projection=d_out_projection, n_hidden_output=n_hidden_output, ) logger.info( f"Created TypeTransformer {encoder_type} with {count_parameters(model)} params" ) # Load pretrained checkpoint if pretrain_resume_path: assert not resume_path logger.info( f"Resuming training from pretraining checkpoint {pretrain_resume_path}, pretrain_resume_encoder_name={pretrain_resume_encoder_name}" ) checkpoint = torch.load(pretrain_resume_path) pretrained_state_dict = checkpoint["model_state_dict"] encoder_state_dict = {} output_state_dict = {} assert pretrain_resume_encoder_name in [ "encoder_k", "encoder_q", "encoder" ] for key, value in pretrained_state_dict.items(): if key.startswith(pretrain_resume_encoder_name + ".") and "project_layer" not in key: remapped_key = key[len(pretrain_resume_encoder_name + "."):] logger.debug( f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}" ) encoder_state_dict[remapped_key] = value if key.startswith( pretrain_resume_encoder_name + "." ) and "project_layer.0." in key and pretrain_resume_project: remapped_key = key[len(pretrain_resume_encoder_name + ".project_layer."):] logger.debug( f"Remapping checkpoint project key {key} to output key {remapped_key}. Value mean: {value.mean().item()}" ) output_state_dict[remapped_key] = value model.encoder.load_state_dict(encoder_state_dict) # TODO: check for head key rather than output for MLM model.output.load_state_dict(output_state_dict, strict=False) logger.info(f"Loaded state dict from {pretrain_resume_path}") # Set up optimizer model = nn.DataParallel(model) model = model.cuda() if use_cuda else model wandb.watch(model, log="all") optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(adam_beta1, adam_beta2), eps=adam_eps, weight_decay=weight_decay) scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, num_steps) epoch = 0 global_step = 0 min_eval_metric = float("inf") if resume_path: assert not pretrain_resume_path logger.info(f"Resuming training from checkpoint {resume_path}") checkpoint = torch.load(resume_path) model.module.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) epoch = checkpoint["epoch"] global_step = checkpoint["global_step"] min_eval_metric = checkpoint["min_eval_metric"] # Eval metric history max_eval_metrics = {} # Evaluate initial metrics logger.info( f"Evaluating model after epoch {epoch} ({global_step} steps)...") eval_metric, eval_metrics = _evaluate( model, eval_loader, sp, target_to_id=target_to_id, use_cuda=use_cuda, no_output_attention=no_output_attention) for metric, value in eval_metrics.items(): logger.info( f"Evaluation {metric} after epoch {epoch} ({global_step} steps): {value:.4f}" ) max_eval_metrics[metric] = value eval_metrics["epoch"] = epoch wandb.log(eval_metrics, step=global_step) wandb.log({k + "_max": v for k, v in max_eval_metrics.items()}, step=global_step) for epoch in tqdm.trange(epoch + 1, num_epochs + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for X, lengths, output_attn, labels in pbar: if use_cuda: X, lengths, output_attn, labels = X.cuda(), lengths.cuda( ), output_attn.cuda(), labels.cuda() optimizer.zero_grad() if no_output_attention: logits = model(X, lengths, None) # BxLxVocab else: logits = model(X, lengths, output_attn) # BxLxVocab if ignore_any_loss: # Don't train with $any$ type labels_ignore_any = labels.clone() labels_ignore_any[labels_ignore_any == any_id] = no_type_id loss = F.cross_entropy(logits.transpose(1, 2), labels_ignore_any, ignore_index=no_type_id) else: loss = F.cross_entropy(logits.transpose(1, 2), labels, ignore_index=no_type_id) loss.backward() optimizer.step() scheduler.step() # Compute accuracy in training batch (corr1_any, corr5_any), num_labels_any = accuracy(logits, labels, topk=(1, 5), ignore_idx=(no_type_id, )) acc1_any, acc5_any = corr1_any / num_labels_any * 100, corr5_any / num_labels_any * 100 (corr1, corr5), num_labels = accuracy(logits, labels, topk=(1, 5), ignore_idx=(no_type_id, any_id)) acc1, acc5 = corr1 / num_labels * 100, corr5 / num_labels * 100 # Log loss global_step += 1 wandb.log( { "epoch": epoch, "train/loss": loss.item(), "train/acc@1": acc1, "train/acc@5": acc5, "train/acc@1_any": acc1_any, "train/acc@5_any": acc5_any, "lr": scheduler.get_last_lr()[0], }, step=global_step, ) pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}") # Evaluate logger.info( f"Evaluating model after epoch {epoch} ({global_step} steps)...") eval_metric, eval_metrics = _evaluate( model, eval_loader, sp, target_to_id=target_to_id, use_cuda=use_cuda, no_output_attention=no_output_attention) for metric, value in eval_metrics.items(): logger.info( f"Evaluation {metric} after epoch {epoch} ({global_step} steps): {value:.4f}" ) max_eval_metrics[metric] = max(value, max_eval_metrics[metric]) eval_metrics["epoch"] = epoch wandb.log(eval_metrics, step=global_step) wandb.log({k + "_max": v for k, v in max_eval_metrics.items()}, step=global_step) # Save checkpoint if save_every and epoch % save_every == 0 or eval_metric < min_eval_metric: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, "eval_metric": eval_metric, "min_eval_metric": min_eval_metric, } if eval_metric < min_eval_metric: logger.info( f"New best evaluation metric: prev {min_eval_metric:.4f} > new {eval_metric:.4f}" ) min_eval_metric = eval_metric model_file = run_dir / "ckpt_best.pth" else: model_file = run_dir / f"ckpt_ep{epoch:04d}.pth" logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, str(model_file.resolve())) logger.info("Done.")