def train( run_name: str, # Data train_filepath: str = CSNJS_TRAIN_FILEPATH, eval_filepath: str = CSNJS_VALID_FILEPATH, spm_filepath: str = SPM_UNIGRAM_FILEPATH, program_mode="identity", eval_program_mode="identity", label_mode="identifier", num_workers=1, limit_dataset_size=-1, # Model model_type="transformer", n_decoder_layers=4, d_model: int = 512, resume_path: str = "", resume_encoder_name: str = "encoder_q", # encoder_q, encoder_k, encoder resume_project: bool = False, # Optimization train_decoder_only: bool = False, num_epochs: int = 50, save_every: int = 2, batch_size: int = 256, lr: float = 8e-4, adam_beta1: float = 0.9, adam_beta2: float = 0.98, use_lr_warmup: bool = True, loss_type = "nll_token", # nll_token or nll_sequence # Loss subword_regularization_alpha: float = 0, # Computational use_cuda: bool = True, auto_test: bool = True, seed: int = 0, ): """Train model""" torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) run_dir = RUN_DIR / run_name run_dir.mkdir(exist_ok=True, parents=True) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") config = locals() logger.info(f"Config: {config}") wandb.init(name=run_name, config=config, job_type="training", project="identifier-prediction", entity="ml4code") if use_cuda: assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False" train_augmentations = [ {"fn": "sample_lines", "line_length_pct": 0.5}, {"fn": "insert_var_declaration", "prob": 0.5}, {"fn": "rename_variable", "prob": 0.5}, ] sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") # Create training dataset and dataloader logger.info(f"Training data path {train_filepath}") train_dataset = get_csnjs_dataset(train_filepath, label_mode=label_mode, limit_size=limit_dataset_size) logger.info(f"Training dataset size: {len(train_dataset)}") train_loader = javascript_dataloader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, augmentations=train_augmentations, sp=sp, program_mode=program_mode, subword_regularization_alpha=subword_regularization_alpha, ) # Create eval dataset and dataloader logger.info(f"Eval data path {eval_filepath}") eval_dataset = get_csnjs_dataset(eval_filepath, label_mode=label_mode, limit_size=limit_dataset_size) logger.info(f"Eval dataset size: {len(eval_dataset)}") eval_loader = javascript_dataloader( eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, augmentations=[], sp=sp, program_mode=eval_program_mode, subword_regularization_alpha=subword_regularization_alpha, ) # Create model pad_id = sp.PieceToId("[PAD]") if model_type == "transformer": model = TransformerModel(n_tokens=sp.GetPieceSize(), pad_id=pad_id, n_decoder_layers=n_decoder_layers, d_model=d_model) logger.info(f"Created TransformerModel with {count_parameters(model)} params") elif model_type == "lstm": model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(), pad_id=pad_id, d_model=d_model) logger.info(f"Created Seq2SeqLSTM with {count_parameters(model)} params") # Load checkpoint if resume_path: logger.info(f"Resuming training from checkpoint {resume_path}, resume_encoder_name={resume_encoder_name}") checkpoint = torch.load(resume_path) pretrained_state_dict = checkpoint["model_state_dict"] encoder_state_dict = {} assert resume_encoder_name in ["encoder_k", "encoder_q", "encoder"] for key, value in pretrained_state_dict.items(): if key.startswith(resume_encoder_name + ".") and "project_layer" not in key: remapped_key = key[len(resume_encoder_name + ".") :] logger.debug(f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}") encoder_state_dict[remapped_key] = value if key.startswith(resume_encoder_name + ".") and "project_layer.0." in key and resume_project: remapped_key = key[len(resume_encoder_name + ".") :] logger.debug(f"Remapping checkpoint project key {key} to {remapped_key}. Value mean: {value.mean().item()}") encoder_state_dict[remapped_key] = value model.encoder.load_state_dict(encoder_state_dict, strict=False) logger.info(f"Loaded state dict from {resume_path}") logger.info(f"Loaded keys: {encoder_state_dict.keys()}") # Set up optimizer model = nn.DataParallel(model) model = model.cuda() if use_cuda else model wandb.watch(model, log="all") params = model.module.decoder.parameters() if train_decoder_only else model.parameters() optimizer = torch.optim.Adam(params, lr=lr, betas=(adam_beta1, adam_beta2), eps=1e-9) if use_lr_warmup: scheduler = get_linear_schedule_with_warmup(optimizer, 5000, len(train_loader) * num_epochs) else: scheduler = LambdaLR(optimizer, lr_lambda=lambda x: 1.0) global_step = 0 min_eval_loss = float("inf") for epoch in tqdm.trange(1, num_epochs + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") if train_decoder_only: model.module.encoder.eval() model.module.decoder.train() else: model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for X, Y, X_lengths, Y_lengths in pbar: if use_cuda: X = X.cuda() Y = Y.cuda() X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda() optimizer.zero_grad() # NOTE: X and Y are [B, max_seq_len] tensors (batch first) logits = model(X, Y[:, :-1], X_lengths, Y_lengths) if loss_type == "nll_sequence": loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction='sum') loss = loss / X.size(0) # Average over num sequences, not target sequence lengths # Thus, minimize bits per sequence. elif loss_type == "nll_token": loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id,) loss.backward() optimizer.step() scheduler.step() # Log loss global_step += 1 wandb.log( {"epoch": epoch, f"label-{label_mode}/train_loss": loss.item(), "lr": scheduler.get_last_lr()[0]}, step=global_step ) pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}") # Evaluate logger.info(f"Evaluating model after epoch {epoch} ({global_step} steps)...") max_decode_len = 20 if label_mode == "identifier" else 200 eval_loss = _evaluate(model, eval_loader, sp, use_cuda=use_cuda, max_decode_len=max_decode_len, loss_type=loss_type) logger.info(f"Evaluation loss after epoch {epoch} ({global_step} steps): {eval_loss:.4f}") wandb.log({"epoch": epoch, f"label-{label_mode}/eval_loss": eval_loss}, step=global_step) # Save checkpoint if save_every and epoch % save_every == 0 or eval_loss < min_eval_loss: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, "eval_loss": eval_loss, } if eval_loss < min_eval_loss: logger.info(f"New best evaluation loss: prev {min_eval_loss:.4f} > new {eval_loss:.4f}") min_eval_loss = eval_loss model_file = run_dir / "ckpt_best.pth" else: model_file = run_dir / f"ckpt_ep{epoch:04d}.pth" logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, str(model_file.resolve())) wandb.save(str(model_file.resolve())) logger.info("Done.") if auto_test: best_ckpt = run_dir / "ckpt_best.pth" test( str(best_ckpt.resolve()), CSNJS_TEST_FILEPATH, spm_filepath, program_mode, label_mode, num_workers, -1, n_decoder_layers=n_decoder_layers, )
def test( checkpoint_file: str, test_filepath: str = CSNJS_TEST_FILEPATH, spm_filepath: str = SPM_UNIGRAM_FILEPATH, program_mode="identity", label_mode="identifier", num_workers=1, limit_dataset_size=-1, batch_size=8, model_type="transformer", n_decoder_layers=4, d_model=512, use_cuda: bool = True, ): wandb.init(name=checkpoint_file, config=locals(), project="f1_eval", entity="ml4code") if use_cuda: assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False" sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) # Create test dataset and dataloader logger.info(f"Test data path {test_filepath}") test_dataset = get_csnjs_dataset(test_filepath, label_mode=label_mode, limit_size=limit_dataset_size) logger.info(f"Test dataset size: {len(test_filepath)}") test_loader = javascript_dataloader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, sp=sp, program_mode=program_mode, subword_regularization_alpha=0, augmentations=[], ) pad_id = sp.PieceToId("[PAD]") if model_type == "transformer": model = TransformerModel(n_tokens=sp.GetPieceSize(), pad_id=pad_id, n_decoder_layers=n_decoder_layers, d_model=d_model) logger.info(f"Created TransformerModel with {count_parameters(model)} params") elif model_type == "lstm": model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(), pad_id=pad_id, d_model=d_model) logger.info(f"Created Seq2SeqLSTM with {count_parameters(model)} params") if use_cuda: model = model.cuda() # Load checkpoint checkpoint = torch.load(checkpoint_file) pretrained_state_dict = checkpoint["model_state_dict"] print("CHECKPOINT", checkpoint_file) from pprint import pprint print("KEYS", checkpoint["model_state_dict"].keys()) try: model.load_state_dict(pretrained_state_dict) except RuntimeError as e: logger.error(e) logger.error("Keys in checkpoint: " + str(list(pretrained_state_dict.keys()))) raise e logger.info(f"Loaded state dict from {checkpoint_file}") # Evaluate NLL model.eval() with torch.no_grad(): test_nll = calculate_nll(model, test_loader, sp, use_cuda=use_cuda, logger_fn=wandb.log) logger.info(f"NLL: {test_nll:.5f}") # Make metric metric = F1MetricMethodName() model.eval() with torch.no_grad(): precision, recall, f1, sample_generations = calculate_f1_metric(metric, model, test_loader, sp, use_cuda=use_cuda, logger_fn=wandb.log) logger.info(f"NLL: {test_nll:.5f}") logger.info(f"Precision: {precision:.5f}%") logger.info(f"Recall: {recall:.5f}%") logger.info(f"F1: {f1:.5f}%") df_generations = pd.DataFrame(sample_generations) df_generations.to_pickle(os.path.join(wandb.run.dir, "sample_generations.pickle.gz")) wandb.save(os.path.join(wandb.run.dir, "sample_generations.pickle.gz"))