Esempio n. 1
0
 def evaluate(self,
              validation_dset: Dataset,
              eval_averaging: AnyStr = 'binary',
              return_labels_logits: bool = False,
              sequence_modeling: bool = False):
     """
     Runs a round of evaluation on the given dataset
     :param validation_dset:
     :return:
     """
     if self.tokenizer is not None:
         pad_token_id = self.tokenizer.pad_token_id
     else:
         pad_token_id = 0
     # Create the validation evaluator
     validation_evaluator = ClassificationEvaluator(
         validation_dset,
         self.device,
         num_labels=self.num_labels[0],
         averaging=eval_averaging,
         pad_token_id=pad_token_id,
         sequence_modeling=sequence_modeling,
         multi_gpu=self.multi_gpu,
         ensemble_edu=self.ensemble_edu,
         ensemble_sent=self.ensemble_sent)
     return validation_evaluator.evaluate(
         self.model, return_labels_logits=return_labels_logits)
def train(model: torch.nn.Module,
          train_dl: DataLoader,
          optimizer: torch.optim.Optimizer,
          scheduler: LambdaLR,
          validation_evaluator: ClassificationEvaluator,
          n_epochs: int,
          device: AnyStr,
          log_interval: int = 1,
          patience: int = 10,
          model_dir: str = "local",
          split: str = ''):
    #best_loss = float('inf')
    best_f1 = 0.0
    patience_counter = 0
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')

    # Main loop
    for ep in range(n_epochs):
        # Training loop
        for i, batch in enumerate(tqdm(train_dl)):
            model.train()
            optimizer.zero_grad()
            batch = tuple(t.to(device) for t in batch)
            input_ids = batch[0]
            masks = batch[1]
            labels = batch[2]
            weights = batch[3]

            (logits, ) = model(input_ids, attention_mask=masks)
            loss = loss_fn(logits.view(-1, 2), labels.view(-1))
            # loss = (loss * weights).sum()
            loss = (loss * weights).mean()

            loss.backward()
            optimizer.step()
            scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)

        print(f"Validation F1: {F1}")

        # Saving the best model and early stopping
        if F1 > best_f1:
            best_model = model.state_dict()
            best_f1 = F1
            torch.save(model.state_dict(), f'{model_dir}/model_{split}.pth')
            patience_counter = 0
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
Esempio n. 3
0
 def evaluate(self,
              validation_dset: Dataset,
              eval_averaging: AnyStr = 'binary'):
     """
     Runs a round of evaluation on the given dataset
     :param validation_dset:
     :return:
     """
     if self.tokenizer is not None:
         pad_token_id = self.tokenizer.pad_token_id
     else:
         pad_token_id = 0
     # Create the validation evaluator
     validation_evaluator = ClassificationEvaluator(
         validation_dset,
         self.device,
         pad_token_id=pad_token_id,
         mlm=True,
         multi_gpu=self.multi_gpu)
     return validation_evaluator.evaluate(self.model)
