def train( run_name: str, # Data train_filepath: str = CSNJS_TRAIN_FILEPATH, eval_filepath: str = CSNJS_VALID_FILEPATH, spm_filepath: str = SPM_UNIGRAM_FILEPATH, program_mode="identity", eval_program_mode="identity", label_mode="identifier", num_workers=1, limit_dataset_size=-1, # Model model_type="transformer", n_decoder_layers=4, d_model: int = 512, resume_path: str = "", resume_encoder_name: str = "encoder_q", # encoder_q, encoder_k, encoder resume_project: bool = False, # Optimization train_decoder_only: bool = False, num_epochs: int = 50, save_every: int = 2, batch_size: int = 256, lr: float = 8e-4, adam_beta1: float = 0.9, adam_beta2: float = 0.98, use_lr_warmup: bool = True, loss_type = "nll_token", # nll_token or nll_sequence # Loss subword_regularization_alpha: float = 0, # Computational use_cuda: bool = True, auto_test: bool = True, seed: int = 0, ): """Train model""" torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) run_dir = RUN_DIR / run_name run_dir.mkdir(exist_ok=True, parents=True) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") config = locals() logger.info(f"Config: {config}") wandb.init(name=run_name, config=config, job_type="training", project="identifier-prediction", entity="ml4code") if use_cuda: assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False" train_augmentations = [ {"fn": "sample_lines", "line_length_pct": 0.5}, {"fn": "insert_var_declaration", "prob": 0.5}, {"fn": "rename_variable", "prob": 0.5}, ] sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") # Create training dataset and dataloader logger.info(f"Training data path {train_filepath}") train_dataset = get_csnjs_dataset(train_filepath, label_mode=label_mode, limit_size=limit_dataset_size) logger.info(f"Training dataset size: {len(train_dataset)}") train_loader = javascript_dataloader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, augmentations=train_augmentations, sp=sp, program_mode=program_mode, subword_regularization_alpha=subword_regularization_alpha, ) # Create eval dataset and dataloader logger.info(f"Eval data path {eval_filepath}") eval_dataset = get_csnjs_dataset(eval_filepath, label_mode=label_mode, limit_size=limit_dataset_size) logger.info(f"Eval dataset size: {len(eval_dataset)}") eval_loader = javascript_dataloader( eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, augmentations=[], sp=sp, program_mode=eval_program_mode, subword_regularization_alpha=subword_regularization_alpha, ) # Create model pad_id = sp.PieceToId("[PAD]") if model_type == "transformer": model = TransformerModel(n_tokens=sp.GetPieceSize(), pad_id=pad_id, n_decoder_layers=n_decoder_layers, d_model=d_model) logger.info(f"Created TransformerModel with {count_parameters(model)} params") elif model_type == "lstm": model = Seq2SeqLSTM(n_tokens=sp.GetPieceSize(), pad_id=pad_id, d_model=d_model) logger.info(f"Created Seq2SeqLSTM with {count_parameters(model)} params") # Load checkpoint if resume_path: logger.info(f"Resuming training from checkpoint {resume_path}, resume_encoder_name={resume_encoder_name}") checkpoint = torch.load(resume_path) pretrained_state_dict = checkpoint["model_state_dict"] encoder_state_dict = {} assert resume_encoder_name in ["encoder_k", "encoder_q", "encoder"] for key, value in pretrained_state_dict.items(): if key.startswith(resume_encoder_name + ".") and "project_layer" not in key: remapped_key = key[len(resume_encoder_name + ".") :] logger.debug(f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}") encoder_state_dict[remapped_key] = value if key.startswith(resume_encoder_name + ".") and "project_layer.0." in key and resume_project: remapped_key = key[len(resume_encoder_name + ".") :] logger.debug(f"Remapping checkpoint project key {key} to {remapped_key}. Value mean: {value.mean().item()}") encoder_state_dict[remapped_key] = value model.encoder.load_state_dict(encoder_state_dict, strict=False) logger.info(f"Loaded state dict from {resume_path}") logger.info(f"Loaded keys: {encoder_state_dict.keys()}") # Set up optimizer model = nn.DataParallel(model) model = model.cuda() if use_cuda else model wandb.watch(model, log="all") params = model.module.decoder.parameters() if train_decoder_only else model.parameters() optimizer = torch.optim.Adam(params, lr=lr, betas=(adam_beta1, adam_beta2), eps=1e-9) if use_lr_warmup: scheduler = get_linear_schedule_with_warmup(optimizer, 5000, len(train_loader) * num_epochs) else: scheduler = LambdaLR(optimizer, lr_lambda=lambda x: 1.0) global_step = 0 min_eval_loss = float("inf") for epoch in tqdm.trange(1, num_epochs + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") if train_decoder_only: model.module.encoder.eval() model.module.decoder.train() else: model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for X, Y, X_lengths, Y_lengths in pbar: if use_cuda: X = X.cuda() Y = Y.cuda() X_lengths, Y_lengths = X_lengths.cuda(), Y_lengths.cuda() optimizer.zero_grad() # NOTE: X and Y are [B, max_seq_len] tensors (batch first) logits = model(X, Y[:, :-1], X_lengths, Y_lengths) if loss_type == "nll_sequence": loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id, reduction='sum') loss = loss / X.size(0) # Average over num sequences, not target sequence lengths # Thus, minimize bits per sequence. elif loss_type == "nll_token": loss = F.cross_entropy(logits.transpose(1, 2), Y[:, 1:], ignore_index=pad_id,) loss.backward() optimizer.step() scheduler.step() # Log loss global_step += 1 wandb.log( {"epoch": epoch, f"label-{label_mode}/train_loss": loss.item(), "lr": scheduler.get_last_lr()[0]}, step=global_step ) pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}") # Evaluate logger.info(f"Evaluating model after epoch {epoch} ({global_step} steps)...") max_decode_len = 20 if label_mode == "identifier" else 200 eval_loss = _evaluate(model, eval_loader, sp, use_cuda=use_cuda, max_decode_len=max_decode_len, loss_type=loss_type) logger.info(f"Evaluation loss after epoch {epoch} ({global_step} steps): {eval_loss:.4f}") wandb.log({"epoch": epoch, f"label-{label_mode}/eval_loss": eval_loss}, step=global_step) # Save checkpoint if save_every and epoch % save_every == 0 or eval_loss < min_eval_loss: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, "eval_loss": eval_loss, } if eval_loss < min_eval_loss: logger.info(f"New best evaluation loss: prev {min_eval_loss:.4f} > new {eval_loss:.4f}") min_eval_loss = eval_loss model_file = run_dir / "ckpt_best.pth" else: model_file = run_dir / f"ckpt_ep{epoch:04d}.pth" logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, str(model_file.resolve())) wandb.save(str(model_file.resolve())) logger.info("Done.") if auto_test: best_ckpt = run_dir / "ckpt_best.pth" test( str(best_ckpt.resolve()), CSNJS_TEST_FILEPATH, spm_filepath, program_mode, label_mode, num_workers, -1, n_decoder_layers=n_decoder_layers, )
data_type, len(dataset.data))) logger.info( "data_loader (data_type: {:<5s}, len: {}, batch_size: {}) generated" .format(data_type, len(data_loader), finetune_config["batch_size"])) # optimizer and losses writer = SummaryWriter(save_model_dir) optimizer = torch.optim.AdamW(model.parameters(), lr=finetune_config["lr"], weight_decay=finetune_config["weight_decay"], amsgrad=True) optimizer.zero_grad() scheduler = get_linear_schedule_with_warmup(optimizer, len(data_loaders["train"]), train_config["epochs"] * len(data_loaders["train"]), min_percent=0.0001) best_reg_losses = {"train": INF, "dev": INF, "test": INF} best_reg_epochs = {"train": -1, "dev": -1, "test": -1} for epoch in range(finetune_config["epochs"]): for data_type, data_loader in data_loaders.items(): if data_type == "train": mean_reg_loss, mean_bp_loss = train(model, optimizer, scheduler, data_type, data_loader,
def train(self): epoch_loss_history = list() min_validation_loss = sys.float_info.max patience_cnt = self.config.patience self.config.n_gpu = torch.cuda.device_count() t_total = len(self.train_data_loader) * self.config.n_epoch cur_step = 0 no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': self.config.weight_decay }, { 'params': [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=self.config.warmup_steps, num_training_steps=t_total) if self.config.n_gpu > 1: self.model = torch.nn.DataParallel(self.model).to( self.config.device) else: self.model = self.model.to(self.config.device) for epoch_i in range(self.epoch_i, self.config.n_epoch): self.epoch_i = epoch_i batch_loss_history = list() self.model.train() epoch_lm_loss = 0.0 epoch_conv_loss = 0.0 epoch_batch_loss = 0.0 for batch_i, (input_utterances, input_utterances_mask, target_utterance, target_utterance_mask, input_user_ids, target_user_ids) in enumerate( tqdm(self.train_data_loader, ncols=80)): input_utterances = torch.LongTensor(input_utterances).to( self.config.device) input_utterances_mask = torch.LongTensor( input_utterances_mask).to(self.config.device) target_utterance = torch.LongTensor(target_utterance).to( self.config.device) target_utterance_mask = torch.LongTensor( target_utterance_mask).to(self.config.device) user_available = input_user_ids[0] is not None if user_available: input_user_ids = torch.LongTensor(input_user_ids).to( self.config.device) target_user_ids = torch.LongTensor(target_user_ids).to( self.config.device) self.optimizer.zero_grad() self.model.zero_grad() loss_fn = torch.nn.CrossEntropyLoss( ignore_index=self.config.pad_id) target, gt_target = target_utterance[ ..., :-1].contiguous(), target_utterance[..., 1:].contiguous() target_mask = target_utterance_mask[..., :-1].contiguous() if user_available: target_user_ids = target_user_ids[..., :-1].contiguous() else: input_user_ids = None target_user_ids = None lm_output, conv_output = self.model(target, target_mask, input_utterances, input_utterances_mask, target_user_ids, input_user_ids) # 1. Calculate Language Model Loss outputs, labels = lm_output[..., :-1, :].contiguous( ), input_utterances[..., 1:].contiguous() lm_loss = loss_fn(outputs.view(-1, outputs.size(-1)), labels.view(-1)) # 2. Calculate Conv Loss conv_loss = loss_fn(conv_output.view(-1, conv_output.size(-1)), gt_target.view(-1)) # 3. Total Loss batch_loss = lm_loss * 0.2 + conv_loss assert not isnan(batch_loss.item()) if self.config.n_gpu > 1: batch_loss = batch_loss.mean() epoch_lm_loss = (batch_i * epoch_lm_loss + lm_loss.item()) / (batch_i + 1) epoch_conv_loss = (batch_i * epoch_conv_loss + conv_loss.item()) / (batch_i + 1) epoch_batch_loss = (batch_i * epoch_batch_loss + batch_loss.item()) / (batch_i + 1) if batch_i % self.config.print_every == 0: tqdm.write( f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item():.3f}' ) self.writer.add_scalar('Train/lm_loss', lm_loss.item(), cur_step) self.writer.add_scalar('Train/conv_loss', conv_loss.item(), cur_step) self.writer.add_scalar('Train/loss', batch_loss.item(), cur_step) self.writer.add_scalar('Train/learning_rate', self.scheduler.get_lr()[0], cur_step) batch_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) self.optimizer.step() self.scheduler.step() cur_step += 1 epoch_loss_history.append(epoch_batch_loss) self.epoch_loss = epoch_batch_loss print(f'Epoch {epoch_i+1} loss average: {epoch_batch_loss:.3f}') print('\n<Validation>...') val_loss, val_lm_loss, val_conv_loss = self.evaluate() self.validation_loss = val_loss if epoch_i % self.config.plot_every_epoch == 0: self.writer.add_scalar('Val/lm_loss', val_lm_loss, epoch_i + 1) self.writer.add_scalar('Val/conv_loss', val_conv_loss, epoch_i + 1) self.writer.add_scalar('Val/loss', val_loss, epoch_i + 1) self.save_model(epoch_i) if min_validation_loss > self.validation_loss: min_validation_loss = self.validation_loss else: patience_cnt -= 1 if patience_cnt < 0: print(f'\nEarly stop at {epoch_i}') return epoch_loss_history self.save_model(self.config.n_epoch) return epoch_loss_history
def train( run_name: str, # Data train_filepath: str, eval_filepath: str, type_vocab_filepath: str, spm_filepath: str, num_workers=1, max_seq_len=1024, max_eval_seq_len=1024, run_dir=RUN_DIR, # Model resume_path: str = "", pretrain_resume_path: str = "", pretrain_resume_encoder_name: str = "encoder_q", # encoder_q, encoder_k, encoder pretrain_resume_project: bool = False, no_output_attention: bool = False, encoder_type: str = "transformer", n_encoder_layers: int = 6, d_model: int = 512, # Output layer hparams d_out_projection: int = 512, n_hidden_output: int = 1, # Optimization num_epochs: int = 100, save_every: int = 2, batch_size: int = 256, lr: float = 8e-4, adam_beta1: float = 0.9, adam_beta2: float = 0.98, adam_eps: float = 1e-6, weight_decay: float = 0, warmup_steps: int = 5000, num_steps: int = 200000, # Augmentations subword_regularization_alpha: float = 0, sample_lines_prob: float = 0, sample_lines_prob_keep_line: float = 0.9, # Loss ignore_any_loss: bool = False, # Computational use_cuda: bool = True, seed: int = 1, ): """Train model""" torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if run_dir != RUN_DIR: run_dir = Path(run_dir) run_dir = run_dir / run_name run_dir.mkdir(exist_ok=True, parents=True) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") config = locals() logger.info(f"Config: {config}") wandb.init(name=run_name, config=config, job_type="training", project="type_prediction", entity="ml4code") if use_cuda: assert torch.cuda.is_available( ), "CUDA not available. Check env configuration, or pass --use_cuda False" sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") id_to_target, target_to_id = load_type_vocab(type_vocab_filepath) no_type_id = target_to_id["O"] assert no_type_id == 0 # Just a sense check since O is the first line in the vocab file any_id = target_to_id["$any$"] collate_fn = get_collate_fn(pad_id, no_type_id) # Create training dataset and dataloader logger.info(f"Training data path {train_filepath}") if sample_lines_prob > 0: augmentations = [ { "fn": "sample_lines", "options": { "prob": sample_lines_prob, "prob_keep_line": sample_lines_prob_keep_line } }, ] program_mode = "augmentation" else: augmentations = None program_mode = "identity" train_dataset = DeepTyperDataset( train_filepath, type_vocab_filepath, spm_filepath, max_length=max_seq_len, subword_regularization_alpha=subword_regularization_alpha, augmentations=augmentations, program_mode=program_mode, ) logger.info(f"Training dataset size: {len(train_dataset)}") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, collate_fn=collate_fn) # Create eval dataset and dataloader logger.info(f"Eval data path {eval_filepath}") eval_dataset = DeepTyperDataset( eval_filepath, type_vocab_filepath, spm_filepath, max_length=max_eval_seq_len, subword_regularization_alpha=0, split_source_targets_by_tab=eval_filepath.endswith(".json"), ) logger.info(f"Eval dataset size: {len(eval_dataset)}") eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) # Create model model = TypeTransformer( n_tokens=sp.GetPieceSize(), n_output_tokens=len(id_to_target), pad_id=pad_id, encoder_type=encoder_type, n_encoder_layers=n_encoder_layers, d_model=d_model, d_out_projection=d_out_projection, n_hidden_output=n_hidden_output, ) logger.info( f"Created TypeTransformer {encoder_type} with {count_parameters(model)} params" ) # Load pretrained checkpoint if pretrain_resume_path: assert not resume_path logger.info( f"Resuming training from pretraining checkpoint {pretrain_resume_path}, pretrain_resume_encoder_name={pretrain_resume_encoder_name}" ) checkpoint = torch.load(pretrain_resume_path) pretrained_state_dict = checkpoint["model_state_dict"] encoder_state_dict = {} output_state_dict = {} assert pretrain_resume_encoder_name in [ "encoder_k", "encoder_q", "encoder" ] for key, value in pretrained_state_dict.items(): if key.startswith(pretrain_resume_encoder_name + ".") and "project_layer" not in key: remapped_key = key[len(pretrain_resume_encoder_name + "."):] logger.debug( f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}" ) encoder_state_dict[remapped_key] = value if key.startswith( pretrain_resume_encoder_name + "." ) and "project_layer.0." in key and pretrain_resume_project: remapped_key = key[len(pretrain_resume_encoder_name + ".project_layer."):] logger.debug( f"Remapping checkpoint project key {key} to output key {remapped_key}. Value mean: {value.mean().item()}" ) output_state_dict[remapped_key] = value model.encoder.load_state_dict(encoder_state_dict) # TODO: check for head key rather than output for MLM model.output.load_state_dict(output_state_dict, strict=False) logger.info(f"Loaded state dict from {pretrain_resume_path}") # Set up optimizer model = nn.DataParallel(model) model = model.cuda() if use_cuda else model wandb.watch(model, log="all") optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(adam_beta1, adam_beta2), eps=adam_eps, weight_decay=weight_decay) scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, num_steps) epoch = 0 global_step = 0 min_eval_metric = float("inf") if resume_path: assert not pretrain_resume_path logger.info(f"Resuming training from checkpoint {resume_path}") checkpoint = torch.load(resume_path) model.module.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) epoch = checkpoint["epoch"] global_step = checkpoint["global_step"] min_eval_metric = checkpoint["min_eval_metric"] # Eval metric history max_eval_metrics = {} # Evaluate initial metrics logger.info( f"Evaluating model after epoch {epoch} ({global_step} steps)...") eval_metric, eval_metrics = _evaluate( model, eval_loader, sp, target_to_id=target_to_id, use_cuda=use_cuda, no_output_attention=no_output_attention) for metric, value in eval_metrics.items(): logger.info( f"Evaluation {metric} after epoch {epoch} ({global_step} steps): {value:.4f}" ) max_eval_metrics[metric] = value eval_metrics["epoch"] = epoch wandb.log(eval_metrics, step=global_step) wandb.log({k + "_max": v for k, v in max_eval_metrics.items()}, step=global_step) for epoch in tqdm.trange(epoch + 1, num_epochs + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for X, lengths, output_attn, labels in pbar: if use_cuda: X, lengths, output_attn, labels = X.cuda(), lengths.cuda( ), output_attn.cuda(), labels.cuda() optimizer.zero_grad() if no_output_attention: logits = model(X, lengths, None) # BxLxVocab else: logits = model(X, lengths, output_attn) # BxLxVocab if ignore_any_loss: # Don't train with $any$ type labels_ignore_any = labels.clone() labels_ignore_any[labels_ignore_any == any_id] = no_type_id loss = F.cross_entropy(logits.transpose(1, 2), labels_ignore_any, ignore_index=no_type_id) else: loss = F.cross_entropy(logits.transpose(1, 2), labels, ignore_index=no_type_id) loss.backward() optimizer.step() scheduler.step() # Compute accuracy in training batch (corr1_any, corr5_any), num_labels_any = accuracy(logits, labels, topk=(1, 5), ignore_idx=(no_type_id, )) acc1_any, acc5_any = corr1_any / num_labels_any * 100, corr5_any / num_labels_any * 100 (corr1, corr5), num_labels = accuracy(logits, labels, topk=(1, 5), ignore_idx=(no_type_id, any_id)) acc1, acc5 = corr1 / num_labels * 100, corr5 / num_labels * 100 # Log loss global_step += 1 wandb.log( { "epoch": epoch, "train/loss": loss.item(), "train/acc@1": acc1, "train/acc@5": acc5, "train/acc@1_any": acc1_any, "train/acc@5_any": acc5_any, "lr": scheduler.get_last_lr()[0], }, step=global_step, ) pbar.set_description(f"epoch {epoch} loss {loss.item():.4f}") # Evaluate logger.info( f"Evaluating model after epoch {epoch} ({global_step} steps)...") eval_metric, eval_metrics = _evaluate( model, eval_loader, sp, target_to_id=target_to_id, use_cuda=use_cuda, no_output_attention=no_output_attention) for metric, value in eval_metrics.items(): logger.info( f"Evaluation {metric} after epoch {epoch} ({global_step} steps): {value:.4f}" ) max_eval_metrics[metric] = max(value, max_eval_metrics[metric]) eval_metrics["epoch"] = epoch wandb.log(eval_metrics, step=global_step) wandb.log({k + "_max": v for k, v in max_eval_metrics.items()}, step=global_step) # Save checkpoint if save_every and epoch % save_every == 0 or eval_metric < min_eval_metric: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, "eval_metric": eval_metric, "min_eval_metric": min_eval_metric, } if eval_metric < min_eval_metric: logger.info( f"New best evaluation metric: prev {min_eval_metric:.4f} > new {eval_metric:.4f}" ) min_eval_metric = eval_metric model_file = run_dir / "ckpt_best.pth" else: model_file = run_dir / f"ckpt_ep{epoch:04d}.pth" logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, str(model_file.resolve())) logger.info("Done.")
def train(self): epoch_loss_history = list() min_validation_loss = sys.float_info.max patience_cnt = self.config.patience self.config.n_gpu = torch.cuda.device_count() t_total = len(self.train_data_loader) * self.config.n_epoch cur_step = 0 no_decay = ['bias', 'ln'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': self.config.weight_decay }, { 'params': [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] self.optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate) self.scheduler = get_linear_schedule_with_warmup( self.optimizer, num_warmup_steps=self.config.warmup_steps, num_training_steps=t_total) if self.config.n_gpu > 1: self.model = torch.nn.DataParallel(self.model).to( self.config.device) else: self.model = self.model.to(self.config.device) for epoch_i in range(self.epoch_i, self.config.n_epoch): self.epoch_i = epoch_i self.model.train() epoch_loss = 0.0 for batch_i, batch in enumerate( tqdm(self.train_data_loader, ncols=80)): self.optimizer.zero_grad() self.model.zero_grad() batch = tuple(t.to(self.config.device) for t in batch) input_ids, position_ids, token_ids, label_ids, user_ids, user_mask = batch inputs = { 'input_ids': input_ids, 'position_ids': position_ids, 'token_type_ids': token_ids, 'user_mask': user_mask } outputs = self.model(**inputs) loss_fn = nn.CrossEntropyLoss(ignore_index=-1) if self.config.users and self.config.reversed: lm_logits, user_outputs = outputs user_ids = user_ids.view(-1) user_ids_mask = user_ids != -1 user_ids = user_ids[user_ids_mask] user_loss = loss_fn( user_outputs.view(-1, user_outputs.size(-1)), user_ids.view(-1)) output_loss = loss_fn( lm_logits.view(-1, lm_logits.size(-1)), label_ids.view(-1)) batch_loss = user_loss + output_loss else: batch_loss = loss_fn(outputs.view(-1, outputs.size(-1)), label_ids.view(-1)) user_loss = None output_loss = None assert not isnan(batch_loss.item()) if self.config.n_gpu > 1: batch_loss = batch_loss.mean() epoch_loss = (batch_i * epoch_loss + batch_loss.item()) / (batch_i + 1) if batch_i % self.config.print_every == 0: tqdm.write( f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item():.3f}' ) self.writer.add_scalar('Train/loss', batch_loss.item(), cur_step) if self.config.users and self.config.reversed: self.writer.add_scalar('Train/user_loss', user_loss.item(), cur_step) self.writer.add_scalar('Train/lmloss', output_loss.item(), cur_step) self.writer.add_scalar('Train/learning_rate', self.scheduler.get_lr()[0], cur_step) batch_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) self.optimizer.step() self.scheduler.step() cur_step += 1 if epoch_i == 0: if self.config.users and not self.config.reversed: self.vocab.save_pretrained(self.config.save_path) epoch_loss_history.append(epoch_loss) self.epoch_loss = epoch_loss print(f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}') print('\n<Validation>...') val_loss, output_loss, user_loss = self.evaluate() self.validation_loss = val_loss if epoch_i % self.config.plot_every_epoch == 0: self.writer.add_scalar('Val/loss', self.validation_loss, epoch_i + 1) if output_loss != None and user_loss != None: self.writer.add_scalar('Val/lmloss', output_loss.item(), epoch_i + 1) self.writer.add_scalar('Val/user_loss', user_loss.item(), epoch_i + 1) if min_validation_loss > self.validation_loss: min_validation_loss = self.validation_loss else: patience_cnt -= 1 self.save_model(epoch_i) if patience_cnt < 0: print(f'\nEarly stop at {epoch_i}') self.save_model(epoch_i) return epoch_loss_history self.save_model(self.config.n_epoch) return epoch_loss_history
def train( run_name: str, # Data train_filepath: str = "data/codeclone/train_data.json", eval_filepath: str = "data/codeclone/valid_data.json", spm_filepath: str = "data/codesearchnet_javascript/csnjs_8k_9995p_unigram_url.model", num_workers=1, max_seq_len=1024, max_eval_seq_len=1024, run_dir=RUN_DIR, balance_negatives=False, # Model resume_path: str = "", pretrain_resume_path: str = "", pretrain_resume_encoder_name: str = "encoder_q", # encoder_q, encoder_k, encoder encoder_type: str = "transformer", n_encoder_layers: int = 6, d_model: int = 512, critic_type: str = "bilinear_identity", critic_bilinear_rank: int = None, # Optimization train_decoder_only: bool = False, num_epochs: int = 100, batch_size: int = 256, lr: float = 8e-4, adam_beta1: float = 0.9, adam_beta2: float = 0.98, adam_eps: float = 1e-6, weight_decay: float = 0, warmup_steps: int = 5000, num_steps: int = 200000, # Evaluation save_every_steps: int = 5000, score_every_steps: int = 10, # Interval for train ROC AUC, AP score evaluate_every_steps: int = 1250, # Augmentations subword_regularization_alpha: float = 0, # Computational use_cuda: bool = True, seed: int = 1, ): """Train model""" assert save_every_steps % evaluate_every_steps == 0 torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if run_dir != RUN_DIR: run_dir = Path(run_dir) run_dir = run_dir / run_name run_dir.mkdir(exist_ok=True, parents=True) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") config = locals() logger.info(f"Config: {config}") wandb.init(name=run_name, config=config, job_type="training", project="clone_detection", entity="ml4code") if use_cuda: assert torch.cuda.is_available(), "CUDA not available. Check env configuration, or pass --use_cuda False" sp = spm.SentencePieceProcessor() sp.Load(spm_filepath) pad_id = sp.PieceToId("[PAD]") pad_collate = get_pad_collate(pad_id) # Create training dataset and dataloader logger.info(f"Training data path {train_filepath}") train_programs = CloneProgramsDataset(train_filepath, sp, subword_regularization_alpha) train_positives = ClonePositivesDataset(train_programs) train_negatives = CloneNegativesDataset(train_programs) train_dataset = torch.utils.data.ConcatDataset([train_positives, train_negatives]) if balance_negatives: positive_weight = 1 / len(train_programs) negative_weight = 1 / len(train_negatives) weights = torch.cat([torch.zeros(len(train_programs)) + positive_weight, torch.zeros(len(train_negatives)) + negative_weight]) sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) else: sampler = None logger.info(f"Training dataset size: {len(train_dataset)}") train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=not balance_negatives, num_workers=num_workers, drop_last=True, collate_fn=pad_collate, sampler=sampler, ) # Create eval dataset and dataloader logger.info(f"Eval data path {eval_filepath}") eval_programs = CloneProgramsDataset(eval_filepath, sp, subword_regularization_alpha) eval_positives = ClonePositivesDataset(eval_programs) eval_negatives = CloneNegativesDataset(eval_programs) eval_dataset = torch.utils.data.ConcatDataset([eval_positives, eval_negatives]) logger.info(f"Eval dataset size: {len(eval_dataset)}") eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=pad_collate ) # Create model model = CloneModel( n_tokens=sp.GetPieceSize(), pad_id=pad_id, encoder_type=encoder_type, n_encoder_layers=n_encoder_layers, d_model=d_model, critic_type=critic_type, bilinear_rank=critic_bilinear_rank, ) logger.info(f"Created CloneModel {encoder_type}, {critic_type} with {count_parameters(model)} params") # Load pretrained checkpoint if pretrain_resume_path: assert not resume_path logger.info( f"Resuming training from pretraining checkpoint {pretrain_resume_path}, pretrain_resume_encoder_name={pretrain_resume_encoder_name}" ) checkpoint = torch.load(pretrain_resume_path) pretrained_state_dict = checkpoint["model_state_dict"] encoder_state_dict = {} # output_state_dict = {} assert pretrain_resume_encoder_name in ["encoder_k", "encoder_q", "encoder"] for key, value in pretrained_state_dict.items(): if key.startswith(pretrain_resume_encoder_name + ".") and "project_layer" not in key: remapped_key = key[len(pretrain_resume_encoder_name + ".") :] logger.debug(f"Remapping checkpoint key {key} to {remapped_key}. Value mean: {value.mean().item()}") encoder_state_dict[remapped_key] = value model.encoder.load_state_dict(encoder_state_dict) logger.info(f"Loaded state dict from {pretrain_resume_path}") # Set up optimizer model = nn.DataParallel(model) model = model.cuda() if use_cuda else model wandb.watch(model, log="all") optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(adam_beta1, adam_beta2), eps=adam_eps, weight_decay=weight_decay) scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, num_steps) epoch = 0 global_step = 0 min_eval_metric = float("inf") if resume_path: assert not pretrain_resume_path logger.info(f"Resuming training from checkpoint {resume_path}") checkpoint = torch.load(resume_path) model.module.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) epoch = checkpoint["epoch"] global_step = checkpoint["global_step"] min_eval_metric = checkpoint["min_eval_metric"] # Eval metric history max_eval_metrics = {} # Initial evaluation logger.info(f"Evaluating model initially ({global_step} steps)...") eval_metric, eval_metrics = _evaluate(model, eval_loader, sp, pad_id=pad_id, use_cuda=use_cuda) for metric, value in eval_metrics.items(): logger.info(f"Evaluation {metric} initial ({global_step} steps): {value:.4f}") max_eval_metrics[metric] = value eval_metrics["epoch"] = epoch wandb.log(eval_metrics, step=global_step) wandb.log({k + "_max": v for k, v in max_eval_metrics.items()}, step=global_step) for epoch in tqdm.trange(epoch + 1, num_epochs + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") model.train() model.module.encoder.eval() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for X, lengths, labels in pbar: if use_cuda: X, lengths, labels = X.cuda(), lengths.cuda(), labels.cuda() optimizer.zero_grad() similarity = model(X, lengths) # B loss = F.binary_cross_entropy_with_logits(similarity, labels.float()) loss.backward() optimizer.step() scheduler.step() with torch.no_grad(): train_metrics = { "epoch": epoch, "train/loss": loss.item(), "lr": scheduler.get_last_lr()[0], "train/labels_mean": labels.float().mean().item(), } # Compute scores in training batch if global_step % score_every_steps == 0: y_true = labels.cpu().numpy() y_scores = similarity.cpu().numpy() roc_auc = roc_auc_score(y_true, y_scores) ap_score = average_precision_score(y_true, y_scores) train_metrics["train/roc_auc_score"] = roc_auc train_metrics["train/ap_score"] = ap_score pbar.set_description(f"epoch {epoch} loss {loss.item():.4f} roc_auc {roc_auc:.4f} ap {ap_score:.4f}") # Log loss global_step += 1 wandb.log(train_metrics, step=global_step) # Evaluate if evaluate_every_steps and global_step % evaluate_every_steps == 0: logger.info(f"Evaluating model after epoch {epoch} ({global_step} steps)...") eval_metric, eval_metrics = _evaluate(model, eval_loader, sp, pad_id=pad_id, use_cuda=use_cuda) for metric, value in eval_metrics.items(): logger.info(f"Evaluation {metric} after epoch {epoch} ({global_step} steps): {value:.4f}") max_eval_metrics[metric] = max(value, max_eval_metrics[metric]) eval_metrics["epoch"] = epoch wandb.log(eval_metrics, step=global_step) wandb.log({k + "_max": v for k, v in max_eval_metrics.items()}, step=global_step) # Save checkpoint if save_every_steps and global_step % save_every_steps == 0 or eval_metric < min_eval_metric: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, "eval_metric": eval_metric, "min_eval_metric": min_eval_metric, } if eval_metric < min_eval_metric: logger.info(f"New best evaluation metric: prev {min_eval_metric:.4f} > new {eval_metric:.4f}") min_eval_metric = eval_metric model_file = run_dir / "ckpt_best.pth" else: model_file = run_dir / f"ckpt_ep{epoch:04d}_step{global_step}.pth" logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, str(model_file.resolve())) logger.info("Done.")
def run(): dfx = pd.read_csv(config.TRAINING_FILE, nrows=config.NROWS).dropna().reset_index(drop=True) # dfx.sentiment = dfx.sentiment.apply( # lambda x: 1 if x =='positive' else 0 # ) print('Data Loaded') df_train, df_valid = model_selection.train_test_split( dfx, test_size=0.5, random_state=42, stratify=dfx.sentiment.values) print('Data split into train data and validation data') df_train = df_train.reset_index(drop=True) df_valid = df_valid.reset_index(drop=True) train_dataset = dataset.TweetDataset( tweet=df_train.text.values, sentiment=df_train.sentiment.values, selected_text=df_train.selected_text.values) print('Train data preprocessed and made into Tweet Dataset Object') train_data_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.TRAIN_BATCH_SIZE, shuffle=True, num_workers=4) print('Train dataloader created') valid_dataset = dataset.TweetDataset( tweet=df_valid.text.values, sentiment=df_valid.sentiment.values, selected_text=df_valid.selected_text.values) print('Valid data preprocessed and made into Tweet Dataset Object') valid_data_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=config.VALID_BATCH_SIZE, num_workers=1) print('Valid dataloader created') device = config.DEVICE conf = transformers.RobertaConfig.from_pretrained( f'{config.PATH}roberta-base-config.json') conf.output_hidden_states = False model = Roberta(conf) model.to(device) print('Model Object created') param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] num_train_steps = int( len(df_train) / config.TRAIN_BATCH_SIZE * config.EPOCHS) optimizer = AdamW(optimizer_parameters, lr=3e-5) scheduler = utils.get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=num_train_steps) best_jaccard = 0 print('Starting Training....') for epoch in range(config.EPOCHS): engine.train_fn(train_data_loader, model, optimizer, device, scheduler) jaccard = engine.eval_fn(valid_data_loader, model, device) print(f'Jaccard Score : {jaccard}') if jaccard > best_jaccard: torch.save(model.state_dict(), config.MODEL_PATH) best_jaccard = jaccard
def pretrain_worker(gpu, ngpus_per_node, config): chief_node = gpu == 0 if chief_node: if config["loss_mode"] == "mlm": project = "bert-pretrain" elif config["loss_mode"] == "infonce": project = "moco-pretrain" elif config["loss_mode"] == "hybrid": project = "hybrid" wandb.init(name=config["run_name"], config=config, job_type="training", project=project, entity="ml4code") if gpu is not None: logger.info("Use GPU: {} for training".format(gpu)) if config["dist_url"] == "env://" and config["rank"] == -1: config["rank"] = int(os.environ["RANK"]) # For multiprocessing distributed training, rank needs to be the # global rank among all the processes config["rank"] = config["rank"] * ngpus_per_node + gpu dist.init_process_group(backend=config["dist_backend"], init_method=config["dist_url"], world_size=config["world_size"], rank=config["rank"]) sp = spm.SentencePieceProcessor() sp.Load(config["spm_filepath"]) pad_id = sp.PieceToId("[PAD]") mask_id = sp.PieceToId("[MASK]") def pad_collate(batch): B = len(batch) if config["program_mode"] == "contrastive": X1, X2 = zip(*batch) X = X1 + X2 else: X = batch # Create tensor of sequence lengths, [B] or [2B] lengths = torch.tensor([len(x) for x in X], dtype=torch.long) # Create padded tensor for batch, [B, T] or [2B, T] X = pad_sequence(X, batch_first=True, padding_value=pad_id) if config["program_mode"] == "contrastive": # Reshape X to [B, 2, T] T = X.size(-1) X = torch.reshape(X, (2, B, -1)) X = torch.transpose(X, 0, 1) assert X.shape == (B, 2, T) lengths = torch.reshape(lengths, (2, B)).transpose(0, 1) assert lengths.shape == (B, 2) return X, lengths, None # Create model if config["loss_mode"] == "infonce": model = CodeMoCo(sp.GetPieceSize(), pad_id=pad_id, d_model=config["d_model"], encoder_config=dict( encoder_type=config["encoder_type"], lstm_project_mode=config["lstm_project_mode"], n_encoder_layers=config["n_encoder_layers"])) logger.info( f"Created CodeMoCo model with {count_parameters(model)} params") elif config["loss_mode"] == "mlm": model = CodeMLM(sp.GetPieceSize(), pad_id=pad_id, encoder_type=config["encoder_type"], n_encoder_layers=config["n_encoder_layers"]) logger.info( f"Created CodeMLM model with {count_parameters(model)} params") elif config["loss_mode"] == "hybrid": model = CodeContrastiveMLM(sp.GetPieceSize(), pad_id=pad_id) logger.info( f"Created CodeContrastiveMLM model with {count_parameters(model)} params" ) else: raise ValueError(f"Bad loss mode {config['loss_mode']}") assert config["use_cuda"] if gpu is not None: torch.cuda.set_device(gpu) model.cuda(gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have config["batch_size"] = int(config["batch_size"] / ngpus_per_node) config["num_workers"] = int( (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) # define optimizer optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], betas=config["adam_betas"], eps=1e-6, weight_decay=config["weight_decay"]) sched = get_linear_schedule_with_warmup(optimizer, config["warmup_steps"], config["num_steps"]) # Setup data train_dataset = PrecomputedDataset( config["train_filepath"], min_alternatives=config["min_alternatives"], program_mode=config["program_mode"], limit_size=config["limit_dataset_size"], sp=sp, subword_regularization_alpha=config["subword_regularization_alpha"], max_length=config["max_length"]) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=pad_collate, # num_workers=config["num_workers"], num_workers=0, drop_last=True, pin_memory=True, sampler=train_sampler, ) # Train global_step = 0 for epoch in tqdm.trange(1, config["num_epochs"] + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") train_sampler.set_epoch(epoch) model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for batch in pbar: optimizer.zero_grad() if config["loss_mode"] == "infonce": train_metrics = training_step(model, batch, use_cuda=config["use_cuda"]) elif config["loss_mode"] == "mlm": # replace tokens randomly with tokens from _ (8) train_metrics = training_step_mlm(sp, model, batch, pad_id=pad_id, mask_id=mask_id, vocab_start_idx=8, vocab_end_idx=7999, use_cuda=config["use_cuda"]) elif config["loss_mode"] == "hybrid": train_metrics = training_step_hybrid( sp, model, batch, mask_id=mask_id, pad_id=pad_id, vocab_start_idx=0, vocab_end_idx=7999, use_cuda=config["use_cuda"]) else: raise ValueError("Bad loss type") loss = train_metrics["loss"] loss.backward() optimizer.step() sched.step() global_step += 1 pbar.set_description( f"epoch {epoch} gpu {gpu} step {global_step} loss {loss.item():.4f}" ) if chief_node: wandb.log(dict(lr=sched.get_last_lr()[0])) wandb.log(dict(epoch=epoch, **train_metrics["log"]), step=global_step) # Save checkpoint if config["save_every"] and global_step % config[ "save_every"] == 0: checkpoint = { "model_state_dict": model.module.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, } model_file = os.path.join( config["run_dir"], f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth" ) logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, model_file) wandb.save(str(model_file)) logger.info("Done.")
def pretrain( run_name: str, # # Data train_filepath: str = DEFAULT_CSNJS_TRAIN_FILEPATH, spm_filepath: str = DEFAULT_SPM_UNIGRAM_FILEPATH, num_workers=1, limit_dataset_size=-1, max_length=1024, subword_regularization_alpha: float = 0, program_mode="contrastive", loss_mode="infonce", # infonce, mlm, or hybrid min_alternatives=1, # # Model resume_path: str = "", encoder_type: str = "transformer", lstm_project_mode: str = "hidden", n_encoder_layers: int = 6, d_model: int = 512, n_head: int = 8, # # Optimization num_epochs: int = 100, save_every: int = 1, batch_size: int = 256, lr: float = 8e-4, weight_decay: float = 0, adam_betas=(0.9, 0.98), warmup_steps: int = 5000, num_steps: int = 600000, # # Horovod use_adasum: bool = False, fp16_allreduce: bool = False, gradient_predivide_factor: float = 1.0, # # Computational use_cuda: bool = True, seed: int = 0, ): hvd.init() logger.info("L:", n_encoder_layers, type(n_encoder_layers)) logger.info("H:", d_model, type(d_model)) logger.info("A:", n_head, type(n_head)) run_name = str(run_name) # support numerical run ids slurm_job_id = os.environ.get("SLURM_JOB_ID") slurm_job_hostname = os.environ.get("SLURM_JOB_NODELIST") config = locals() logger.info(f"Config = \n{config}") logger.info("Training configuration: {}".format(config)) logger.info( f"CUDA_VISIBLE_DEVICES = '{os.environ.get('CUDA_VISIBLE_DEVICES')}'") logger.info(f"CUDA_DEVICE_ORDER = '{os.environ.get('CUDA_DEVICE_ORDER')}'") assert program_mode in ["contrastive", "identity", "augmentation"] assert loss_mode == "infonce" or loss_mode == "mlm" or loss_mode == "hybrid" assert not (program_mode == "contrastive" and loss_mode == "mlm") assert not (program_mode != "contrastive" and (loss_mode == "hybrid" or loss_mode == "infonce")) assert not use_cuda or torch.cuda.is_available() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) run_dir = RUN_DIR / "{}_{}".format(run_name, int(time.time())) run_dir.mkdir(exist_ok=True, parents=True) config["run_dir"] = str(run_dir.resolve()) logger.add(str((run_dir / "train.log").resolve())) logger.info(f"Saving logs, model checkpoints to {run_dir}") # Create training dataset and dataloader assert train_filepath.endswith(".pickle") or train_filepath.endswith(".gz") # Setup distributed gpu = hvd.local_rank() ngpus_per_node = 1 chief_node = gpu == 0 assert gpu is not None if chief_node: if config["loss_mode"] == "mlm": project = "bert-pretrain" elif config["loss_mode"] == "infonce": project = "moco-pretrain" elif config["loss_mode"] == "hybrid": project = "hybrid" wandb.init(name=config["run_name"], config=config, job_type="training", project=project, entity="ml4code") logger.info("Use GPU: {} for training".format(gpu)) torch.cuda.set_device(gpu) # Horovod: limit # of CPU threads to be used per worker. torch.set_num_threads(1) kwargs = {"num_workers": 1, "pin_memory": True} # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent # issues with Infiniband implementations that are not fork-safe if (kwargs.get("num_workers", 0) > 0 and hasattr(mp, "_supports_context") and mp._supports_context and "forkserver" in mp.get_all_start_methods()): kwargs["multiprocessing_context"] = "forkserver" sp = spm.SentencePieceProcessor() sp.Load(config["spm_filepath"]) pad_id = sp.PieceToId("[PAD]") logger.info("pad_id {}", pad_id) assert pad_id == 0 # hard coded in pad_collate mask_id = sp.PieceToId("[MASK]") # Create model if config["loss_mode"] == "infonce": # TODO(ajay): Support n_head argument, check how d_model is being used (why not in encoder config dict?) model = CodeMoCo( sp.GetPieceSize(), pad_id=pad_id, d_model=config["d_model"], encoder_config=dict( encoder_type=config["encoder_type"], lstm_project_mode=config["lstm_project_mode"], n_encoder_layers=config["n_encoder_layers"], ), ) logger.info( f"Created CodeMoCo model with {count_parameters(model)} params") elif config["loss_mode"] == "mlm": model = CodeMLM( sp.GetPieceSize(), pad_id=pad_id, encoder_type=config["encoder_type"], n_encoder_layers=config["n_encoder_layers"], d_model=config["d_model"], n_head=config["n_head"], d_ff=4 * config["d_model"], ) logger.info( f"Created CodeMLM model with {count_parameters(model)} params") elif config["loss_mode"] == "hybrid": model = CodeContrastiveMLM( sp.GetPieceSize(), pad_id=pad_id, n_encoder_layers=config["n_encoder_layers"], d_model=config["d_model"], n_head=config["n_head"], d_ff=4 * config["d_model"], use_horovod=True, ) logger.info( f"Created CodeContrastiveMLM model with {count_parameters(model)} params" ) else: raise ValueError(f"Bad loss mode {config['loss_mode']}") assert config["use_cuda"] model.cuda() # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have # config["batch_size"] = int(config["batch_size"] / ngpus_per_node) # config["num_workers"] = int((config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node) # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) # define optimizer # By default, Adasum doesn't need scaling up learning rate. lr_scaler = hvd.size() if not config["use_adasum"] else 1 # If using GPU Adasum allreduce, scale learning rate by local_size. if config["use_adasum"] and hvd.nccl_built(): lr_scaler = hvd.local_size() # Horovod: scale learning rate by lr_scaler. optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"] * lr_scaler, betas=config["adam_betas"], eps=1e-6, weight_decay=config["weight_decay"]) sched = get_linear_schedule_with_warmup(optimizer, config["warmup_steps"], config["num_steps"]) # Horovod: broadcast parameters & optimizer state. hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) # Horovod: (optional) compression algorithm. compression = hvd.Compression.fp16 if config[ "fp16_allreduce"] else hvd.Compression.none # Horovod: wrap optimizer with DistributedOptimizer. optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), compression=compression, op=hvd.Adasum if config["use_adasum"] else hvd.Average, gradient_predivide_factor=config["gradient_predivide_factor"], ) # Load checkpoint if config["resume_path"]: logger.info(f"Loading parameters from {config['resume_path']}") # configure map_location properly map_location = {"cuda:%d" % 0: "cuda:%d" % hvd.rank()} checkpoint = torch.load(config["resume_path"], map_location=map_location) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) start_epoch = checkpoint["epoch"] + 1 start_global_step = checkpoint["global_step"] else: start_epoch = 1 start_global_step = 0 # Setup data train_dataset = PrecomputedDataset( config["train_filepath"], min_alternatives=config["min_alternatives"], program_mode=config["program_mode"], limit_size=config["limit_dataset_size"], sp=sp, subword_regularization_alpha=config["subword_regularization_alpha"], max_length=config["max_length"], ) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=pad_collate_contrastive if config["program_mode"] == "contrastive" else pad_collate, drop_last=True, sampler=train_sampler, **kwargs, ) # Train global_step = 0 while global_step < start_global_step: sched.step() global_step += 1 for epoch in tqdm.trange(start_epoch, config["num_epochs"] + 1, desc="training", unit="epoch", leave=False): logger.info(f"Starting epoch {epoch}\n") train_sampler.set_epoch(epoch) model.train() pbar = tqdm.tqdm(train_loader, desc=f"epoch {epoch}") for batch in pbar: optimizer.zero_grad() if config["loss_mode"] == "infonce": train_metrics = training_step(model, batch, use_cuda=config["use_cuda"]) elif config["loss_mode"] == "mlm": # replace tokens randomly with tokens from _ (8) train_metrics = training_step_mlm(sp, model, batch, pad_id=pad_id, mask_id=mask_id, vocab_start_idx=8, vocab_end_idx=7999, use_cuda=config["use_cuda"]) elif config["loss_mode"] == "hybrid": train_metrics = training_step_hybrid( sp, model, batch, mask_id=mask_id, pad_id=pad_id, vocab_start_idx=0, vocab_end_idx=7999, use_cuda=config["use_cuda"]) else: raise ValueError("Bad loss type") loss = train_metrics["loss"] loss.backward() optimizer.step() sched.step() global_step += 1 pbar.set_description( f"epoch {epoch} gpu {gpu} step {global_step} loss {loss.item():.4f}" ) if chief_node: wandb.log(dict(lr=sched.get_last_lr()[0])) wandb.log(dict(epoch=epoch, **train_metrics["log"]), step=global_step) # Save checkpoint if config["save_every"] and global_step % config[ "save_every"] == 0: checkpoint = { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch, "global_step": global_step, "config": config, } model_file = os.path.join( config["run_dir"], f"ckpt_pretrain_ep{epoch:04d}_step{global_step:07d}.pth" ) logger.info(f"Saving checkpoint to {model_file}...") torch.save(checkpoint, model_file) wandb.save(str(model_file)) logger.info("Done.")