Exemplo n.º 1
0
def train(
    run_name: str,
    # Data
    train_filepath: str = CSNJS_TRAIN_FILEPATH,
    eval_filepath: str = CSNJS_VALID_FILEPATH,
    spm_filepath: str = SPM_UNIGRAM_FILEPATH,
    program_mode="identity",
    eval_program_mode="identity",
    label_mode="identifier",
    num_workers=1,
    limit_dataset_size=-1,
    # Model
    model_type="transformer",
    n_decoder_layers=4,
    d_model: int = 512,
    resume_path: str = "",
    resume_encoder_name: str = "encoder_q",  # encoder_q, encoder_k, encoder
    resume_project: bool = False,
    # Optimization
    train_decoder_only: bool = False,
    num_epochs: int = 50,
    save_every: int = 2,
    batch_size: int = 256,
    lr: float = 8e-4,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.98,
    use_lr_warmup: bool = True,
    loss_type = "nll_token",  # nll_token or nll_sequence
    # Loss
    subword_regularization_alpha: float = 0,
    # Computational
    use_cuda: bool = True,
    auto_test: bool = True,
    seed: int = 0,
):
    """Train model"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    run_dir = RUN_DIR / run_name
    run_dir.mkdir(exist_ok=True, parents=True)
    logger.add(str((run_dir / "train.log").resolve()))
    logger.info(f"Saving logs, model checkpoints to {run_dir}")
    config = locals()
    logger.info(f"Config: {config}")
    wandb.init(name=run_name, config=config, job_type="training", project="identifier-prediction", entity="ml4code")

    if use_cuda:
        assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False"

    train_augmentations = [
        {"fn": "sample_lines", "line_length_pct": 0.5},
        {"fn": "insert_var_declaration", "prob": 0.5},
        {"fn": "rename_variable", "prob": 0.5},
    ]
    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)
    pad_id = sp.PieceToId("[PAD]")

    # Create training dataset and dataloader
    logger.info(f"Training data path {train_filepath}")
    train_dataset = get_csnjs_dataset(train_filepath, label_mode=label_mode, limit_size=limit_dataset_size)
    logger.info(f"Training dataset size: {len(train_dataset)}")
    train_loader = javascript_dataloader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        augmentations=train_augmentations,
        sp=sp,
        program_mode=program_mode,
        subword_regularization_alpha=subword_regularization_alpha,
    )

    # Create eval dataset and dataloader
    logger.info(f"Eval data path {eval_filepath}")
    eval_dataset = get_csnjs_dataset(eval_filepath, label_mode=label_mode, limit_size=limit_dataset_size)
    logger.info(f"Eval dataset size: {len(eval_dataset)}")
    eval_loader = javascript_dataloader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        augmentations=[],
        sp=sp,
        program_mode=eval_program_mode,
        subword_regularization_alpha=subword_regularization_alpha,
    )

    # Create model
    pad_id = sp.PieceToId("[PAD]")
    if model_type == "transformer":
        model = TransformerModel(n_tokens=sp.GetPieceSize(), pad_id=pad_id, n_decoder_layers=n_decoder_layers, d_model=d_model)
        logger.info(f"Created TransformerModel with {count_parameters(model)} params")
    elif model_type == "lstm":
        model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(), pad_id=pad_id, d_model=d_model)
        logger.info(f"Created Seq2SeqLSTM with {count_parameters(model)} params")

    # Load checkpoint
    if resume_path:
        logger.info(f"Resuming training from checkpoint {resume_path}, resume_encoder_name={resume_encoder_name}")
        checkpoint = torch.load(resume_path)
        pretrained_state_dict = checkpoint["model_state_dict"]
        encoder_state_dict = {}
        assert resume_encoder_name in ["encoder_k", "encoder_q", "encoder"]

        for key, value in pretrained_state_dict.items():
            if key.startswith(resume_encoder_name + ".") and "project_layer" not in key:
                remapped_key = key[len(resume_encoder_name + ".") :]
                logger.debug(f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}")
                encoder_state_dict[remapped_key] = value
            if key.startswith(resume_encoder_name + ".") and "project_layer.0." in key and resume_project:
                remapped_key = key[len(resume_encoder_name + ".") :]
                logger.debug(f"Remapping checkpoint project key {key} to {remapped_key}. Value mean: {value.mean().item()}")
                encoder_state_dict[remapped_key] = value
        model.encoder.load_state_dict(encoder_state_dict, strict=False)
        logger.info(f"Loaded state dict from {resume_path}")
        logger.info(f"Loaded keys: {encoder_state_dict.keys()}")

    # Set up optimizer
    model = nn.DataParallel(model)
    model = model.cuda() if use_cuda else model
    wandb.watch(model, log="all")
    params = model.module.decoder.parameters() if train_decoder_only else model.parameters()
    optimizer = torch.optim.Adam(params, lr=lr, betas=(adam_beta1, adam_beta2), eps=1e-9)
    if use_lr_warmup:
        scheduler = get_linear_schedule_with_warmup(optimizer, 5000, len(train_loader) * num_epochs)
    else:
        scheduler = LambdaLR(optimizer, lr_lambda=lambda x: 1.0)

    global_step = 0
    min_eval_loss = float("inf")
    for epoch in tqdm.trange(1, num_epochs + 1, desc="training", unit="epoch", leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        if train_decoder_only:
            model.module.encoder.eval()
            model.module.decoder.train()
        else:
            model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for X, Y, X_lengths, Y_lengths in pbar:
            if use_cuda:
                X = X.cuda()
                Y = Y.cuda()
                X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
            optimizer.zero_grad()
            # NOTE: X and Y are [B, max_seq_len] tensors (batch first)
            logits = model(X, Y[:, :-1], X_lengths, Y_lengths)
            if loss_type == "nll_sequence":
                loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction='sum')
                loss = loss / X.size(0)  # Average over num sequences, not target sequence lengths
                                        # Thus, minimize bits per sequence.
            elif loss_type == "nll_token":
                loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id,)
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Log loss
            global_step += 1
            wandb.log(
                {"epoch": epoch, f"label-{label_mode}/train_loss": loss.item(), "lr": scheduler.get_last_lr()[0]}, step=global_step
            )
            pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}")

        # Evaluate
        logger.info(f"Evaluating model after epoch {epoch} ({global_step} steps)...")
        max_decode_len = 20 if label_mode == "identifier" else 200
        eval_loss = _evaluate(model, eval_loader, sp, use_cuda=use_cuda, max_decode_len=max_decode_len, loss_type=loss_type)
        logger.info(f"Evaluation loss after epoch {epoch} ({global_step} steps): {eval_loss:.4f}")
        wandb.log({"epoch": epoch, f"label-{label_mode}/eval_loss": eval_loss}, step=global_step)

        # Save checkpoint
        if save_every and epoch % save_every == 0 or eval_loss < min_eval_loss:
            checkpoint = {
                "model_state_dict": model.module.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch,
                "global_step": global_step,
                "config": config,
                "eval_loss": eval_loss,
            }
            if eval_loss < min_eval_loss:
                logger.info(f"New best evaluation loss: prev {min_eval_loss:.4f} > new {eval_loss:.4f}")
                min_eval_loss = eval_loss
                model_file = run_dir / "ckpt_best.pth"
            else:
                model_file = run_dir / f"ckpt_ep{epoch:04d}.pth"
            logger.info(f"Saving checkpoint to {model_file}...")
            torch.save(checkpoint, str(model_file.resolve()))
            wandb.save(str(model_file.resolve()))
            logger.info("Done.")

    if auto_test:
        best_ckpt = run_dir / "ckpt_best.pth"
        test(
            str(best_ckpt.resolve()),
            CSNJS_TEST_FILEPATH,
            spm_filepath,
            program_mode,
            label_mode,
            num_workers,
            -1,
            n_decoder_layers=n_decoder_layers,
        )
Exemplo n.º 2
0
def test(
    checkpoint_file: str,
    test_filepath: str = CSNJS_TEST_FILEPATH,
    spm_filepath: str = SPM_UNIGRAM_FILEPATH,
    program_mode="identity",
    label_mode="identifier",
    num_workers=1,
    limit_dataset_size=-1,
    batch_size=8,
    model_type="transformer",
    n_decoder_layers=4,
    d_model=512,
    use_cuda: bool = True,
):
    wandb.init(name=checkpoint_file, config=locals(), project="f1_eval", entity="ml4code")
    if use_cuda:
        assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False"
    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)

    # Create test dataset and dataloader
    logger.info(f"Test data path {test_filepath}")
    test_dataset = get_csnjs_dataset(test_filepath, label_mode=label_mode, limit_size=limit_dataset_size)
    logger.info(f"Test dataset size: {len(test_filepath)}")
    test_loader = javascript_dataloader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        sp=sp,
        program_mode=program_mode,
        subword_regularization_alpha=0,
        augmentations=[],
    )

    pad_id = sp.PieceToId("[PAD]")
    if model_type == "transformer":
        model = TransformerModel(n_tokens=sp.GetPieceSize(), pad_id=pad_id, n_decoder_layers=n_decoder_layers, d_model=d_model)
        logger.info(f"Created TransformerModel with {count_parameters(model)} params")
    elif model_type == "lstm":
        model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(), pad_id=pad_id, d_model=d_model)
        logger.info(f"Created Seq2SeqLSTM with {count_parameters(model)} params")
    if use_cuda:
        model = model.cuda()

    # Load checkpoint
    checkpoint = torch.load(checkpoint_file)
    pretrained_state_dict = checkpoint["model_state_dict"]
    print("CHECKPOINT", checkpoint_file)
    from pprint import pprint
    print("KEYS", checkpoint["model_state_dict"].keys())
    try:
        model.load_state_dict(pretrained_state_dict)
    except RuntimeError as e:
        logger.error(e)
        logger.error("Keys in checkpoint: " + str(list(pretrained_state_dict.keys())))
        raise e
    logger.info(f"Loaded state dict from {checkpoint_file}")

    # Evaluate NLL
    model.eval()
    with torch.no_grad():
        test_nll = calculate_nll(model, test_loader, sp, use_cuda=use_cuda, logger_fn=wandb.log)
    logger.info(f"NLL: {test_nll:.5f}")

    # Make metric
    metric = F1MetricMethodName()
    model.eval()
    with torch.no_grad():
        precision, recall, f1, sample_generations = calculate_f1_metric(metric, model, test_loader, sp, use_cuda=use_cuda, logger_fn=wandb.log)
    logger.info(f"NLL: {test_nll:.5f}")
    logger.info(f"Precision: {precision:.5f}%")
    logger.info(f"Recall: {recall:.5f}%")
    logger.info(f"F1: {f1:.5f}%")

    df_generations = pd.DataFrame(sample_generations)
    df_generations.to_pickle(os.path.join(wandb.run.dir, "sample_generations.pickle.gz"))
    wandb.save(os.path.join(wandb.run.dir, "sample_generations.pickle.gz"))