예제 #1
0
def train(
    run_name: str,
    # Data
    train_filepath: str = CSNJS_TRAIN_FILEPATH,
    eval_filepath: str = CSNJS_VALID_FILEPATH,
    spm_filepath: str = SPM_UNIGRAM_FILEPATH,
    program_mode="identity",
    eval_program_mode="identity",
    label_mode="identifier",
    num_workers=1,
    limit_dataset_size=-1,
    # Model
    model_type="transformer",
    n_decoder_layers=4,
    d_model: int = 512,
    resume_path: str = "",
    resume_encoder_name: str = "encoder_q",  # encoder_q, encoder_k, encoder
    resume_project: bool = False,
    # Optimization
    train_decoder_only: bool = False,
    num_epochs: int = 50,
    save_every: int = 2,
    batch_size: int = 256,
    lr: float = 8e-4,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.98,
    use_lr_warmup: bool = True,
    loss_type = "nll_token",  # nll_token or nll_sequence
    # Loss
    subword_regularization_alpha: float = 0,
    # Computational
    use_cuda: bool = True,
    auto_test: bool = True,
    seed: int = 0,
):
    """Train model"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    run_dir = RUN_DIR / run_name
    run_dir.mkdir(exist_ok=True, parents=True)
    logger.add(str((run_dir / "train.log").resolve()))
    logger.info(f"Saving logs, model checkpoints to {run_dir}")
    config = locals()
    logger.info(f"Config: {config}")
    wandb.init(name=run_name, config=config, job_type="training", project="identifier-prediction", entity="ml4code")

    if use_cuda:
        assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False"

    train_augmentations = [
        {"fn": "sample_lines", "line_length_pct": 0.5},
        {"fn": "insert_var_declaration", "prob": 0.5},
        {"fn": "rename_variable", "prob": 0.5},
    ]
    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)
    pad_id = sp.PieceToId("[PAD]")

    # Create training dataset and dataloader
    logger.info(f"Training data path {train_filepath}")
    train_dataset = get_csnjs_dataset(train_filepath, label_mode=label_mode, limit_size=limit_dataset_size)
    logger.info(f"Training dataset size: {len(train_dataset)}")
    train_loader = javascript_dataloader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        augmentations=train_augmentations,
        sp=sp,
        program_mode=program_mode,
        subword_regularization_alpha=subword_regularization_alpha,
    )

    # Create eval dataset and dataloader
    logger.info(f"Eval data path {eval_filepath}")
    eval_dataset = get_csnjs_dataset(eval_filepath, label_mode=label_mode, limit_size=limit_dataset_size)
    logger.info(f"Eval dataset size: {len(eval_dataset)}")
    eval_loader = javascript_dataloader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        augmentations=[],
        sp=sp,
        program_mode=eval_program_mode,
        subword_regularization_alpha=subword_regularization_alpha,
    )

    # Create model
    pad_id = sp.PieceToId("[PAD]")
    if model_type == "transformer":
        model = TransformerModel(n_tokens=sp.GetPieceSize(), pad_id=pad_id, n_decoder_layers=n_decoder_layers, d_model=d_model)
        logger.info(f"Created TransformerModel with {count_parameters(model)} params")
    elif model_type == "lstm":
        model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(), pad_id=pad_id, d_model=d_model)
        logger.info(f"Created Seq2SeqLSTM with {count_parameters(model)} params")

    # Load checkpoint
    if resume_path:
        logger.info(f"Resuming training from checkpoint {resume_path}, resume_encoder_name={resume_encoder_name}")
        checkpoint = torch.load(resume_path)
        pretrained_state_dict = checkpoint["model_state_dict"]
        encoder_state_dict = {}
        assert resume_encoder_name in ["encoder_k", "encoder_q", "encoder"]

        for key, value in pretrained_state_dict.items():
            if key.startswith(resume_encoder_name + ".") and "project_layer" not in key:
                remapped_key = key[len(resume_encoder_name + ".") :]
                logger.debug(f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}")
                encoder_state_dict[remapped_key] = value
            if key.startswith(resume_encoder_name + ".") and "project_layer.0." in key and resume_project:
                remapped_key = key[len(resume_encoder_name + ".") :]
                logger.debug(f"Remapping checkpoint project key {key} to {remapped_key}. Value mean: {value.mean().item()}")
                encoder_state_dict[remapped_key] = value
        model.encoder.load_state_dict(encoder_state_dict, strict=False)
        logger.info(f"Loaded state dict from {resume_path}")
        logger.info(f"Loaded keys: {encoder_state_dict.keys()}")

    # Set up optimizer
    model = nn.DataParallel(model)
    model = model.cuda() if use_cuda else model
    wandb.watch(model, log="all")
    params = model.module.decoder.parameters() if train_decoder_only else model.parameters()
    optimizer = torch.optim.Adam(params, lr=lr, betas=(adam_beta1, adam_beta2), eps=1e-9)
    if use_lr_warmup:
        scheduler = get_linear_schedule_with_warmup(optimizer, 5000, len(train_loader) * num_epochs)
    else:
        scheduler = LambdaLR(optimizer, lr_lambda=lambda x: 1.0)

    global_step = 0
    min_eval_loss = float("inf")
    for epoch in tqdm.trange(1, num_epochs + 1, desc="training", unit="epoch", leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        if train_decoder_only:
            model.module.encoder.eval()
            model.module.decoder.train()
        else:
            model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for X, Y, X_lengths, Y_lengths in pbar:
            if use_cuda:
                X = X.cuda()
                Y = Y.cuda()
                X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda()
            optimizer.zero_grad()
            # NOTE: X and Y are [B, max_seq_len] tensors (batch first)
            logits = model(X, Y[:, :-1], X_lengths, Y_lengths)
            if loss_type == "nll_sequence":
                loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction='sum')
                loss = loss / X.size(0)  # Average over num sequences, not target sequence lengths
                                        # Thus, minimize bits per sequence.
            elif loss_type == "nll_token":
                loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id,)
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Log loss
            global_step += 1
            wandb.log(
                {"epoch": epoch, f"label-{label_mode}/train_loss": loss.item(), "lr": scheduler.get_last_lr()[0]}, step=global_step
            )
            pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}")

        # Evaluate
        logger.info(f"Evaluating model after epoch {epoch} ({global_step} steps)...")
        max_decode_len = 20 if label_mode == "identifier" else 200
        eval_loss = _evaluate(model, eval_loader, sp, use_cuda=use_cuda, max_decode_len=max_decode_len, loss_type=loss_type)
        logger.info(f"Evaluation loss after epoch {epoch} ({global_step} steps): {eval_loss:.4f}")
        wandb.log({"epoch": epoch, f"label-{label_mode}/eval_loss": eval_loss}, step=global_step)

        # Save checkpoint
        if save_every and epoch % save_every == 0 or eval_loss < min_eval_loss:
            checkpoint = {
                "model_state_dict": model.module.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch,
                "global_step": global_step,
                "config": config,
                "eval_loss": eval_loss,
            }
            if eval_loss < min_eval_loss:
                logger.info(f"New best evaluation loss: prev {min_eval_loss:.4f} > new {eval_loss:.4f}")
                min_eval_loss = eval_loss
                model_file = run_dir / "ckpt_best.pth"
            else:
                model_file = run_dir / f"ckpt_ep{epoch:04d}.pth"
            logger.info(f"Saving checkpoint to {model_file}...")
            torch.save(checkpoint, str(model_file.resolve()))
            wandb.save(str(model_file.resolve()))
            logger.info("Done.")

    if auto_test:
        best_ckpt = run_dir / "ckpt_best.pth"
        test(
            str(best_ckpt.resolve()),
            CSNJS_TEST_FILEPATH,
            spm_filepath,
            program_mode,
            label_mode,
            num_workers,
            -1,
            n_decoder_layers=n_decoder_layers,
        )
                data_type, len(dataset.data)))
            logger.info(
                "data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated"
                .format(data_type, len(data_loader),
                        finetune_config["batch_size"]))

    # optimizer and losses
    writer = SummaryWriter(save_model_dir)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=finetune_config["lr"],
                                  weight_decay=finetune_config["weight_decay"],
                                  amsgrad=True)
    optimizer.zero_grad()
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                len(data_loaders["train"]),
                                                train_config["epochs"] *
                                                len(data_loaders["train"]),
                                                min_percent=0.0001)

    best_reg_losses = {"train": INF, "dev": INF, "test": INF}
    best_reg_epochs = {"train": -1, "dev": -1, "test": -1}

    for epoch in range(finetune_config["epochs"]):
        for data_type, data_loader in data_loaders.items():

            if data_type == "train":
                mean_reg_loss, mean_bp_loss = train(model,
                                                    optimizer,
                                                    scheduler,
                                                    data_type,
                                                    data_loader,
예제 #3
0
    def train(self):
        epoch_loss_history = list()
        min_validation_loss = sys.float_info.max
        patience_cnt = self.config.patience

        self.config.n_gpu = torch.cuda.device_count()

        t_total = len(self.train_data_loader) * self.config.n_epoch
        cur_step = 0

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            self.config.weight_decay
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters,
                                           lr=self.config.learning_rate)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=t_total)

        if self.config.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model).to(
                self.config.device)
        else:
            self.model = self.model.to(self.config.device)

        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            batch_loss_history = list()
            self.model.train()

            epoch_lm_loss = 0.0
            epoch_conv_loss = 0.0
            epoch_batch_loss = 0.0

            for batch_i, (input_utterances, input_utterances_mask,
                          target_utterance, target_utterance_mask,
                          input_user_ids, target_user_ids) in enumerate(
                              tqdm(self.train_data_loader, ncols=80)):

                input_utterances = torch.LongTensor(input_utterances).to(
                    self.config.device)
                input_utterances_mask = torch.LongTensor(
                    input_utterances_mask).to(self.config.device)
                target_utterance = torch.LongTensor(target_utterance).to(
                    self.config.device)
                target_utterance_mask = torch.LongTensor(
                    target_utterance_mask).to(self.config.device)

                user_available = input_user_ids[0] is not None

                if user_available:
                    input_user_ids = torch.LongTensor(input_user_ids).to(
                        self.config.device)
                    target_user_ids = torch.LongTensor(target_user_ids).to(
                        self.config.device)

                self.optimizer.zero_grad()
                self.model.zero_grad()

                loss_fn = torch.nn.CrossEntropyLoss(
                    ignore_index=self.config.pad_id)

                target, gt_target = target_utterance[
                    ..., :-1].contiguous(), target_utterance[...,
                                                             1:].contiguous()
                target_mask = target_utterance_mask[..., :-1].contiguous()

                if user_available:
                    target_user_ids = target_user_ids[..., :-1].contiguous()
                else:
                    input_user_ids = None
                    target_user_ids = None

                lm_output, conv_output = self.model(target, target_mask,
                                                    input_utterances,
                                                    input_utterances_mask,
                                                    target_user_ids,
                                                    input_user_ids)

                # 1. Calculate Language Model Loss
                outputs, labels = lm_output[..., :-1, :].contiguous(
                ), input_utterances[..., 1:].contiguous()
                lm_loss = loss_fn(outputs.view(-1, outputs.size(-1)),
                                  labels.view(-1))

                # 2. Calculate Conv Loss
                conv_loss = loss_fn(conv_output.view(-1, conv_output.size(-1)),
                                    gt_target.view(-1))

                # 3. Total Loss
                batch_loss = lm_loss * 0.2 + conv_loss

                assert not isnan(batch_loss.item())

                if self.config.n_gpu > 1:
                    batch_loss = batch_loss.mean()

                epoch_lm_loss = (batch_i * epoch_lm_loss +
                                 lm_loss.item()) / (batch_i + 1)
                epoch_conv_loss = (batch_i * epoch_conv_loss +
                                   conv_loss.item()) / (batch_i + 1)
                epoch_batch_loss = (batch_i * epoch_batch_loss +
                                    batch_loss.item()) / (batch_i + 1)

                if batch_i % self.config.print_every == 0:
                    tqdm.write(
                        f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item():.3f}'
                    )
                    self.writer.add_scalar('Train/lm_loss', lm_loss.item(),
                                           cur_step)
                    self.writer.add_scalar('Train/conv_loss', conv_loss.item(),
                                           cur_step)
                    self.writer.add_scalar('Train/loss', batch_loss.item(),
                                           cur_step)
                    self.writer.add_scalar('Train/learning_rate',
                                           self.scheduler.get_lr()[0],
                                           cur_step)

                batch_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.clip)
                self.optimizer.step()
                self.scheduler.step()
                cur_step += 1

            epoch_loss_history.append(epoch_batch_loss)
            self.epoch_loss = epoch_batch_loss

            print(f'Epoch {epoch_i+1} loss average: {epoch_batch_loss:.3f}')

            print('\n<Validation>...')
            val_loss, val_lm_loss, val_conv_loss = self.evaluate()
            self.validation_loss = val_loss

            if epoch_i % self.config.plot_every_epoch == 0:
                self.writer.add_scalar('Val/lm_loss', val_lm_loss, epoch_i + 1)
                self.writer.add_scalar('Val/conv_loss', val_conv_loss,
                                       epoch_i + 1)
                self.writer.add_scalar('Val/loss', val_loss, epoch_i + 1)

            self.save_model(epoch_i)

            if min_validation_loss > self.validation_loss:
                min_validation_loss = self.validation_loss
            else:
                patience_cnt -= 1

            if patience_cnt < 0:
                print(f'\nEarly stop at {epoch_i}')
                return epoch_loss_history

        self.save_model(self.config.n_epoch)

        return epoch_loss_history
예제 #4
0
def train(
    run_name: str,
    # Data
    train_filepath: str,
    eval_filepath: str,
    type_vocab_filepath: str,
    spm_filepath: str,
    num_workers=1,
    max_seq_len=1024,
    max_eval_seq_len=1024,
    run_dir=RUN_DIR,
    # Model
    resume_path: str = "",
    pretrain_resume_path: str = "",
    pretrain_resume_encoder_name:
    str = "encoder_q",  # encoder_q, encoder_k, encoder
    pretrain_resume_project: bool = False,
    no_output_attention: bool = False,
    encoder_type: str = "transformer",
    n_encoder_layers: int = 6,
    d_model: int = 512,
    # Output layer hparams
    d_out_projection: int = 512,
    n_hidden_output: int = 1,
    # Optimization
    num_epochs: int = 100,
    save_every: int = 2,
    batch_size: int = 256,
    lr: float = 8e-4,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.98,
    adam_eps: float = 1e-6,
    weight_decay: float = 0,
    warmup_steps: int = 5000,
    num_steps: int = 200000,
    # Augmentations
    subword_regularization_alpha: float = 0,
    sample_lines_prob: float = 0,
    sample_lines_prob_keep_line: float = 0.9,
    # Loss
    ignore_any_loss: bool = False,
    # Computational
    use_cuda: bool = True,
    seed: int = 1,
):
    """Train model"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if run_dir != RUN_DIR:
        run_dir = Path(run_dir)
    run_dir = run_dir / run_name
    run_dir.mkdir(exist_ok=True, parents=True)
    logger.add(str((run_dir / "train.log").resolve()))
    logger.info(f"Saving logs, model checkpoints to {run_dir}")
    config = locals()
    logger.info(f"Config: {config}")
    wandb.init(name=run_name,
               config=config,
               job_type="training",
               project="type_prediction",
               entity="ml4code")

    if use_cuda:
        assert torch.cuda.is_available(
        ), "CUDA not available. Check env configuration, or pass --use_cuda False"

    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)
    pad_id = sp.PieceToId("[PAD]")

    id_to_target, target_to_id = load_type_vocab(type_vocab_filepath)
    no_type_id = target_to_id["O"]
    assert no_type_id == 0  # Just a sense check since O is the first line in the vocab file
    any_id = target_to_id["$any$"]

    collate_fn = get_collate_fn(pad_id, no_type_id)

    # Create training dataset and dataloader
    logger.info(f"Training data path {train_filepath}")
    if sample_lines_prob > 0:
        augmentations = [
            {
                "fn": "sample_lines",
                "options": {
                    "prob": sample_lines_prob,
                    "prob_keep_line": sample_lines_prob_keep_line
                }
            },
        ]
        program_mode = "augmentation"
    else:
        augmentations = None
        program_mode = "identity"
    train_dataset = DeepTyperDataset(
        train_filepath,
        type_vocab_filepath,
        spm_filepath,
        max_length=max_seq_len,
        subword_regularization_alpha=subword_regularization_alpha,
        augmentations=augmentations,
        program_mode=program_mode,
    )
    logger.info(f"Training dataset size: {len(train_dataset)}")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers,
                                               drop_last=True,
                                               collate_fn=collate_fn)

    # Create eval dataset and dataloader
    logger.info(f"Eval data path {eval_filepath}")
    eval_dataset = DeepTyperDataset(
        eval_filepath,
        type_vocab_filepath,
        spm_filepath,
        max_length=max_eval_seq_len,
        subword_regularization_alpha=0,
        split_source_targets_by_tab=eval_filepath.endswith(".json"),
    )
    logger.info(f"Eval dataset size: {len(eval_dataset)}")
    eval_loader = torch.utils.data.DataLoader(eval_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)

    # Create model
    model = TypeTransformer(
        n_tokens=sp.GetPieceSize(),
        n_output_tokens=len(id_to_target),
        pad_id=pad_id,
        encoder_type=encoder_type,
        n_encoder_layers=n_encoder_layers,
        d_model=d_model,
        d_out_projection=d_out_projection,
        n_hidden_output=n_hidden_output,
    )
    logger.info(
        f"Created TypeTransformer {encoder_type} with {count_parameters(model)} params"
    )

    # Load pretrained checkpoint
    if pretrain_resume_path:
        assert not resume_path
        logger.info(
            f"Resuming training from pretraining checkpoint {pretrain_resume_path}, pretrain_resume_encoder_name={pretrain_resume_encoder_name}"
        )
        checkpoint = torch.load(pretrain_resume_path)
        pretrained_state_dict = checkpoint["model_state_dict"]
        encoder_state_dict = {}
        output_state_dict = {}
        assert pretrain_resume_encoder_name in [
            "encoder_k", "encoder_q", "encoder"
        ]

        for key, value in pretrained_state_dict.items():
            if key.startswith(pretrain_resume_encoder_name +
                              ".") and "project_layer" not in key:
                remapped_key = key[len(pretrain_resume_encoder_name + "."):]
                logger.debug(
                    f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}"
                )
                encoder_state_dict[remapped_key] = value
            if key.startswith(
                    pretrain_resume_encoder_name + "."
            ) and "project_layer.0." in key and pretrain_resume_project:
                remapped_key = key[len(pretrain_resume_encoder_name +
                                       ".project_layer."):]
                logger.debug(
                    f"Remapping checkpoint project key {key} to output key {remapped_key}. Value mean: {value.mean().item()}"
                )
                output_state_dict[remapped_key] = value
        model.encoder.load_state_dict(encoder_state_dict)
        # TODO: check for head key rather than output for MLM
        model.output.load_state_dict(output_state_dict, strict=False)
        logger.info(f"Loaded state dict from {pretrain_resume_path}")

    # Set up optimizer
    model = nn.DataParallel(model)
    model = model.cuda() if use_cuda else model
    wandb.watch(model, log="all")
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=(adam_beta1, adam_beta2),
                                 eps=adam_eps,
                                 weight_decay=weight_decay)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps,
                                                num_steps)
    epoch = 0
    global_step = 0
    min_eval_metric = float("inf")

    if resume_path:
        assert not pretrain_resume_path
        logger.info(f"Resuming training from checkpoint {resume_path}")
        checkpoint = torch.load(resume_path)
        model.module.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        global_step = checkpoint["global_step"]
        min_eval_metric = checkpoint["min_eval_metric"]

    # Eval metric history
    max_eval_metrics = {}

    # Evaluate initial metrics
    logger.info(
        f"Evaluating model after epoch {epoch} ({global_step} steps)...")
    eval_metric, eval_metrics = _evaluate(
        model,
        eval_loader,
        sp,
        target_to_id=target_to_id,
        use_cuda=use_cuda,
        no_output_attention=no_output_attention)
    for metric, value in eval_metrics.items():
        logger.info(
            f"Evaluation {metric} after epoch {epoch} ({global_step} steps): {value:.4f}"
        )
        max_eval_metrics[metric] = value
    eval_metrics["epoch"] = epoch
    wandb.log(eval_metrics, step=global_step)
    wandb.log({k + "_max": v
               for k, v in max_eval_metrics.items()},
              step=global_step)

    for epoch in tqdm.trange(epoch + 1,
                             num_epochs + 1,
                             desc="training",
                             unit="epoch",
                             leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for X, lengths, output_attn, labels in pbar:
            if use_cuda:
                X, lengths, output_attn, labels = X.cuda(), lengths.cuda(
                ), output_attn.cuda(), labels.cuda()
            optimizer.zero_grad()
            if no_output_attention:
                logits = model(X, lengths, None)  # BxLxVocab
            else:
                logits = model(X, lengths, output_attn)  # BxLxVocab
            if ignore_any_loss:
                # Don't train with $any$ type
                labels_ignore_any = labels.clone()
                labels_ignore_any[labels_ignore_any == any_id] = no_type_id
                loss = F.cross_entropy(logits.transpose(1, 2),
                                       labels_ignore_any,
                                       ignore_index=no_type_id)
            else:
                loss = F.cross_entropy(logits.transpose(1, 2),
                                       labels,
                                       ignore_index=no_type_id)
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Compute accuracy in training batch
            (corr1_any,
             corr5_any), num_labels_any = accuracy(logits,
                                                   labels,
                                                   topk=(1, 5),
                                                   ignore_idx=(no_type_id, ))
            acc1_any, acc5_any = corr1_any / num_labels_any * 100, corr5_any / num_labels_any * 100
            (corr1, corr5), num_labels = accuracy(logits,
                                                  labels,
                                                  topk=(1, 5),
                                                  ignore_idx=(no_type_id,
                                                              any_id))
            acc1, acc5 = corr1 / num_labels * 100, corr5 / num_labels * 100

            # Log loss
            global_step += 1
            wandb.log(
                {
                    "epoch": epoch,
                    "train/loss": loss.item(),
                    "train/acc@1": acc1,
                    "train/acc@5": acc5,
                    "train/acc@1_any": acc1_any,
                    "train/acc@5_any": acc5_any,
                    "lr": scheduler.get_last_lr()[0],
                },
                step=global_step,
            )
            pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}")

        # Evaluate
        logger.info(
            f"Evaluating model after epoch {epoch} ({global_step} steps)...")
        eval_metric, eval_metrics = _evaluate(
            model,
            eval_loader,
            sp,
            target_to_id=target_to_id,
            use_cuda=use_cuda,
            no_output_attention=no_output_attention)
        for metric, value in eval_metrics.items():
            logger.info(
                f"Evaluation {metric} after epoch {epoch} ({global_step} steps): {value:.4f}"
            )
            max_eval_metrics[metric] = max(value, max_eval_metrics[metric])
        eval_metrics["epoch"] = epoch
        wandb.log(eval_metrics, step=global_step)
        wandb.log({k + "_max": v
                   for k, v in max_eval_metrics.items()},
                  step=global_step)

        # Save checkpoint
        if save_every and epoch % save_every == 0 or eval_metric < min_eval_metric:
            checkpoint = {
                "model_state_dict": model.module.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "epoch": epoch,
                "global_step": global_step,
                "config": config,
                "eval_metric": eval_metric,
                "min_eval_metric": min_eval_metric,
            }
            if eval_metric < min_eval_metric:
                logger.info(
                    f"New best evaluation metric: prev {min_eval_metric:.4f} > new {eval_metric:.4f}"
                )
                min_eval_metric = eval_metric
                model_file = run_dir / "ckpt_best.pth"
            else:
                model_file = run_dir / f"ckpt_ep{epoch:04d}.pth"
            logger.info(f"Saving checkpoint to {model_file}...")
            torch.save(checkpoint, str(model_file.resolve()))
            logger.info("Done.")
    def train(self):
        epoch_loss_history = list()
        min_validation_loss = sys.float_info.max
        patience_cnt = self.config.patience

        self.config.n_gpu = torch.cuda.device_count()

        t_total = len(self.train_data_loader) * self.config.n_epoch
        cur_step = 0

        no_decay = ['bias', 'ln']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            self.config.weight_decay
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters,
                                           lr=self.config.learning_rate)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=t_total)

        if self.config.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model).to(
                self.config.device)
        else:
            self.model = self.model.to(self.config.device)

        for epoch_i in range(self.epoch_i, self.config.n_epoch):
            self.epoch_i = epoch_i
            self.model.train()

            epoch_loss = 0.0

            for batch_i, batch in enumerate(
                    tqdm(self.train_data_loader, ncols=80)):
                self.optimizer.zero_grad()
                self.model.zero_grad()

                batch = tuple(t.to(self.config.device) for t in batch)

                input_ids, position_ids, token_ids, label_ids, user_ids, user_mask = batch

                inputs = {
                    'input_ids': input_ids,
                    'position_ids': position_ids,
                    'token_type_ids': token_ids,
                    'user_mask': user_mask
                }

                outputs = self.model(**inputs)

                loss_fn = nn.CrossEntropyLoss(ignore_index=-1)

                if self.config.users and self.config.reversed:
                    lm_logits, user_outputs = outputs
                    user_ids = user_ids.view(-1)
                    user_ids_mask = user_ids != -1
                    user_ids = user_ids[user_ids_mask]
                    user_loss = loss_fn(
                        user_outputs.view(-1, user_outputs.size(-1)),
                        user_ids.view(-1))
                    output_loss = loss_fn(
                        lm_logits.view(-1, lm_logits.size(-1)),
                        label_ids.view(-1))
                    batch_loss = user_loss + output_loss
                else:
                    batch_loss = loss_fn(outputs.view(-1, outputs.size(-1)),
                                         label_ids.view(-1))
                    user_loss = None
                    output_loss = None

                assert not isnan(batch_loss.item())

                if self.config.n_gpu > 1:
                    batch_loss = batch_loss.mean()

                epoch_loss = (batch_i * epoch_loss +
                              batch_loss.item()) / (batch_i + 1)

                if batch_i % self.config.print_every == 0:
                    tqdm.write(
                        f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item():.3f}'
                    )
                    self.writer.add_scalar('Train/loss', batch_loss.item(),
                                           cur_step)
                    if self.config.users and self.config.reversed:
                        self.writer.add_scalar('Train/user_loss',
                                               user_loss.item(), cur_step)
                        self.writer.add_scalar('Train/lmloss',
                                               output_loss.item(), cur_step)
                    self.writer.add_scalar('Train/learning_rate',
                                           self.scheduler.get_lr()[0],
                                           cur_step)

                batch_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.config.clip)
                self.optimizer.step()
                self.scheduler.step()
                cur_step += 1

            if epoch_i == 0:
                if self.config.users and not self.config.reversed:
                    self.vocab.save_pretrained(self.config.save_path)

            epoch_loss_history.append(epoch_loss)
            self.epoch_loss = epoch_loss

            print(f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}')

            print('\n<Validation>...')
            val_loss, output_loss, user_loss = self.evaluate()
            self.validation_loss = val_loss

            if epoch_i % self.config.plot_every_epoch == 0:
                self.writer.add_scalar('Val/loss', self.validation_loss,
                                       epoch_i + 1)
                if output_loss != None and user_loss != None:
                    self.writer.add_scalar('Val/lmloss', output_loss.item(),
                                           epoch_i + 1)
                    self.writer.add_scalar('Val/user_loss', user_loss.item(),
                                           epoch_i + 1)

            if min_validation_loss > self.validation_loss:
                min_validation_loss = self.validation_loss
            else:
                patience_cnt -= 1
                self.save_model(epoch_i)

            if patience_cnt < 0:
                print(f'\nEarly stop at {epoch_i}')
                self.save_model(epoch_i)
                return epoch_loss_history

        self.save_model(self.config.n_epoch)

        return epoch_loss_history
