def embed_augmented( # Data data_filepath: str, output_dir: str, spm_filepath: str, num_workers=1, max_seq_len=-1, min_alternatives=2, # Model encoder_type: str = "lstm", pretrain_resume_path: str = "", pretrain_resume_encoder_name: str = "encoder_q", # encoder_q, encoder_k, encoder pretrain_resume_project: bool = False, # no_output_attention: bool = False, n_encoder_layers: int = 2, d_model: int = 512, # Loss subword_regularization_alpha: float = 0, # Computational use_cuda: bool = True, seed: int = 0, ): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) config = locals() logger.info(f"Config: {config}") if use_cuda: assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False" sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") mask_id = sp.PieceToId("[MASK]") # Create model if encoder_type == "lstm": encoder = CodeEncoderLSTM( n_tokens=sp.GetPieceSize(), d_model=d_model, d_rep=256, n_encoder_layers=n_encoder_layers, dropout=0.1, pad_id=pad_id, project=False, ) encoder.config["project"] = "hidden" logger.info(f"Created CodeEncoderLSTM with {count_parameters(encoder)} params") elif encoder_type == "transformer": encoder = CodeEncoder(sp.GetPieceSize(), d_model, 256, 8, n_encoder_layers, 2048, 0.1, "relu", True, pad_id, project=False) logger.info(f"Created CodeEncoder with {count_parameters(encoder)} params") # Load pretrained checkpoint if pretrain_resume_path: logger.info( f"Resuming training from pretraining checkpoint {pretrain_resume_path}, pretrain_resume_encoder_name={pretrain_resume_encoder_name}" ) checkpoint = torch.load(pretrain_resume_path) pretrained_state_dict = checkpoint["model_state_dict"] for key in pretrained_state_dict.keys(): print("Pretrained state dict:", key) for key in encoder.state_dict().keys(): print("Encoder state dict:", key) encoder_state_dict = {} assert pretrain_resume_encoder_name in ["encoder_k", "encoder_q", "encoder"] for key, value in pretrained_state_dict.items(): if key.startswith(pretrain_resume_encoder_name + ".") and "project_layer" not in key: remapped_key = key[len(pretrain_resume_encoder_name + ".") :] logger.debug(f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}") encoder_state_dict[remapped_key] = value encoder.load_state_dict(encoder_state_dict) logger.info(f"Loaded state dict from {pretrain_resume_path}") # Parallelize across GPUs encoder = nn.DataParallel(encoder) encoder = encoder.cuda() if use_cuda else encoder # Load batches consisting of augmented variants of the same program sp = spm.SentencePieceProcessor() sp.Load(config["spm_filepath"]) pad_id = sp.PieceToId("[PAD]") def pad_collate(batch): assert len(batch) == 1 X = batch[0] B = len(X) # Create tensor of sequence lengths, [B] or [2B] lengths = torch.tensor([len(x) for x in X], dtype=torch.long) # Create padded tensor for batch, [B, T] X = pad_sequence(X, batch_first=True, padding_value=pad_id) return X, lengths dataset = PrecomputedDataset( data_filepath, min_alternatives=min_alternatives, program_mode="all_alternatives", limit_size=-1, sp=sp, subword_regularization_alpha=subword_regularization_alpha, max_length=max_seq_len, ) loader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=True, collate_fn=pad_collate, num_workers=num_workers, drop_last=False, pin_memory=False, ) representations = [] encoder.eval() os.makedirs(output_dir, exist_ok=True) with torch.no_grad(): # Evaluate metrics logger.info(f"Evaluating encoder...") pbar = tqdm.tqdm(loader, desc="evalaute") for X, lengths in pbar: rep = encoder(X.cuda(), lengths.cuda(), None) # [B, n_layers*n_directions*d_model] if encoder_type == "transformer": assert len(rep.shape) == 3 rep = rep.mean(dim=0) # rep is [T, B, dimension], so take mean across sequence rep = rep.cpu().numpy() X = X.cpu().numpy() print("rep", type(rep), "X", type(X)) print("rep", rep.shape, "X", X.shape) representations.append((X, rep)) if len(representations) and len(representations) % 100 == 0: path = os.path.join(output_dir, f"tokens_and_embeddings_{len(representations):06d}.pth") logger.info(f"Saving representations to {path}") # with open(path, "wb") as f: # pickle.dump(representations, f) # torch.save(path, representations) torch.save(representations, path)
def pretrain_worker(gpu, ngpus_per_node, config): chief_node = gpu == 0 if chief_node: if config["loss_mode"] == "mlm": project = "bert-pretrain" elif config["loss_mode"] == "infonce": project = "moco-pretrain" elif config["loss_mode"] == "hybrid": project = "hybrid" wandb.init(name=config["run_name"], config=config, job_type="training", project=project, entity="ml4code") if gpu is not None: logger.info("Use GPU: {} for training".format(gpu)) if config["dist_url"] == "env://" and config["rank"] == -1: config["rank"] = int(os.environ["RANK"]) # For multiprocessing distributed training, rank needs to be the # global rank among all the processes config["rank"] = config["rank"] * ngpus_per_node + gpu dist.init_process_group(backend=config["dist_backend"], init_method=config["dist_url"], world_size=config["world_size"], rank=config["rank"]) sp = spm.SentencePieceProcessor() sp.Load(config["spm_filepath"]) pad_id = sp.PieceToId("[PAD]") mask_id = sp.PieceToId("[MASK]") def pad_collate(batch): B = len(batch) if config["program_mode"] == "contrastive": X1, X2 = zip(*batch) X = X1 + X2 else: X = batch # Create tensor of sequence lengths, [B] or [2B] lengths = torch.tensor([len(x) for x in X], dtype=torch.long) # Create padded tensor for batch, [B, T] or [2B, T] X = pad_sequence(X, batch_first=True, padding_value=pad_id) if config["program_mode"] == "contrastive": # Reshape X to [B, 2, T] T = X.size(-1) X = torch.reshape(X, (2, B, -1)) X = torch.transpose(X, 0, 1) assert X.shape == (B, 2, T) lengths = torch.reshape(lengths, (2, B)).transpose(0, 1) assert lengths.shape == (B, 2) return X, lengths, None # Create model if config["loss_mode"] == "infonce": model = CodeMoCo(sp.GetPieceSize(), pad_id=pad_id, d_model=config["d_model"], encoder_config=dict( encoder_type=config["encoder_type"], lstm_project_mode=config["lstm_project_mode"], n_encoder_layers=config["n_encoder_layers"])) logger.info( f"Created CodeMoCo model with {count_parameters(model)} params") elif config["loss_mode"] == "mlm": model = CodeMLM(sp.GetPieceSize(), pad_id=pad_id, encoder_type=config["encoder_type"], n_encoder_layers=config["n_encoder_layers"]) logger.info( f"Created CodeMLM model with {count_parameters(model)} params") elif config["loss_mode"] == "hybrid": model = CodeContrastiveMLM(sp.GetPieceSize(), pad_id=pad_id) logger.info( f"Created CodeContrastiveMLM model with {count_parameters(model)} params" ) else: raise ValueError(f"Bad loss mode {config['loss_mode']}") assert config["use_cuda"] if gpu is not None: torch.cuda.set_device(gpu) model.cuda(gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have config["batch_size"] = int(config["batch_size"] / ngpus_per_node) config["num_workers"] = int( (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) # define optimizer optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], betas=config["adam_betas"], eps=1e-6, weight_decay=config["weight_decay"]) sched = get_linear_schedule_with_warmup(optimizer, config["warmup_steps"], config["num_steps"]) # Setup data train_dataset = PrecomputedDataset( config["train_filepath"], min_alternatives=config["min_alternatives"], program_mode=config["program_mode"], limit_size=config["limit_dataset_size"], sp=sp, subword_regularization_alpha=config["subword_regularization_alpha"], max_length=config["max_length"]) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=pad_collate, # num_workers=config["num_workers"], num_workers=0, drop_last=True, pin_memory=True, sampler=train_sampler, ) # Train global_step = 0 for epoch in tqdm.trange(1, config["num_epochs"] + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") train_sampler.set_epoch(epoch) model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for batch in pbar: optimizer.zero_grad() if config["loss_mode"] == "infonce": train_metrics = training_step(model, batch, use_cuda=config["use_cuda"]) elif config["loss_mode"] == "mlm": # replace tokens randomly with tokens from _ (8) train_metrics = training_step_mlm(sp, model, batch, pad_id=pad_id, mask_id=mask_id, vocab_start_idx=8, vocab_end_idx=7999, use_cuda=config["use_cuda"]) elif config["loss_mode"] == "hybrid": train_metrics = training_step_hybrid( sp, model, batch, mask_id=mask_id, pad_id=pad_id, vocab_start_idx=0, vocab_end_idx=7999, use_cuda=config["use_cuda"]) else: raise ValueError("Bad loss type") loss = train_metrics["loss"] loss.backward() optimizer.step() sched.step() global_step += 1 pbar.set_description( f"epoch {epoch} gpu {gpu} step {global_step} loss {loss.item():.4f}" ) if chief_node: wandb.log(dict(lr=sched.get_last_lr()[0])) wandb.log(dict(epoch=epoch, **train_metrics["log"]), step=global_step) # Save checkpoint if config["save_every"] and global_step % config[ "save_every"] == 0: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, } model_file = os.path.join( config["run_dir"], f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth" ) logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, model_file) wandb.save(str(model_file)) logger.info("Done.")
def pretrain( run_name: str, # Data train_filepath: str = DEFAULT_CSNJS_TRAIN_FILEPATH, spm_filepath: str = DEFAULT_SPM_UNIGRAM_FILEPATH, num_workers=1, limit_dataset_size=-1, max_sequence_length=1024, augment_window_crop_size=6, subword_regularization_alpha: float = 0, program_mode="contrastive", loss_mode="infonce", # Optimization num_epochs: int = 100, save_every: int = 1, batch_size: int = 256, lr: float = 8e-4, adam_betas=(0.9, 0.98), # Computational use_cuda: bool = True, seed: int = 0, ): run_name = str(run_name) # support numerical run ids slurm_job_id, slurm_job_hostname = ( os.environ.get("SLURM_JOB_ID"), os.environ.get("SLURM_JOB_NODELIST"), ) config = locals() logger.info("Training configuration: {}".format(config)) logger.info( "CUDA_VISIBLE_DEVICES = '{}', CUDA_DEVICE_ORDER = '{}'".format( os.environ.get("CUDA_VISIBLE_DEVICES"), os.environ.get("CUDA_DEVICE_ORDER") ) ) assert not use_cuda or torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False" assert loss_mode in ["infonce", "mlm"] torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) run_dir = RUN_DIR / "{}_{}".format(run_name, int(time.time())) run_dir.mkdir(exist_ok=True, parents=True) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") wandb.init( name=run_name, config=config, job_type="training", project="moco-pretrain", entity="ml4code", ) sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") # Create training dataset and dataloader assert train_filepath.endswith(".pickle") def pad_collate(batch): B = len(batch) if program_mode == "contrastive": X1, X2 = zip(*batch) X = X1 + X2 else: raise NotImplementedError() # Create padded tensor for batch, [B, T] or [2B, T] X = pad_sequence(X, batch_first=True, padding_value=pad_id) if program_mode == "contrastive": # Reshape X to [B, 2, T] T = X.size(-1) X = torch.reshape(X, (2, B, -1)) X = torch.transpose(X, 0, 1) assert X.shape == (B, 2, T) return (X, None) train_dataset = PrecomputedDataset( train_filepath, min_alternatives=2, program_mode="contrastive", limit_size=limit_dataset_size, sp=sp, subword_regularization_alpha=subword_regularization_alpha, ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate, num_workers=num_workers, drop_last=True, ) # Create model if loss_mode == "infonce": model = CodeMoCo(sp.GetPieceSize(), pad_id=pad_id) logger.info(f"Created CodeMoCo model with {count_parameters(model)} params") elif loss_mode == "mlm": model = CodeMLM(sp.GetPieceSize(), pad_id=pad_id) logger.info(f"Created CodeMLM model with {count_parameters(model)} params") model = nn.DataParallel(model) model = model.cuda() if use_cuda else model params = model.parameters() optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=adam_betas, eps=1e-9) global_step = 0 min_eval_loss = float("inf") for epoch in tqdm.trange(1, num_epochs + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for batch in pbar: optimizer.zero_grad() if loss_mode == "infonce": imgs, _ = batch if use_cuda: imgs = imgs.cuda() imgs_k, imgs_q = imgs[:, 0, :], imgs[:, 1, :] output, target = model(imgs_q, imgs_k) loss = F.cross_entropy(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) train_metrics = { "loss": loss, "log": { "pretrain/loss": loss.item(), "pretrain/acc@1": acc1[0].item(), "pretrain/acc@5": acc5[0].item(), "pretrain/queue_ptr": model.module.queue_ptr.item(), }, } loss = train_metrics["loss"] loss.backward() optimizer.step() # Log loss global_step += 1 wandb.log(dict(epoch=epoch, **train_metrics["log"]), step=global_step) pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}") # Save checkpoint if save_every and global_step % save_every == 0: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, } model_file = run_dir / f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth" logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, str(model_file.resolve())) # wandb.save(model_file) logger.info("Done.")
def pretrain( run_name: str, # # Data train_filepath: str = DEFAULT_CSNJS_TRAIN_FILEPATH, spm_filepath: str = DEFAULT_SPM_UNIGRAM_FILEPATH, num_workers=1, limit_dataset_size=-1, max_length=1024, subword_regularization_alpha: float = 0, program_mode="contrastive", loss_mode="infonce", # infonce, mlm, or hybrid min_alternatives=1, # # Model resume_path: str = "", encoder_type: str = "transformer", lstm_project_mode: str = "hidden", n_encoder_layers: int = 6, d_model: int = 512, n_head: int = 8, # # Optimization num_epochs: int = 100, save_every: int = 1, batch_size: int = 256, lr: float = 8e-4, weight_decay: float = 0, adam_betas=(0.9, 0.98), warmup_steps: int = 5000, num_steps: int = 600000, # # Horovod use_adasum: bool = False, fp16_allreduce: bool = False, gradient_predivide_factor: float = 1.0, # # Computational use_cuda: bool = True, seed: int = 0, ): hvd.init() logger.info("L:", n_encoder_layers, type(n_encoder_layers)) logger.info("H:", d_model, type(d_model)) logger.info("A:", n_head, type(n_head)) run_name = str(run_name) # support numerical run ids slurm_job_id = os.environ.get("SLURM_JOB_ID") slurm_job_hostname = os.environ.get("SLURM_JOB_NODELIST") config = locals() logger.info(f"Config = \n{config}") logger.info("Training configuration: {}".format(config)) logger.info( f"CUDA_VISIBLE_DEVICES = '{os.environ.get('CUDA_VISIBLE_DEVICES')}'") logger.info(f"CUDA_DEVICE_ORDER = '{os.environ.get('CUDA_DEVICE_ORDER')}'") assert program_mode in ["contrastive", "identity", "augmentation"] assert loss_mode == "infonce" or loss_mode == "mlm" or loss_mode == "hybrid" assert not (program_mode == "contrastive" and loss_mode == "mlm") assert not (program_mode != "contrastive" and (loss_mode == "hybrid" or loss_mode == "infonce")) assert not use_cuda or torch.cuda.is_available() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) run_dir = RUN_DIR / "{}_{}".format(run_name, int(time.time())) run_dir.mkdir(exist_ok=True, parents=True) config["run_dir"] = str(run_dir.resolve()) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") # Create training dataset and dataloader assert train_filepath.endswith(".pickle") or train_filepath.endswith(".gz") # Setup distributed gpu = hvd.local_rank() ngpus_per_node = 1 chief_node = gpu == 0 assert gpu is not None if chief_node: if config["loss_mode"] == "mlm": project = "bert-pretrain" elif config["loss_mode"] == "infonce": project = "moco-pretrain" elif config["loss_mode"] == "hybrid": project = "hybrid" wandb.init(name=config["run_name"], config=config, job_type="training", project=project, entity="ml4code") logger.info("Use GPU: {} for training".format(gpu)) torch.cuda.set_device(gpu) # Horovod: limit # of CPU threads to be used per worker. torch.set_num_threads(1) kwargs = {"num_workers": 1, "pin_memory": True} # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent # issues with Infiniband implementations that are not fork-safe if (kwargs.get("num_workers", 0) > 0 and hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods()): kwargs["multiprocessing_context"] = "forkserver" sp = spm.SentencePieceProcessor() sp.Load(config["spm_filepath"]) pad_id = sp.PieceToId("[PAD]") logger.info("pad_id {}", pad_id) assert pad_id == 0 # hard coded in pad_collate mask_id = sp.PieceToId("[MASK]") # Create model if config["loss_mode"] == "infonce": # TODO(ajay): Support n_head argument, check how d_model is being used (why not in encoder config dict?) model = CodeMoCo( sp.GetPieceSize(), pad_id=pad_id, d_model=config["d_model"], encoder_config=dict( encoder_type=config["encoder_type"], lstm_project_mode=config["lstm_project_mode"], n_encoder_layers=config["n_encoder_layers"], ), ) logger.info( f"Created CodeMoCo model with {count_parameters(model)} params") elif config["loss_mode"] == "mlm": model = CodeMLM( sp.GetPieceSize(), pad_id=pad_id, encoder_type=config["encoder_type"], n_encoder_layers=config["n_encoder_layers"], d_model=config["d_model"], n_head=config["n_head"], d_ff=4 * config["d_model"], ) logger.info( f"Created CodeMLM model with {count_parameters(model)} params") elif config["loss_mode"] == "hybrid": model = CodeContrastiveMLM( sp.GetPieceSize(), pad_id=pad_id, n_encoder_layers=config["n_encoder_layers"], d_model=config["d_model"], n_head=config["n_head"], d_ff=4 * config["d_model"], use_horovod=True, ) logger.info( f"Created CodeContrastiveMLM model with {count_parameters(model)} params" ) else: raise ValueError(f"Bad loss mode {config['loss_mode']}") assert config["use_cuda"] model.cuda() # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have # config["batch_size"] = int(config["batch_size"] / ngpus_per_node) # config["num_workers"] = int((config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node) # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) # define optimizer # By default, Adasum doesn't need scaling up learning rate. lr_scaler = hvd.size() if not config["use_adasum"] else 1 # If using GPU Adasum allreduce, scale learning rate by local_size. if config["use_adasum"] and hvd.nccl_built(): lr_scaler = hvd.local_size() # Horovod: scale learning rate by lr_scaler. optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"] * lr_scaler, betas=config["adam_betas"], eps=1e-6, weight_decay=config["weight_decay"]) sched = get_linear_schedule_with_warmup(optimizer, config["warmup_steps"], config["num_steps"]) # Horovod: broadcast parameters & optimizer state. hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # Horovod: (optional) compression algorithm. compression = hvd.Compression.fp16 if config[ "fp16_allreduce"] else hvd.Compression.none # Horovod: wrap optimizer with DistributedOptimizer. optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), compression=compression, op=hvd.Adasum if config["use_adasum"] else hvd.Average, gradient_predivide_factor=config["gradient_predivide_factor"], ) # Load checkpoint if config["resume_path"]: logger.info(f"Loading parameters from {config['resume_path']}") # configure map_location properly map_location = {"cuda:%d" % 0: "cuda:%d" % hvd.rank()} checkpoint = torch.load(config["resume_path"], map_location=map_location) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] + 1 start_global_step = checkpoint["global_step"] else: start_epoch = 1 start_global_step = 0 # Setup data train_dataset = PrecomputedDataset( config["train_filepath"], min_alternatives=config["min_alternatives"], program_mode=config["program_mode"], limit_size=config["limit_dataset_size"], sp=sp, subword_regularization_alpha=config["subword_regularization_alpha"], max_length=config["max_length"], ) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=pad_collate_contrastive if config["program_mode"] == "contrastive" else pad_collate, drop_last=True, sampler=train_sampler, **kwargs, ) # Train global_step = 0 while global_step < start_global_step: sched.step() global_step += 1 for epoch in tqdm.trange(start_epoch, config["num_epochs"] + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") train_sampler.set_epoch(epoch) model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for batch in pbar: optimizer.zero_grad() if config["loss_mode"] == "infonce": train_metrics = training_step(model, batch, use_cuda=config["use_cuda"]) elif config["loss_mode"] == "mlm": # replace tokens randomly with tokens from _ (8) train_metrics = training_step_mlm(sp, model, batch, pad_id=pad_id, mask_id=mask_id, vocab_start_idx=8, vocab_end_idx=7999, use_cuda=config["use_cuda"]) elif config["loss_mode"] == "hybrid": train_metrics = training_step_hybrid( sp, model, batch, mask_id=mask_id, pad_id=pad_id, vocab_start_idx=0, vocab_end_idx=7999, use_cuda=config["use_cuda"]) else: raise ValueError("Bad loss type") loss = train_metrics["loss"] loss.backward() optimizer.step() sched.step() global_step += 1 pbar.set_description( f"epoch {epoch} gpu {gpu} step {global_step} loss {loss.item():.4f}" ) if chief_node: wandb.log(dict(lr=sched.get_last_lr()[0])) wandb.log(dict(epoch=epoch, **train_metrics["log"]), step=global_step) # Save checkpoint if config["save_every"] and global_step % config[ "save_every"] == 0: checkpoint = { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, } model_file = os.path.join( config["run_dir"], f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth" ) logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, model_file) wandb.save(str(model_file)) logger.info("Done.")