Example #1
0
def embed_coco(checkpoint,
               data_path,
               spm_filepath=DEFAULT_SPM_UNIGRAM_FILEPATH,
               encoder_type="transformer",
               n_encoder_layers=6):
    # read data
    df = pd.read_pickle(data_path)

    # read tokenizer
    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)
    n_tokens = sp.GetPieceSize()
    pad_id = sp.PieceToId("[PAD]")

    # load model
    model = CodeMoCo(
        n_tokens=n_tokens,
        pad_id=pad_id,
        encoder_config=dict(encoder_type=encoder_type,
                            n_encoder_layers=n_encoder_layers,
                            project="hidden"),
    )
    state = torch.load(checkpoint)
    print(state["model_state_dict"].keys())
    model.load_state_dict(state["model_state_dict"])
    model.cuda()
    model.eval()

    out_rows = []
    with torch.no_grad():
        for row_idx in tqdm.tqdm(list(range(len(df))), desc="Table"):
            text = df.loc[row_idx]["code"]
            func_name = df.loc[row_idx]["func_name"]
            x_encoded = torch.LongTensor(sp.EncodeAsIds(text)).cuda()
            lens = torch.LongTensor([len(x_encoded)])
            try:
                embed_x = model.embed_x(x_encoded.unsqueeze(0),
                                        lens).cpu().numpy()
                out_rows.append(
                    dict(code=text, func_name=func_name, embedding=embed_x))
            except Exception as e:
                print("Error!", e)
                continue

    tsne_out_path = DATA_DIR / "tsne"
    tsne_out_path.mkdir(parents=True, exist_ok=True)
    print("writing output to ", tsne_out_path.resolve())
    with (tsne_out_path /
          "tsne_out_embedded_grouped_hidden.pickle").open("wb") as f:
        pickle.dump(out_rows, f)
Example #2
0
def embed_coco(checkpoint, data_path, spm_filepath=DEFAULT_SPM_UNIGRAM_FILEPATH):
    with open(data_path, 'rb') as f:
        matches = pickle.load(f)
    negatives = matches['negatives']
    positive_samples = {k: v for k, v in matches.items() if k != "negatives"}

    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)
    n_tokens = sp.GetPieceSize()
    pad_id = sp.PieceToId("[PAD]")

    model = CodeMoCo(n_tokens=n_tokens, pad_id=pad_id)
    model.load_state_dict(torch.load(checkpoint))
    model.cuda()
    model.eval()

    def make_dataset(l):
        embed_x = [torch.LongTensor(sp.EncodeAsIds(item)).cuda() for item in l]
        return embed_x

    out_matches = {}
    out_negatives = []
    with torch.no_grad():
        for negative in tqdm.tqdm(make_dataset(negatives), desc='negatives'):
            x = negative.unsqueeze(0)
            out_negatives.append(model.embed(x).cpu().numpy())

        for match in positive_samples.keys():
            out_matches[match] = list()
            for positive in make_dataset(match):
                x = positive.unsqueeze(0)
                out_matches[match].append(model.embed(x).cpu().numpy())
    tsne_out_path = (RUN_DIR / 'tsne')
    tsne_out_path.mkdir(parents=True, exist_ok=True)
    with (tsne_out_path / "moco_embed.pickle").open('wb') as f:
        pickle.dump((out_matches, out_negatives), f)
