def attention_grid_search( model: torch.nn.Module, validation_evaluator: MultiDatasetClassificationEvaluator, n_epochs: int, seed: int, ): best_weights = model.module.weights # initial (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) best_f1 = F1 print(F1) # Create the grid search param_dict = { 1: list(range(0, 11)), 2: list(range(0, 11)), 3: list(range(0, 11)), 4: list(range(0, 11)), 5: list(range(0, 11)) } grid_search_params = ParameterSampler(param_dict, n_iter=n_epochs, random_state=seed) for d in grid_search_params: weights = [v for k, v in sorted(d.items(), key=lambda x: x[0])] weights = np.array(weights) / sum(weights) model.module.weights = weights # Inline evaluation (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) print(f"Weights: {weights}\tValidation F1: {F1}") if F1 > best_f1: best_weights = weights best_f1 = F1 # Log to wandb wandb.log({ 'Validation accuracy': acc, 'Validation Precision': P, 'Validation Recall': R, 'Validation F1': F1, 'Validation loss': val_loss }) gc.collect() return best_weights
def train( model: torch.nn.Module, train_dls: List[DataLoader], optimizer: torch.optim.Optimizer, scheduler: LambdaLR, validation_evaluator: MultiDatasetClassificationEvaluator, n_epochs: int, device: AnyStr, log_interval: int = 1, patience: int = 10, model_dir: str = "wandb_local", gradient_accumulation: int = 1, domain_name: str = '' ): #best_loss = float('inf') best_f1 = 0.0 patience_counter = 0 epoch_counter = 0 total = sum(len(dl) for dl in train_dls) # Main loop while epoch_counter < n_epochs: dl_iters = [iter(dl) for dl in train_dls] dl_idx = list(range(len(dl_iters))) finished = [0] * len(dl_iters) i = 0 with tqdm(total=total, desc="Training") as pbar: while sum(finished) < len(dl_iters): random.shuffle(dl_idx) for d in dl_idx: domain_dl = dl_iters[d] batches = [] try: for j in range(gradient_accumulation): batches.append(next(domain_dl)) except StopIteration: finished[d] = 1 if len(batches) == 0: continue optimizer.zero_grad() for batch in batches: model.train() batch = tuple(t.to(device) for t in batch) input_ids = batch[0] masks = batch[1] labels = batch[2] # Testing with random domains to see if any effect #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) domains = batch[3] loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True) loss = loss.mean() / gradient_accumulation if i % log_interval == 0: # wandb.log({ # "Loss": loss.item(), # "alpha0": alpha[:,0].cpu(), # "alpha1": alpha[:, 1].cpu(), # "alpha2": alpha[:, 2].cpu(), # "alpha_shared": alpha[:, 3].cpu() # }) wandb.log({ "Loss": loss.item() }) loss.backward() i += 1 pbar.update(1) optimizer.step() if scheduler is not None: scheduler.step() gc.collect() # Inline evaluation (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) print(f"Validation f1: {F1}") #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') # Saving the best model and early stopping #if val_loss < best_loss: if F1 > best_f1: best_model = model.state_dict() #best_loss = val_loss best_f1 = F1 #wandb.run.summary['best_validation_loss'] = best_loss torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth') patience_counter = 0 # Log to wandb wandb.log({ 'Validation accuracy': acc, 'Validation Precision': P, 'Validation Recall': R, 'Validation F1': F1, 'Validation loss': val_loss}) else: patience_counter += 1 # Stop training once we have lost patience if patience_counter == patience: break gc.collect() epoch_counter += 1
validation_evaluator = MultiDatasetClassificationEvaluator( val_ds, device) # Create the model bert = DistilBertForSequenceClassification.from_pretrained( bert_model, config=bert_config).to(device) multi_xformer = MultiDistilBertClassifier(bert_model, bert_config, n_domains=len(train_dls) - 1).to(device) if args.pretrained_multi_xformer is not None: multi_xformer.load_state_dict( torch.load( f"{args.pretrained_multi_xformer}/model_{domain}.pth")) (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(multi_xformer) print(f"Validation acc multi-xformer: {acc}") shared_bert = VanillaBert(bert).to(device) if args.pretrained_bert is not None: shared_bert.load_state_dict( torch.load(f"{args.pretrained_bert}/model_{domain}.pth")) (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(shared_bert) print(f"Validation acc shared bert: {acc}") model = torch.nn.DataParallel( MultiViewTransformerNetworkProbabilitiesAdversarial( multi_xformer, shared_bert, supervision_layer=args.supervision_layer)).to(device)
def train_domain_classifier( model: torch.nn.Module, train_dl: DataLoader, optimizer: torch.optim.Optimizer, scheduler: LambdaLR, validation_evaluator: MultiDatasetClassificationEvaluator, n_epochs: int, device: AnyStr, class_weights: List, log_interval: int = 1, patience: int = 10, model_dir: str = "wandb_local", gradient_accumulation: int = 1, domain_name: str = ''): #best_loss = float('inf') best_acc = 0.0 patience_counter = 0 epoch_counter = 0 total = sum(len(dl) for dl in train_dls) loss_fn = torch.nn.CrossEntropyLoss( weight=torch.FloatTensor(class_weights).to(device)) # Main loop while epoch_counter < n_epochs: for i, batch in enumerate(tqdm(train_dl)): model.train() batch = tuple(t.to(device) for t in batch) input_ids = batch[0] masks = batch[1] labels = batch[2] # Testing with random domains to see if any effect #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device) domains = batch[3] logits = model(input_ids, attention_mask=masks)[0] loss = loss_fn(logits, domains) loss = loss / gradient_accumulation #if i % gradient_accumulation == 0: loss.backward() optimizer.step() optimizer.zero_grad() if scheduler is not None: scheduler.step() gc.collect() # Inline evaluation (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model) print(f"Validation acc: {acc}") # Saving the best model and early stopping #if val_loss < best_loss: if acc > best_acc: best_model = model.state_dict() best_acc = acc torch.save( model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_domainclassifier_{domain_name}.pth' ) patience_counter = 0 else: patience_counter += 1 # Stop training once we have lost patience if patience_counter == patience: break gc.collect() epoch_counter += 1