예제 #6
0
def train(
    run_name: str,
    # Data
    train_filepath: str = "data/codeclone/train_data.json",
    eval_filepath: str = "data/codeclone/valid_data.json",
    spm_filepath: str = "data/codesearchnet_javascript/csnjs_8k_9995p_unigram_url.model",
    num_workers=1,
    max_seq_len=1024,
    max_eval_seq_len=1024,
    run_dir=RUN_DIR,
    balance_negatives=False,
    # Model
    resume_path: str = "",
    pretrain_resume_path: str = "",
    pretrain_resume_encoder_name: str = "encoder_q",  # encoder_q, encoder_k, encoder
    encoder_type: str = "transformer",
    n_encoder_layers: int = 6,
    d_model: int = 512,
    critic_type: str = "bilinear_identity",
    critic_bilinear_rank: int = None,
    # Optimization
    train_decoder_only: bool = False,
    num_epochs: int = 100,
    batch_size: int = 256,
    lr: float = 8e-4,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.98,
    adam_eps: float = 1e-6,
    weight_decay: float = 0,
    warmup_steps: int = 5000,
    num_steps: int = 200000,
    # Evaluation
    save_every_steps: int = 5000,
    score_every_steps: int = 10,  # Interval for train ROC AUC, AP score
    evaluate_every_steps: int = 1250,
    # Augmentations
    subword_regularization_alpha: float = 0,
    # Computational
    use_cuda: bool = True,
    seed: int = 1,
):
    """Train model"""
    assert save_every_steps % evaluate_every_steps == 0

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if run_dir != RUN_DIR:
        run_dir = Path(run_dir)
    run_dir = run_dir / run_name
    run_dir.mkdir(exist_ok=True, parents=True)
    logger.add(str((run_dir / "train.log").resolve()))
    logger.info(f"Saving logs, model checkpoints to {run_dir}")
    config = locals()
    logger.info(f"Config: {config}")
    wandb.init(name=run_name, config=config, job_type="training", project="clone_detection", entity="ml4code")

    if use_cuda:
        assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False"

    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)
    pad_id = sp.PieceToId("[PAD]")
    pad_collate = get_pad_collate(pad_id)

    # Create training dataset and dataloader
    logger.info(f"Training data path {train_filepath}")
    train_programs = CloneProgramsDataset(train_filepath, sp, subword_regularization_alpha)
    train_positives = ClonePositivesDataset(train_programs)
    train_negatives = CloneNegativesDataset(train_programs)
    train_dataset = torch.utils.data.ConcatDataset([train_positives, train_negatives])
    if balance_negatives:
        positive_weight = 1 / len(train_programs)
        negative_weight = 1 / len(train_negatives)
        weights = torch.cat([torch.zeros(len(train_programs)) + positive_weight, torch.zeros(len(train_negatives)) + negative_weight])
        sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
    else:
        sampler = None
    logger.info(f"Training dataset size: {len(train_dataset)}")
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=not balance_negatives,
        num_workers=num_workers,
        drop_last=True,
        collate_fn=pad_collate,
        sampler=sampler,
    )

    # Create eval dataset and dataloader
    logger.info(f"Eval data path {eval_filepath}")
    eval_programs = CloneProgramsDataset(eval_filepath, sp, subword_regularization_alpha)
    eval_positives = ClonePositivesDataset(eval_programs)
    eval_negatives = CloneNegativesDataset(eval_programs)
    eval_dataset = torch.utils.data.ConcatDataset([eval_positives, eval_negatives])
    logger.info(f"Eval dataset size: {len(eval_dataset)}")
    eval_loader = torch.utils.data.DataLoader(
        eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=pad_collate
    )

    # Create model
    model = CloneModel(
        n_tokens=sp.GetPieceSize(),
        pad_id=pad_id,
        encoder_type=encoder_type,
        n_encoder_layers=n_encoder_layers,
        d_model=d_model,
        critic_type=critic_type,
        bilinear_rank=critic_bilinear_rank,
    )
    logger.info(f"Created CloneModel {encoder_type}, {critic_type} with {count_parameters(model)} params")

    # Load pretrained checkpoint
    if pretrain_resume_path:
        assert not resume_path
        logger.info(
            f"Resuming training from pretraining checkpoint {pretrain_resume_path}, pretrain_resume_encoder_name={pretrain_resume_encoder_name}"
        )
        checkpoint = torch.load(pretrain_resume_path)
        pretrained_state_dict = checkpoint["model_state_dict"]
        encoder_state_dict = {}
        # output_state_dict = {}
        assert pretrain_resume_encoder_name in ["encoder_k", "encoder_q", "encoder"]

        for key, value in pretrained_state_dict.items():
            if key.startswith(pretrain_resume_encoder_name + ".") and "project_layer" not in key:
                remapped_key = key[len(pretrain_resume_encoder_name + ".") :]
                logger.debug(f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}")
                encoder_state_dict[remapped_key] = value
        model.encoder.load_state_dict(encoder_state_dict)
        logger.info(f"Loaded state dict from {pretrain_resume_path}")

    # Set up optimizer
    model = nn.DataParallel(model)
    model = model.cuda() if use_cuda else model
    wandb.watch(model, log="all")
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(adam_beta1, adam_beta2), eps=adam_eps, weight_decay=weight_decay)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, num_steps)
    epoch = 0
    global_step = 0
    min_eval_metric = float("inf")

    if resume_path:
        assert not pretrain_resume_path
        logger.info(f"Resuming training from checkpoint {resume_path}")
        checkpoint = torch.load(resume_path)
        model.module.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        global_step = checkpoint["global_step"]
        min_eval_metric = checkpoint["min_eval_metric"]

    # Eval metric history
    max_eval_metrics = {}

    # Initial evaluation
    logger.info(f"Evaluating model initially ({global_step} steps)...")
    eval_metric, eval_metrics = _evaluate(model, eval_loader, sp, pad_id=pad_id, use_cuda=use_cuda)
    for metric, value in eval_metrics.items():
        logger.info(f"Evaluation {metric} initial ({global_step} steps): {value:.4f}")
        max_eval_metrics[metric] = value
    eval_metrics["epoch"] = epoch
    wandb.log(eval_metrics, step=global_step)
    wandb.log({k + "_max": v for k, v in max_eval_metrics.items()}, step=global_step)

    for epoch in tqdm.trange(epoch + 1, num_epochs + 1, desc="training", unit="epoch", leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        model.train()
        model.module.encoder.eval()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for X, lengths, labels in pbar:
            if use_cuda:
                X, lengths, labels = X.cuda(), lengths.cuda(), labels.cuda()
            optimizer.zero_grad()
            similarity = model(X, lengths)  # B
            loss = F.binary_cross_entropy_with_logits(similarity, labels.float())
            loss.backward()
            optimizer.step()
            scheduler.step()

            with torch.no_grad():
                train_metrics = {
                    "epoch": epoch,
                    "train/loss": loss.item(),
                    "lr": scheduler.get_last_lr()[0],
                    "train/labels_mean": labels.float().mean().item(),
                }

                # Compute scores in training batch
                if global_step % score_every_steps == 0:
                    y_true = labels.cpu().numpy()
                    y_scores = similarity.cpu().numpy()
                    roc_auc = roc_auc_score(y_true, y_scores)
                    ap_score = average_precision_score(y_true, y_scores)
                    train_metrics["train/roc_auc_score"] = roc_auc
                    train_metrics["train/ap_score"] = ap_score
                    pbar.set_description(f"epoch {epoch} loss {loss.item():.4f} roc_auc {roc_auc:.4f} ap {ap_score:.4f}")

                # Log loss
                global_step += 1
                wandb.log(train_metrics, step=global_step)

                # Evaluate
                if evaluate_every_steps and global_step % evaluate_every_steps == 0:
                    logger.info(f"Evaluating model after epoch {epoch} ({global_step} steps)...")
                    eval_metric, eval_metrics = _evaluate(model, eval_loader, sp, pad_id=pad_id, use_cuda=use_cuda)
                    for metric, value in eval_metrics.items():
                        logger.info(f"Evaluation {metric} after epoch {epoch} ({global_step} steps): {value:.4f}")
                        max_eval_metrics[metric] = max(value, max_eval_metrics[metric])
                    eval_metrics["epoch"] = epoch
                    wandb.log(eval_metrics, step=global_step)
                    wandb.log({k + "_max": v for k, v in max_eval_metrics.items()}, step=global_step)

                    # Save checkpoint
                    if save_every_steps and global_step % save_every_steps == 0 or eval_metric < min_eval_metric:
                        checkpoint = {
                            "model_state_dict": model.module.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                            "epoch": epoch,
                            "global_step": global_step,
                            "config": config,
                            "eval_metric": eval_metric,
                            "min_eval_metric": min_eval_metric,
                        }
                        if eval_metric < min_eval_metric:
                            logger.info(f"New best evaluation metric: prev {min_eval_metric:.4f} > new {eval_metric:.4f}")
                            min_eval_metric = eval_metric
                            model_file = run_dir / "ckpt_best.pth"
                        else:
                            model_file = run_dir / f"ckpt_ep{epoch:04d}_step{global_step}.pth"
                        logger.info(f"Saving checkpoint to {model_file}...")
                        torch.save(checkpoint, str(model_file.resolve()))
                        logger.info("Done.")
def run():
    dfx = pd.read_csv(config.TRAINING_FILE,
                      nrows=config.NROWS).dropna().reset_index(drop=True)
    # dfx.sentiment = dfx.sentiment.apply(
    #     lambda x: 1 if x =='positive' else 0
    # )
    print('Data Loaded')
    df_train, df_valid = model_selection.train_test_split(
        dfx, test_size=0.5, random_state=42, stratify=dfx.sentiment.values)
    print('Data split into train data and validation data')
    df_train = df_train.reset_index(drop=True)
    df_valid = df_valid.reset_index(drop=True)

    train_dataset = dataset.TweetDataset(
        tweet=df_train.text.values,
        sentiment=df_train.sentiment.values,
        selected_text=df_train.selected_text.values)

    print('Train data preprocessed and made into Tweet Dataset Object')

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        shuffle=True,
        num_workers=4)

    print('Train dataloader created')
    valid_dataset = dataset.TweetDataset(
        tweet=df_valid.text.values,
        sentiment=df_valid.sentiment.values,
        selected_text=df_valid.selected_text.values)
    print('Valid data preprocessed and made into Tweet Dataset Object')
    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=config.VALID_BATCH_SIZE, num_workers=1)
    print('Valid dataloader created')
    device = config.DEVICE
    conf = transformers.RobertaConfig.from_pretrained(
        f'{config.PATH}roberta-base-config.json')
    conf.output_hidden_states = False

    model = Roberta(conf)
    model.to(device)
    print('Model Object created')

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.001
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    num_train_steps = int(
        len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS)
    optimizer = AdamW(optimizer_parameters, lr=3e-5)
    scheduler = utils.get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=num_train_steps)

    best_jaccard = 0
    print('Starting Training....')
    for epoch in range(config.EPOCHS):
        engine.train_fn(train_data_loader, model, optimizer, device, scheduler)
        jaccard = engine.eval_fn(valid_data_loader, model, device)

        print(f'Jaccard Score : {jaccard}')
        if jaccard > best_jaccard:
            torch.save(model.state_dict(), config.MODEL_PATH)
            best_jaccard = jaccard