def pretrain_worker(gpu, ngpus_per_node, config):
    chief_node = gpu == 0
    if chief_node:
        if config["loss_mode"] == "mlm":
            project = "bert-pretrain"
        elif config["loss_mode"] == "infonce":
            project = "moco-pretrain"
        elif config["loss_mode"] == "hybrid":
            project = "hybrid"
        wandb.init(name=config["run_name"],
                   config=config,
                   job_type="training",
                   project=project,
                   entity="ml4code")

    if gpu is not None:
        logger.info("Use GPU: {} for training".format(gpu))

    if config["dist_url"] == "env://" and config["rank"] == -1:
        config["rank"] = int(os.environ["RANK"])
    # For multiprocessing distributed training, rank needs to be the
    # global rank among all the processes
    config["rank"] = config["rank"] * ngpus_per_node + gpu
    dist.init_process_group(backend=config["dist_backend"],
                            init_method=config["dist_url"],
                            world_size=config["world_size"],
                            rank=config["rank"])

    sp = spm.SentencePieceProcessor()
    sp.Load(config["spm_filepath"])
    pad_id = sp.PieceToId("[PAD]")
    mask_id = sp.PieceToId("[MASK]")

    def pad_collate(batch):
        B = len(batch)
        if config["program_mode"] == "contrastive":
            X1, X2 = zip(*batch)
            X = X1 + X2
        else:
            X = batch

        # Create tensor of sequence lengths, [B] or [2B]
        lengths = torch.tensor([len(x) for x in X], dtype=torch.long)

        # Create padded tensor for batch, [B, T] or [2B, T]
        X = pad_sequence(X, batch_first=True, padding_value=pad_id)

        if config["program_mode"] == "contrastive":
            # Reshape X to [B, 2, T]
            T = X.size(-1)
            X = torch.reshape(X, (2, B, -1))
            X = torch.transpose(X, 0, 1)
            assert X.shape == (B, 2, T)
            lengths = torch.reshape(lengths, (2, B)).transpose(0, 1)
            assert lengths.shape == (B, 2)
        return X, lengths, None

    # Create model
    if config["loss_mode"] == "infonce":
        model = CodeMoCo(sp.GetPieceSize(),
                         pad_id=pad_id,
                         d_model=config["d_model"],
                         encoder_config=dict(
                             encoder_type=config["encoder_type"],
                             lstm_project_mode=config["lstm_project_mode"],
                             n_encoder_layers=config["n_encoder_layers"]))
        logger.info(
            f"Created CodeMoCo model with {count_parameters(model)} params")
    elif config["loss_mode"] == "mlm":
        model = CodeMLM(sp.GetPieceSize(),
                        pad_id=pad_id,
                        encoder_type=config["encoder_type"],
                        n_encoder_layers=config["n_encoder_layers"])
        logger.info(
            f"Created CodeMLM model with {count_parameters(model)} params")
    elif config["loss_mode"] == "hybrid":
        model = CodeContrastiveMLM(sp.GetPieceSize(), pad_id=pad_id)
        logger.info(
            f"Created CodeContrastiveMLM model with {count_parameters(model)} params"
        )
    else:
        raise ValueError(f"Bad loss mode {config['loss_mode']}")

    assert config["use_cuda"]
    if gpu is not None:
        torch.cuda.set_device(gpu)
        model.cuda(gpu)
        # When using a single GPU per process and per
        # DistributedDataParallel, we need to divide the batch size
        # ourselves based on the total number of GPUs we have
        config["batch_size"] = int(config["batch_size"] / ngpus_per_node)
        config["num_workers"] = int(
            (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[gpu])
    else:
        model.cuda()
        # DistributedDataParallel will divide and allocate batch_size to all
        # available GPUs if device_ids are not set
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config["lr"],
                                 betas=config["adam_betas"],
                                 eps=1e-6,
                                 weight_decay=config["weight_decay"])
    sched = get_linear_schedule_with_warmup(optimizer, config["warmup_steps"],
                                            config["num_steps"])

    # Setup data
    train_dataset = PrecomputedDataset(
        config["train_filepath"],
        min_alternatives=config["min_alternatives"],
        program_mode=config["program_mode"],
        limit_size=config["limit_dataset_size"],
        sp=sp,
        subword_regularization_alpha=config["subword_regularization_alpha"],
        max_length=config["max_length"])
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        collate_fn=pad_collate,
        # num_workers=config["num_workers"],
        num_workers=0,
        drop_last=True,
        pin_memory=True,
        sampler=train_sampler,
    )

    # Train
    global_step = 0
    for epoch in tqdm.trange(1,
                             config["num_epochs"] + 1,
                             desc="training",
                             unit="epoch",
                             leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        train_sampler.set_epoch(epoch)
        model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for batch in pbar:
            optimizer.zero_grad()
            if config["loss_mode"] == "infonce":
                train_metrics = training_step(model,
                                              batch,
                                              use_cuda=config["use_cuda"])
            elif config["loss_mode"] == "mlm":
                # replace tokens randomly with tokens from _ (8)
                train_metrics = training_step_mlm(sp,
                                                  model,
                                                  batch,
                                                  pad_id=pad_id,
                                                  mask_id=mask_id,
                                                  vocab_start_idx=8,
                                                  vocab_end_idx=7999,
                                                  use_cuda=config["use_cuda"])
            elif config["loss_mode"] == "hybrid":
                train_metrics = training_step_hybrid(
                    sp,
                    model,
                    batch,
                    mask_id=mask_id,
                    pad_id=pad_id,
                    vocab_start_idx=0,
                    vocab_end_idx=7999,
                    use_cuda=config["use_cuda"])
            else:
                raise ValueError("Bad loss type")
            loss = train_metrics["loss"]
            loss.backward()
            optimizer.step()
            sched.step()

            global_step += 1
            pbar.set_description(
                f"epoch {epoch} gpu {gpu} step {global_step} loss {loss.item():.4f}"
            )

            if chief_node:
                wandb.log(dict(lr=sched.get_last_lr()[0]))
                wandb.log(dict(epoch=epoch, **train_metrics["log"]),
                          step=global_step)

                # Save checkpoint
                if config["save_every"] and global_step % config[
                        "save_every"] == 0:
                    checkpoint = {
                        "model_state_dict": model.module.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "epoch": epoch,
                        "global_step": global_step,
                        "config": config,
                    }
                    model_file = os.path.join(
                        config["run_dir"],
                        f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth"
                    )
                    logger.info(f"Saving checkpoint to {model_file}...")
                    torch.save(checkpoint, model_file)
                    wandb.save(str(model_file))
                    logger.info("Done.")
Example #4
0
def pretrain(
    run_name: str,
    # Data
    train_filepath: str = DEFAULT_CSNJS_TRAIN_FILEPATH,
    spm_filepath: str = DEFAULT_SPM_UNIGRAM_FILEPATH,
    num_workers=1,
    limit_dataset_size=-1,
    max_sequence_length=1024,
    augment_window_crop_size=6,
    subword_regularization_alpha: float = 0,
    program_mode="contrastive",
    loss_mode="infonce",
    # Optimization
    num_epochs: int = 100,
    save_every: int = 1,
    batch_size: int = 256,
    lr: float = 8e-4,
    adam_betas=(0.9, 0.98),
    # Computational
    use_cuda: bool = True,
    seed: int = 0,
):
    run_name = str(run_name)  # support numerical run ids
    slurm_job_id, slurm_job_hostname = (
        os.environ.get("SLURM_JOB_ID"),
        os.environ.get("SLURM_JOB_NODELIST"),
    )
    config = locals()
    logger.info("Training configuration: {}".format(config))
    logger.info(
        "CUDA_VISIBLE_DEVICES = '{}', CUDA_DEVICE_ORDER = '{}'".format(
            os.environ.get("CUDA_VISIBLE_DEVICES"), os.environ.get("CUDA_DEVICE_ORDER")
        )
    )

    assert not use_cuda or torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False"
    assert loss_mode in ["infonce", "mlm"]
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    run_dir = RUN_DIR / "{}_{}".format(run_name, int(time.time()))
    run_dir.mkdir(exist_ok=True, parents=True)
    logger.add(str((run_dir / "train.log").resolve()))
    logger.info(f"Saving logs, model checkpoints to {run_dir}")
    wandb.init(
        name=run_name, config=config, job_type="training", project="moco-pretrain", entity="ml4code",
    )

    sp = spm.SentencePieceProcessor()
    sp.Load(spm_filepath)
    pad_id = sp.PieceToId("[PAD]")

    # Create training dataset and dataloader
    assert train_filepath.endswith(".pickle")

    def pad_collate(batch):
        B = len(batch)
        if program_mode == "contrastive":
            X1, X2 = zip(*batch)
            X = X1 + X2
        else:
            raise NotImplementedError()

        # Create padded tensor for batch, [B, T] or [2B, T]
        X = pad_sequence(X, batch_first=True, padding_value=pad_id)

        if program_mode == "contrastive":
            # Reshape X to [B, 2, T]
            T = X.size(-1)
            X = torch.reshape(X, (2, B, -1))
            X = torch.transpose(X, 0, 1)
            assert X.shape == (B, 2, T)
        return (X, None)

    train_dataset = PrecomputedDataset(
        train_filepath,
        min_alternatives=2,
        program_mode="contrastive",
        limit_size=limit_dataset_size,
        sp=sp,
        subword_regularization_alpha=subword_regularization_alpha,
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate, num_workers=num_workers, drop_last=True,
    )

    # Create model
    if loss_mode == "infonce":
        model = CodeMoCo(sp.GetPieceSize(), pad_id=pad_id)
        logger.info(f"Created CodeMoCo model with {count_parameters(model)} params")
    elif loss_mode == "mlm":
        model = CodeMLM(sp.GetPieceSize(), pad_id=pad_id)
        logger.info(f"Created CodeMLM model with {count_parameters(model)} params")
    model = nn.DataParallel(model)
    model = model.cuda() if use_cuda else model
    params = model.parameters()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas, eps=1e-9)

    global_step = 0
    min_eval_loss = float("inf")
    for epoch in tqdm.trange(1, num_epochs + 1, desc="training", unit="epoch", leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for batch in pbar:
            optimizer.zero_grad()
            if loss_mode == "infonce":
                imgs, _ = batch
                if use_cuda:
                    imgs = imgs.cuda()
                imgs_k, imgs_q = imgs[:, 0, :], imgs[:, 1, :]
                output, target = model(imgs_q, imgs_k)
                loss = F.cross_entropy(output, target)
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                train_metrics = {
                    "loss": loss,
                    "log": {
                        "pretrain/loss": loss.item(),
                        "pretrain/acc@1": acc1[0].item(),
                        "pretrain/acc@5": acc5[0].item(),
                        "pretrain/queue_ptr": model.module.queue_ptr.item(),
                    },
                }
                loss = train_metrics["loss"]
            loss.backward()
            optimizer.step()

            # Log loss
            global_step += 1
            wandb.log(dict(epoch=epoch, **train_metrics["log"]), step=global_step)
            pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}")

            # Save checkpoint
            if save_every and global_step % save_every == 0:
                checkpoint = {
                    "model_state_dict": model.module.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "epoch": epoch,
                    "global_step": global_step,
                    "config": config,
                }
                model_file = run_dir / f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth"
                logger.info(f"Saving checkpoint to {model_file}...")
                torch.save(checkpoint, str(model_file.resolve()))
                # wandb.save(model_file)
                logger.info("Done.")