def train(model: torch.nn.Module,
          train_dl: DataLoader,
          optimizer: torch.optim.Optimizer,
          scheduler: LambdaLR,
          validation_evaluator: ClassificationEvaluator,
          n_epochs: int,
          device: AnyStr,
          log_interval: int = 1,
          patience: int = 10,
          model_dir: AnyStr = "local") -> torch.nn.Module:
    best_loss = float('inf')
    best_f1 = 0
    patience_counter = 0

    # Main loop
    for ep in range(n_epochs):
        # Training loop
        for i, batch in enumerate(tqdm(train_dl)):
            model.train()
            optimizer.zero_grad()
            batch = tuple(t.to(device) for t in batch)
            input_ids = batch[0]
            masks = batch[1]
            labels = batch[2]

            loss, logits = model(input_ids,
                                 attention_mask=masks,
                                 labels=labels)

            loss.backward()
            optimizer.step()
            scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)

        if F1 > best_f1:
            best_model = model.state_dict()
            best_f1 = F1
            torch.save(model.state_dict(), f'{model_dir}/model.pth')
            patience_counter = 0
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
Esempio n. 5
0
        # Load the best weights
        model.load_state_dict(
            torch.load(f'{args.pretrained_model}/model_{domain}.pth'))

        # Calculate the best attention weights with a grid search
        weights = attention_grid_search(model, validation_evaluator, n_epochs,
                                        seed)
        model.module.weights = weights
        print(f"Best weights: {weights}")
        with open(
                f'{args.model_dir}/{Path(wandb.run.dir).name}/weights_{domain}.txt',
                'wt') as f:
            f.write(str(weights))

        evaluator = ClassificationEvaluator(test_dset,
                                            device,
                                            use_domain=False)
        (loss, acc, P, R,
         F1), plots, (labels, logits), votes = evaluator.evaluate(
             model,
             plot_callbacks=[plot_label_distribution],
             return_labels_logits=True,
             return_votes=True)
        print(f"{domain} F1: {F1}")
        print(f"{domain} Accuracy: {acc}")
        print()

        wandb.run.summary[f"{domain}-P"] = P
        wandb.run.summary[f"{domain}-R"] = R
        wandb.run.summary[f"{domain}-F1"] = F1
        wandb.run.summary[f"{domain}-Acc"] = acc
            for idx in subsets[1].indices:
                dataset_idx = bisect.bisect_right(dset.cumulative_sizes, idx)
                if dataset_idx == 0:
                    sample_idx = idx
                else:
                    sample_idx = idx - dset.cumulative_sizes[dataset_idx - 1]
                g.write(f'{dataset_idx},{sample_idx}\n')

        train_ds = subsets[0]
        train_dl = DataLoader(train_ds,
                              batch_size=batch_size,
                              shuffle=True,
                              collate_fn=collate_batch_transformer)

        val_ds = subsets[1]
        validation_evaluator = ClassificationEvaluator(val_ds, device)

        # Create the model
        model = BertForSequenceClassification.from_pretrained(
            bert_model, config=bert_config).to(device)
        if args.pretrained_model is not None:
            weights = {
                k: v
                for k, v in torch.load(args.pretrained_model).items()
                if "classifier" not in k
            }
            model_dict = model.state_dict()
            model_dict.update(weights)
            model.load_state_dict(model_dict)

        # Create the optimizer
Esempio n. 7
0
def train(model: torch.nn.Module,
          train_dl: DataLoader,
          optimizer: torch.optim.Optimizer,
          scheduler: LambdaLR,
          validation_evaluator: ClassificationEvaluator,
          n_epochs: int,
          device: AnyStr,
          log_interval: int = 1,
          patience: int = 10,
          neg_class_weight: float = None,
          model_dir: str = "local",
          split: str = '') -> torch.nn.Module:
    best_loss = float('inf')
    patience_counter = 0
    best_f1 = 0.0
    weights_found = False
    loss_fn = torch.nn.CrossEntropyLoss(
        weight=torch.tensor([neg_class_weight, 1.]).to(device))

    # Main loop
    for ep in range(n_epochs):
        # Training loop
        for i, batch in enumerate(tqdm(train_dl)):
            model.train()
            optimizer.zero_grad()
            batch = tuple(t.to(device) for t in batch)
            input_ids = batch[0]
            masks = batch[1]
            labels = batch[2]

            (logits, ) = model(input_ids, attention_mask=masks)
            loss = loss_fn(logits.view(-1, 2), labels.view(-1))

            loss.backward()
            optimizer.step()
            scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)

        # Saving the best model and early stopping
        if F1 > best_f1:
            weights_found = True
            best_model = model.state_dict()
            # best_loss = val_loss
            best_f1 = F1
            torch.save(model.state_dict(), f'{model_dir}/model_{split}.pth')
            patience_counter = 0
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

    if weights_found == False:
        print("No good weights found, saving weights from last epoch")
        # Save one just in case
        torch.save(model.state_dict(), f'{model_dir}/model_{split}.pth')

    gc.collect()
    return best_f1