예제 #8
0
def pretrain_worker(gpu, ngpus_per_node, config):
    chief_node = gpu == 0
    if chief_node:
        if config["loss_mode"] == "mlm":
            project = "bert-pretrain"
        elif config["loss_mode"] == "infonce":
            project = "moco-pretrain"
        elif config["loss_mode"] == "hybrid":
            project = "hybrid"
        wandb.init(name=config["run_name"],
                   config=config,
                   job_type="training",
                   project=project,
                   entity="ml4code")

    if gpu is not None:
        logger.info("Use GPU: {} for training".format(gpu))

    if config["dist_url"] == "env://" and config["rank"] == -1:
        config["rank"] = int(os.environ["RANK"])
    # For multiprocessing distributed training, rank needs to be the
    # global rank among all the processes
    config["rank"] = config["rank"] * ngpus_per_node + gpu
    dist.init_process_group(backend=config["dist_backend"],
                            init_method=config["dist_url"],
                            world_size=config["world_size"],
                            rank=config["rank"])

    sp = spm.SentencePieceProcessor()
    sp.Load(config["spm_filepath"])
    pad_id = sp.PieceToId("[PAD]")
    mask_id = sp.PieceToId("[MASK]")

    def pad_collate(batch):
        B = len(batch)
        if config["program_mode"] == "contrastive":
            X1, X2 = zip(*batch)
            X = X1 + X2
        else:
            X = batch

        # Create tensor of sequence lengths, [B] or [2B]
        lengths = torch.tensor([len(x) for x in X], dtype=torch.long)

        # Create padded tensor for batch, [B, T] or [2B, T]
        X = pad_sequence(X, batch_first=True, padding_value=pad_id)

        if config["program_mode"] == "contrastive":
            # Reshape X to [B, 2, T]
            T = X.size(-1)
            X = torch.reshape(X, (2, B, -1))
            X = torch.transpose(X, 0, 1)
            assert X.shape == (B, 2, T)
            lengths = torch.reshape(lengths, (2, B)).transpose(0, 1)
            assert lengths.shape == (B, 2)
        return X, lengths, None

    # Create model
    if config["loss_mode"] == "infonce":
        model = CodeMoCo(sp.GetPieceSize(),
                         pad_id=pad_id,
                         d_model=config["d_model"],
                         encoder_config=dict(
                             encoder_type=config["encoder_type"],
                             lstm_project_mode=config["lstm_project_mode"],
                             n_encoder_layers=config["n_encoder_layers"]))
        logger.info(
            f"Created CodeMoCo model with {count_parameters(model)} params")
    elif config["loss_mode"] == "mlm":
        model = CodeMLM(sp.GetPieceSize(),
                        pad_id=pad_id,
                        encoder_type=config["encoder_type"],
                        n_encoder_layers=config["n_encoder_layers"])
        logger.info(
            f"Created CodeMLM model with {count_parameters(model)} params")
    elif config["loss_mode"] == "hybrid":
        model = CodeContrastiveMLM(sp.GetPieceSize(), pad_id=pad_id)
        logger.info(
            f"Created CodeContrastiveMLM model with {count_parameters(model)} params"
        )
    else:
        raise ValueError(f"Bad loss mode {config['loss_mode']}")

    assert config["use_cuda"]
    if gpu is not None:
        torch.cuda.set_device(gpu)
        model.cuda(gpu)
        # When using a single GPU per process and per
        # DistributedDataParallel, we need to divide the batch size
        # ourselves based on the total number of GPUs we have
        config["batch_size"] = int(config["batch_size"] / ngpus_per_node)
        config["num_workers"] = int(
            (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[gpu])
    else:
        model.cuda()
        # DistributedDataParallel will divide and allocate batch_size to all
        # available GPUs if device_ids are not set
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config["lr"],
                                 betas=config["adam_betas"],
                                 eps=1e-6,
                                 weight_decay=config["weight_decay"])
    sched = get_linear_schedule_with_warmup(optimizer, config["warmup_steps"],
                                            config["num_steps"])

    # Setup data
    train_dataset = PrecomputedDataset(
        config["train_filepath"],
        min_alternatives=config["min_alternatives"],
        program_mode=config["program_mode"],
        limit_size=config["limit_dataset_size"],
        sp=sp,
        subword_regularization_alpha=config["subword_regularization_alpha"],
        max_length=config["max_length"])
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        collate_fn=pad_collate,
        # num_workers=config["num_workers"],
        num_workers=0,
        drop_last=True,
        pin_memory=True,
        sampler=train_sampler,
    )

    # Train
    global_step = 0
    for epoch in tqdm.trange(1,
                             config["num_epochs"] + 1,
                             desc="training",
                             unit="epoch",
                             leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        train_sampler.set_epoch(epoch)
        model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for batch in pbar:
            optimizer.zero_grad()
            if config["loss_mode"] == "infonce":
                train_metrics = training_step(model,
                                              batch,
                                              use_cuda=config["use_cuda"])
            elif config["loss_mode"] == "mlm":
                # replace tokens randomly with tokens from _ (8)
                train_metrics = training_step_mlm(sp,
                                                  model,
                                                  batch,
                                                  pad_id=pad_id,
                                                  mask_id=mask_id,
                                                  vocab_start_idx=8,
                                                  vocab_end_idx=7999,
                                                  use_cuda=config["use_cuda"])
            elif config["loss_mode"] == "hybrid":
                train_metrics = training_step_hybrid(
                    sp,
                    model,
                    batch,
                    mask_id=mask_id,
                    pad_id=pad_id,
                    vocab_start_idx=0,
                    vocab_end_idx=7999,
                    use_cuda=config["use_cuda"])
            else:
                raise ValueError("Bad loss type")
            loss = train_metrics["loss"]
            loss.backward()
            optimizer.step()
            sched.step()

            global_step += 1
            pbar.set_description(
                f"epoch {epoch} gpu {gpu} step {global_step} loss {loss.item():.4f}"
            )

            if chief_node:
                wandb.log(dict(lr=sched.get_last_lr()[0]))
                wandb.log(dict(epoch=epoch, **train_metrics["log"]),
                          step=global_step)

                # Save checkpoint
                if config["save_every"] and global_step % config[
                        "save_every"] == 0:
                    checkpoint = {
                        "model_state_dict": model.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "epoch": epoch,
                        "global_step": global_step,
                        "config": config,
                    }
                    model_file = os.path.join(
                        config["run_dir"],
                        f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth"
                    )
                    logger.info(f"Saving checkpoint to {model_file}...")
                    torch.save(checkpoint, model_file)
                    wandb.save(str(model_file))
                    logger.info("Done.")
예제 #9
0
def pretrain(
    run_name: str,
    #
    # Data
    train_filepath: str = DEFAULT_CSNJS_TRAIN_FILEPATH,
    spm_filepath: str = DEFAULT_SPM_UNIGRAM_FILEPATH,
    num_workers=1,
    limit_dataset_size=-1,
    max_length=1024,
    subword_regularization_alpha: float = 0,
    program_mode="contrastive",
    loss_mode="infonce",  # infonce, mlm, or hybrid
    min_alternatives=1,
    #
    # Model
    resume_path: str = "",
    encoder_type: str = "transformer",
    lstm_project_mode: str = "hidden",
    n_encoder_layers: int = 6,
    d_model: int = 512,
    n_head: int = 8,
    #
    # Optimization
    num_epochs: int = 100,
    save_every: int = 1,
    batch_size: int = 256,
    lr: float = 8e-4,
    weight_decay: float = 0,
    adam_betas=(0.9, 0.98),
    warmup_steps: int = 5000,
    num_steps: int = 600000,
    #
    # Horovod
    use_adasum: bool = False,
    fp16_allreduce: bool = False,
    gradient_predivide_factor: float = 1.0,
    #
    # Computational
    use_cuda: bool = True,
    seed: int = 0,
):
    hvd.init()

    logger.info("L:", n_encoder_layers, type(n_encoder_layers))
    logger.info("H:", d_model, type(d_model))
    logger.info("A:", n_head, type(n_head))
    run_name = str(run_name)  # support numerical run ids
    slurm_job_id = os.environ.get("SLURM_JOB_ID")
    slurm_job_hostname = os.environ.get("SLURM_JOB_NODELIST")
    config = locals()
    logger.info(f"Config = \n{config}")
    logger.info("Training configuration: {}".format(config))
    logger.info(
        f"CUDA_VISIBLE_DEVICES = '{os.environ.get('CUDA_VISIBLE_DEVICES')}'")
    logger.info(f"CUDA_DEVICE_ORDER = '{os.environ.get('CUDA_DEVICE_ORDER')}'")

    assert program_mode in ["contrastive", "identity", "augmentation"]
    assert loss_mode == "infonce" or loss_mode == "mlm" or loss_mode == "hybrid"
    assert not (program_mode == "contrastive" and loss_mode == "mlm")
    assert not (program_mode != "contrastive" and
                (loss_mode == "hybrid" or loss_mode == "infonce"))
    assert not use_cuda or torch.cuda.is_available()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    run_dir = RUN_DIR / "{}_{}".format(run_name, int(time.time()))
    run_dir.mkdir(exist_ok=True, parents=True)
    config["run_dir"] = str(run_dir.resolve())
    logger.add(str((run_dir / "train.log").resolve()))
    logger.info(f"Saving logs, model checkpoints to {run_dir}")

    # Create training dataset and dataloader
    assert train_filepath.endswith(".pickle") or train_filepath.endswith(".gz")

    # Setup distributed
    gpu = hvd.local_rank()
    ngpus_per_node = 1
    chief_node = gpu == 0
    assert gpu is not None

    if chief_node:
        if config["loss_mode"] == "mlm":
            project = "bert-pretrain"
        elif config["loss_mode"] == "infonce":
            project = "moco-pretrain"
        elif config["loss_mode"] == "hybrid":
            project = "hybrid"
        wandb.init(name=config["run_name"],
                   config=config,
                   job_type="training",
                   project=project,
                   entity="ml4code")

    logger.info("Use GPU: {} for training".format(gpu))
    torch.cuda.set_device(gpu)
    # Horovod: limit # of CPU threads to be used per worker.
    torch.set_num_threads(1)

    kwargs = {"num_workers": 1, "pin_memory": True}
    # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
    # issues with Infiniband implementations that are not fork-safe
    if (kwargs.get("num_workers", 0) > 0 and hasattr(mp, "_supports_context")
            and mp._supports_context
            and "forkserver" in mp.get_all_start_methods()):
        kwargs["multiprocessing_context"] = "forkserver"

    sp = spm.SentencePieceProcessor()
    sp.Load(config["spm_filepath"])
    pad_id = sp.PieceToId("[PAD]")
    logger.info("pad_id {}", pad_id)
    assert pad_id == 0  # hard coded in pad_collate
    mask_id = sp.PieceToId("[MASK]")

    # Create model
    if config["loss_mode"] == "infonce":
        # TODO(ajay): Support n_head argument, check how d_model is being used (why not in encoder config dict?)
        model = CodeMoCo(
            sp.GetPieceSize(),
            pad_id=pad_id,
            d_model=config["d_model"],
            encoder_config=dict(
                encoder_type=config["encoder_type"],
                lstm_project_mode=config["lstm_project_mode"],
                n_encoder_layers=config["n_encoder_layers"],
            ),
        )
        logger.info(
            f"Created CodeMoCo model with {count_parameters(model)} params")
    elif config["loss_mode"] == "mlm":
        model = CodeMLM(
            sp.GetPieceSize(),
            pad_id=pad_id,
            encoder_type=config["encoder_type"],
            n_encoder_layers=config["n_encoder_layers"],
            d_model=config["d_model"],
            n_head=config["n_head"],
            d_ff=4 * config["d_model"],
        )
        logger.info(
            f"Created CodeMLM model with {count_parameters(model)} params")
    elif config["loss_mode"] == "hybrid":
        model = CodeContrastiveMLM(
            sp.GetPieceSize(),
            pad_id=pad_id,
            n_encoder_layers=config["n_encoder_layers"],
            d_model=config["d_model"],
            n_head=config["n_head"],
            d_ff=4 * config["d_model"],
            use_horovod=True,
        )
        logger.info(
            f"Created CodeContrastiveMLM model with {count_parameters(model)} params"
        )
    else:
        raise ValueError(f"Bad loss mode {config['loss_mode']}")

    assert config["use_cuda"]
    model.cuda()
    # When using a single GPU per process and per
    # DistributedDataParallel, we need to divide the batch size
    # ourselves based on the total number of GPUs we have
    # config["batch_size"] = int(config["batch_size"] / ngpus_per_node)
    # config["num_workers"] = int((config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node)
    # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    # define optimizer
    # By default, Adasum doesn't need scaling up learning rate.
    lr_scaler = hvd.size() if not config["use_adasum"] else 1
    # If using GPU Adasum allreduce, scale learning rate by local_size.
    if config["use_adasum"] and hvd.nccl_built():
        lr_scaler = hvd.local_size()
    # Horovod: scale learning rate by lr_scaler.
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config["lr"] * lr_scaler,
                                 betas=config["adam_betas"],
                                 eps=1e-6,
                                 weight_decay=config["weight_decay"])
    sched = get_linear_schedule_with_warmup(optimizer, config["warmup_steps"],
                                            config["num_steps"])

    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Horovod: (optional) compression algorithm.
    compression = hvd.Compression.fp16 if config[
        "fp16_allreduce"] else hvd.Compression.none

    # Horovod: wrap optimizer with DistributedOptimizer.
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression,
        op=hvd.Adasum if config["use_adasum"] else hvd.Average,
        gradient_predivide_factor=config["gradient_predivide_factor"],
    )

    # Load checkpoint
    if config["resume_path"]:
        logger.info(f"Loading parameters from {config['resume_path']}")
        # configure map_location properly
        map_location = {"cuda:%d" % 0: "cuda:%d" % hvd.rank()}
        checkpoint = torch.load(config["resume_path"],
                                map_location=map_location)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        start_global_step = checkpoint["global_step"]
    else:
        start_epoch = 1
        start_global_step = 0

    # Setup data
    train_dataset = PrecomputedDataset(
        config["train_filepath"],
        min_alternatives=config["min_alternatives"],
        program_mode=config["program_mode"],
        limit_size=config["limit_dataset_size"],
        sp=sp,
        subword_regularization_alpha=config["subword_regularization_alpha"],
        max_length=config["max_length"],
    )
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        collate_fn=pad_collate_contrastive
        if config["program_mode"] == "contrastive" else pad_collate,
        drop_last=True,
        sampler=train_sampler,
        **kwargs,
    )

    # Train
    global_step = 0
    while global_step < start_global_step:
        sched.step()
        global_step += 1
    for epoch in tqdm.trange(start_epoch,
                             config["num_epochs"] + 1,
                             desc="training",
                             unit="epoch",
                             leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        train_sampler.set_epoch(epoch)
        model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for batch in pbar:
            optimizer.zero_grad()
            if config["loss_mode"] == "infonce":
                train_metrics = training_step(model,
                                              batch,
                                              use_cuda=config["use_cuda"])
            elif config["loss_mode"] == "mlm":
                # replace tokens randomly with tokens from _ (8)
                train_metrics = training_step_mlm(sp,
                                                  model,
                                                  batch,
                                                  pad_id=pad_id,
                                                  mask_id=mask_id,
                                                  vocab_start_idx=8,
                                                  vocab_end_idx=7999,
                                                  use_cuda=config["use_cuda"])
            elif config["loss_mode"] == "hybrid":
                train_metrics = training_step_hybrid(
                    sp,
                    model,
                    batch,
                    mask_id=mask_id,
                    pad_id=pad_id,
                    vocab_start_idx=0,
                    vocab_end_idx=7999,
                    use_cuda=config["use_cuda"])
            else:
                raise ValueError("Bad loss type")
            loss = train_metrics["loss"]
            loss.backward()
            optimizer.step()
            sched.step()

            global_step += 1
            pbar.set_description(
                f"epoch {epoch} gpu {gpu} step {global_step} loss {loss.item():.4f}"
            )

            if chief_node:
                wandb.log(dict(lr=sched.get_last_lr()[0]))
                wandb.log(dict(epoch=epoch, **train_metrics["log"]),
                          step=global_step)

                # Save checkpoint
                if config["save_every"] and global_step % config[
                        "save_every"] == 0:
                    checkpoint = {
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "epoch": epoch,
                        "global_step": global_step,
                        "config": config,
                    }
                    model_file = os.path.join(
                        config["run_dir"],
                        f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth"
                    )
                    logger.info(f"Saving checkpoint to {model_file}...")
                    torch.save(checkpoint, model_file)
                    wandb.save(str(model_file))
                    logger.info("Done.")