def pretrain(
    run_name: str,
    #
    # Data
    train_filepath: str = DEFAULT_CSNJS_TRAIN_FILEPATH,
    spm_filepath: str = DEFAULT_SPM_UNIGRAM_FILEPATH,
    num_workers=1,
    limit_dataset_size=-1,
    max_length=1024,
    subword_regularization_alpha: float = 0,
    program_mode="contrastive",
    loss_mode="infonce",  # infonce, mlm, or hybrid
    min_alternatives=1,
    #
    # Model
    resume_path: str = "",
    encoder_type: str = "transformer",
    lstm_project_mode: str = "hidden",
    n_encoder_layers: int = 6,
    d_model: int = 512,
    n_head: int = 8,
    #
    # Optimization
    num_epochs: int = 100,
    save_every: int = 1,
    batch_size: int = 256,
    lr: float = 8e-4,
    weight_decay: float = 0,
    adam_betas=(0.9, 0.98),
    warmup_steps: int = 5000,
    num_steps: int = 600000,
    #
    # Horovod
    use_adasum: bool = False,
    fp16_allreduce: bool = False,
    gradient_predivide_factor: float = 1.0,
    #
    # Computational
    use_cuda: bool = True,
    seed: int = 0,
):
    hvd.init()

    logger.info("L:", n_encoder_layers, type(n_encoder_layers))
    logger.info("H:", d_model, type(d_model))
    logger.info("A:", n_head, type(n_head))
    run_name = str(run_name)  # support numerical run ids
    slurm_job_id = os.environ.get("SLURM_JOB_ID")
    slurm_job_hostname = os.environ.get("SLURM_JOB_NODELIST")
    config = locals()
    logger.info(f"Config = \n{config}")
    logger.info("Training configuration: {}".format(config))
    logger.info(
        f"CUDA_VISIBLE_DEVICES = '{os.environ.get('CUDA_VISIBLE_DEVICES')}'")
    logger.info(f"CUDA_DEVICE_ORDER = '{os.environ.get('CUDA_DEVICE_ORDER')}'")

    assert program_mode in ["contrastive", "identity", "augmentation"]
    assert loss_mode == "infonce" or loss_mode == "mlm" or loss_mode == "hybrid"
    assert not (program_mode == "contrastive" and loss_mode == "mlm")
    assert not (program_mode != "contrastive" and
                (loss_mode == "hybrid" or loss_mode == "infonce"))
    assert not use_cuda or torch.cuda.is_available()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    run_dir = RUN_DIR / "{}_{}".format(run_name, int(time.time()))
    run_dir.mkdir(exist_ok=True, parents=True)
    config["run_dir"] = str(run_dir.resolve())
    logger.add(str((run_dir / "train.log").resolve()))
    logger.info(f"Saving logs, model checkpoints to {run_dir}")

    # Create training dataset and dataloader
    assert train_filepath.endswith(".pickle") or train_filepath.endswith(".gz")

    # Setup distributed
    gpu = hvd.local_rank()
    ngpus_per_node = 1
    chief_node = gpu == 0
    assert gpu is not None

    if chief_node:
        if config["loss_mode"] == "mlm":
            project = "bert-pretrain"
        elif config["loss_mode"] == "infonce":
            project = "moco-pretrain"
        elif config["loss_mode"] == "hybrid":
            project = "hybrid"
        wandb.init(name=config["run_name"],
                   config=config,
                   job_type="training",
                   project=project,
                   entity="ml4code")

    logger.info("Use GPU: {} for training".format(gpu))
    torch.cuda.set_device(gpu)
    # Horovod: limit # of CPU threads to be used per worker.
    torch.set_num_threads(1)

    kwargs = {"num_workers": 1, "pin_memory": True}
    # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
    # issues with Infiniband implementations that are not fork-safe
    if (kwargs.get("num_workers", 0) > 0 and hasattr(mp, "_supports_context")
            and mp._supports_context
            and "forkserver" in mp.get_all_start_methods()):
        kwargs["multiprocessing_context"] = "forkserver"

    sp = spm.SentencePieceProcessor()
    sp.Load(config["spm_filepath"])
    pad_id = sp.PieceToId("[PAD]")
    logger.info("pad_id {}", pad_id)
    assert pad_id == 0  # hard coded in pad_collate
    mask_id = sp.PieceToId("[MASK]")

    # Create model
    if config["loss_mode"] == "infonce":
        # TODO(ajay): Support n_head argument, check how d_model is being used (why not in encoder config dict?)
        model = CodeMoCo(
            sp.GetPieceSize(),
            pad_id=pad_id,
            d_model=config["d_model"],
            encoder_config=dict(
                encoder_type=config["encoder_type"],
                lstm_project_mode=config["lstm_project_mode"],
                n_encoder_layers=config["n_encoder_layers"],
            ),
        )
        logger.info(
            f"Created CodeMoCo model with {count_parameters(model)} params")
    elif config["loss_mode"] == "mlm":
        model = CodeMLM(
            sp.GetPieceSize(),
            pad_id=pad_id,
            encoder_type=config["encoder_type"],
            n_encoder_layers=config["n_encoder_layers"],
            d_model=config["d_model"],
            n_head=config["n_head"],
            d_ff=4 * config["d_model"],
        )
        logger.info(
            f"Created CodeMLM model with {count_parameters(model)} params")
    elif config["loss_mode"] == "hybrid":
        model = CodeContrastiveMLM(
            sp.GetPieceSize(),
            pad_id=pad_id,
            n_encoder_layers=config["n_encoder_layers"],
            d_model=config["d_model"],
            n_head=config["n_head"],
            d_ff=4 * config["d_model"],
            use_horovod=True,
        )
        logger.info(
            f"Created CodeContrastiveMLM model with {count_parameters(model)} params"
        )
    else:
        raise ValueError(f"Bad loss mode {config['loss_mode']}")

    assert config["use_cuda"]
    model.cuda()
    # When using a single GPU per process and per
    # DistributedDataParallel, we need to divide the batch size
    # ourselves based on the total number of GPUs we have
    # config["batch_size"] = int(config["batch_size"] / ngpus_per_node)
    # config["num_workers"] = int((config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node)
    # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    # define optimizer
    # By default, Adasum doesn't need scaling up learning rate.
    lr_scaler = hvd.size() if not config["use_adasum"] else 1
    # If using GPU Adasum allreduce, scale learning rate by local_size.
    if config["use_adasum"] and hvd.nccl_built():
        lr_scaler = hvd.local_size()
    # Horovod: scale learning rate by lr_scaler.
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config["lr"] * lr_scaler,
                                 betas=config["adam_betas"],
                                 eps=1e-6,
                                 weight_decay=config["weight_decay"])
    sched = get_linear_schedule_with_warmup(optimizer, config["warmup_steps"],
                                            config["num_steps"])

    # Horovod: broadcast parameters & optimizer state.
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # Horovod: (optional) compression algorithm.
    compression = hvd.Compression.fp16 if config[
        "fp16_allreduce"] else hvd.Compression.none

    # Horovod: wrap optimizer with DistributedOptimizer.
    optimizer = hvd.DistributedOptimizer(
        optimizer,
        named_parameters=model.named_parameters(),
        compression=compression,
        op=hvd.Adasum if config["use_adasum"] else hvd.Average,
        gradient_predivide_factor=config["gradient_predivide_factor"],
    )

    # Load checkpoint
    if config["resume_path"]:
        logger.info(f"Loading parameters from {config['resume_path']}")
        # configure map_location properly
        map_location = {"cuda:%d" % 0: "cuda:%d" % hvd.rank()}
        checkpoint = torch.load(config["resume_path"],
                                map_location=map_location)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        start_global_step = checkpoint["global_step"]
    else:
        start_epoch = 1
        start_global_step = 0

    # Setup data
    train_dataset = PrecomputedDataset(
        config["train_filepath"],
        min_alternatives=config["min_alternatives"],
        program_mode=config["program_mode"],
        limit_size=config["limit_dataset_size"],
        sp=sp,
        subword_regularization_alpha=config["subword_regularization_alpha"],
        max_length=config["max_length"],
    )
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        collate_fn=pad_collate_contrastive
        if config["program_mode"] == "contrastive" else pad_collate,
        drop_last=True,
        sampler=train_sampler,
        **kwargs,
    )

    # Train
    global_step = 0
    while global_step < start_global_step:
        sched.step()
        global_step += 1
    for epoch in tqdm.trange(start_epoch,
                             config["num_epochs"] + 1,
                             desc="training",
                             unit="epoch",
                             leave=False):
        logger.info(f"Starting epoch {epoch}\n")
        train_sampler.set_epoch(epoch)
        model.train()
        pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}")
        for batch in pbar:
            optimizer.zero_grad()
            if config["loss_mode"] == "infonce":
                train_metrics = training_step(model,
                                              batch,
                                              use_cuda=config["use_cuda"])
            elif config["loss_mode"] == "mlm":
                # replace tokens randomly with tokens from _ (8)
                train_metrics = training_step_mlm(sp,
                                                  model,
                                                  batch,
                                                  pad_id=pad_id,
                                                  mask_id=mask_id,
                                                  vocab_start_idx=8,
                                                  vocab_end_idx=7999,
                                                  use_cuda=config["use_cuda"])
            elif config["loss_mode"] == "hybrid":
                train_metrics = training_step_hybrid(
                    sp,
                    model,
                    batch,
                    mask_id=mask_id,
                    pad_id=pad_id,
                    vocab_start_idx=0,
                    vocab_end_idx=7999,
                    use_cuda=config["use_cuda"])
            else:
                raise ValueError("Bad loss type")
            loss = train_metrics["loss"]
            loss.backward()
            optimizer.step()
            sched.step()

            global_step += 1
            pbar.set_description(
                f"epoch {epoch} gpu {gpu} step {global_step} loss {loss.item():.4f}"
            )

            if chief_node:
                wandb.log(dict(lr=sched.get_last_lr()[0]))
                wandb.log(dict(epoch=epoch, **train_metrics["log"]),
                          step=global_step)

                # Save checkpoint
                if config["save_every"] and global_step % config[
                        "save_every"] == 0:
                    checkpoint = {
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "epoch": epoch,
                        "global_step": global_step,
                        "config": config,
                    }
                    model_file = os.path.join(
                        config["run_dir"],
                        f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth"
                    )
                    logger.info(f"Saving checkpoint to {model_file}...")
                    torch.save(checkpoint, model_file)
                    wandb.save(str(model_file))
                    logger.info("Done.")