Ejemplo n.º 1
0
def attention_grid_search(
    model: torch.nn.Module,
    validation_evaluator: MultiDatasetClassificationEvaluator,
    n_epochs: int,
    seed: int,
):
    best_weights = model.module.weights
    # initial
    (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
    best_f1 = F1
    print(F1)
    # Create the grid search
    param_dict = {
        1: list(range(0, 11)),
        2: list(range(0, 11)),
        3: list(range(0, 11)),
        4: list(range(0, 11)),
        5: list(range(0, 11))
    }
    grid_search_params = ParameterSampler(param_dict,
                                          n_iter=n_epochs,
                                          random_state=seed)
    for d in grid_search_params:
        weights = [v for k, v in sorted(d.items(), key=lambda x: x[0])]
        weights = np.array(weights) / sum(weights)
        model.module.weights = weights
        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Weights: {weights}\tValidation F1: {F1}")

        if F1 > best_f1:
            best_weights = weights
            best_f1 = F1
            # Log to wandb
            wandb.log({
                'Validation accuracy': acc,
                'Validation Precision': P,
                'Validation Recall': R,
                'Validation F1': F1,
                'Validation loss': val_loss
            })

        gc.collect()

    return best_weights
Ejemplo n.º 2
0
                Subset(dset_choices[d], subset_indices[d][0]),
                Subset(dset_choices[d], subset_indices[d][1])
            ] for d in subset_indices]

        train_dls = [
            DataLoader(subset[0],
                       batch_size=8,
                       shuffle=True,
                       collate_fn=collate_batch_transformer)
            for subset in subsets
        ]

        val_ds = [subset[1] for subset in subsets]
        # for vds in val_ds:
        #     print(vds.indices)
        validation_evaluator = MultiDatasetClassificationEvaluator(
            val_ds, device)

        bert = DistilBertForSequenceClassification.from_pretrained(
            bert_model, config=bert_config).to(device)
        # Create the model
        init_weights = None
        shared_bert = VanillaBert(bert).to(device)

        multi_xformer = MultiDistilBertClassifier(
            bert_model,
            bert_config,
            n_domains=len(train_dls),
            init_weights=init_weights).to(device)

        model = torch.nn.DataParallel(
            MultiViewTransformerNetworkSelectiveWeight(multi_xformer,
Ejemplo n.º 3
0
def train(
        model: torch.nn.Module,
        train_dls: List[DataLoader],
        optimizer: torch.optim.Optimizer,
        scheduler: LambdaLR,
        validation_evaluator: MultiDatasetClassificationEvaluator,
        n_epochs: int,
        device: AnyStr,
        log_interval: int = 1,
        patience: int = 10,
        model_dir: str = "wandb_local",
        gradient_accumulation: int = 1,
        domain_name: str = ''
):
    #best_loss = float('inf')
    best_f1 = 0.0
    patience_counter = 0

    epoch_counter = 0
    total = sum(len(dl) for dl in train_dls)

    # Main loop
    while epoch_counter < n_epochs:
        dl_iters = [iter(dl) for dl in train_dls]
        dl_idx = list(range(len(dl_iters)))
        finished = [0] * len(dl_iters)
        i = 0
        with tqdm(total=total, desc="Training") as pbar:
            while sum(finished) < len(dl_iters):
                random.shuffle(dl_idx)
                for d in dl_idx:
                    domain_dl = dl_iters[d]
                    batches = []
                    try:
                        for j in range(gradient_accumulation):
                            batches.append(next(domain_dl))
                    except StopIteration:
                        finished[d] = 1
                        if len(batches) == 0:
                            continue
                    optimizer.zero_grad()
                    for batch in batches:
                        model.train()
                        batch = tuple(t.to(device) for t in batch)
                        input_ids = batch[0]
                        masks = batch[1]
                        labels = batch[2]
                        # Testing with random domains to see if any effect
                        #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
                        domains = batch[3]

                        loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True)
                        loss = loss.mean() / gradient_accumulation
                        if i % log_interval == 0:
                            # wandb.log({
                            #     "Loss": loss.item(),
                            #     "alpha0": alpha[:,0].cpu(),
                            #     "alpha1": alpha[:, 1].cpu(),
                            #     "alpha2": alpha[:, 2].cpu(),
                            #     "alpha_shared": alpha[:, 3].cpu()
                            # })
                            wandb.log({
                                "Loss": loss.item()
                            })

                        loss.backward()
                        i += 1
                        pbar.update(1)

                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Validation f1: {F1}")

        #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')

        # Saving the best model and early stopping
        #if val_loss < best_loss:
        if F1 > best_f1:
            best_model = model.state_dict()
            #best_loss = val_loss
            best_f1 = F1
            #wandb.run.summary['best_validation_loss'] = best_loss
            torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
            patience_counter = 0
            # Log to wandb
            wandb.log({
                'Validation accuracy': acc,
                'Validation Precision': P,
                'Validation Recall': R,
                'Validation F1': F1,
                'Validation loss': val_loss})
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
        epoch_counter += 1
Ejemplo n.º 4
0
                       shuffle=True,
                       collate_fn=collate_batch_transformer)
            for subset in subsets
        ]
        # Add test data for domain adversarial training
        train_dls += [
            DataLoader(test_dset,
                       batch_size=batch_size,
                       shuffle=True,
                       collate_fn=collate_batch_transformer)
        ]

        val_ds = [subset[1] for subset in subsets]
        # for vds in val_ds:
        #     print(vds.indices)
        validation_evaluator = MultiDatasetClassificationEvaluator(
            val_ds, device)

        # Create the model
        bert = DistilBertForSequenceClassification.from_pretrained(
            bert_model, config=bert_config).to(device)
        multi_xformer = MultiDistilBertClassifier(bert_model,
                                                  bert_config,
                                                  n_domains=len(train_dls) -
                                                  1).to(device)
        if args.pretrained_multi_xformer is not None:
            multi_xformer.load_state_dict(
                torch.load(
                    f"{args.pretrained_multi_xformer}/model_{domain}.pth"))
            (val_loss, acc, P, R,
             F1), _ = validation_evaluator.evaluate(multi_xformer)
            print(f"Validation acc multi-xformer: {acc}")
Ejemplo n.º 5
0
def train_domain_classifier(
        model: torch.nn.Module,
        train_dl: DataLoader,
        optimizer: torch.optim.Optimizer,
        scheduler: LambdaLR,
        validation_evaluator: MultiDatasetClassificationEvaluator,
        n_epochs: int,
        device: AnyStr,
        class_weights: List,
        log_interval: int = 1,
        patience: int = 10,
        model_dir: str = "wandb_local",
        gradient_accumulation: int = 1,
        domain_name: str = ''):
    #best_loss = float('inf')
    best_acc = 0.0
    patience_counter = 0

    epoch_counter = 0
    total = sum(len(dl) for dl in train_dls)
    loss_fn = torch.nn.CrossEntropyLoss(
        weight=torch.FloatTensor(class_weights).to(device))

    # Main loop
    while epoch_counter < n_epochs:
        for i, batch in enumerate(tqdm(train_dl)):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids = batch[0]
            masks = batch[1]
            labels = batch[2]
            # Testing with random domains to see if any effect
            #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
            domains = batch[3]

            logits = model(input_ids, attention_mask=masks)[0]
            loss = loss_fn(logits, domains)
            loss = loss / gradient_accumulation

            #if i % gradient_accumulation == 0:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Validation acc: {acc}")

        # Saving the best model and early stopping
        #if val_loss < best_loss:
        if acc > best_acc:
            best_model = model.state_dict()
            best_acc = acc
            torch.save(
                model.state_dict(),
                f'{model_dir}/{Path(wandb.run.dir).name}/model_domainclassifier_{domain_name}.pth'
            )
            patience_counter = 0
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
        epoch_counter += 1
Ejemplo n.º 6
0
                Subset(dset_choices[d], subset_indices[d][1])
            ] for d in subset_indices]

        train_dls = [
            DataLoader(subset[0],
                       batch_size=batch_size,
                       shuffle=True,
                       collate_fn=collate_batch_transformer)
            for subset in subsets
        ]

        val_ds = [subset[1] for subset in subsets]
        # for vds in val_ds:
        #     print(vds.indices)
        validation_evaluators = [
            MultiDatasetClassificationEvaluator([vds], device)
            for vds in val_ds
        ] + [
            MultiDatasetClassificationEvaluator(
                val_ds, device, use_domain=False)
        ]

        # Create the model
        bert = DistilBertForSequenceClassification.from_pretrained(
            bert_model, config=bert_config).to(device)

        model = torch.nn.DataParallel(
            MultiViewTransformerNetworkAveragingIndividuals(
                bert_model, bert_config, len(train_dls))).to(device)
        # (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        # print(f"Validation acc starting: {acc}")
Ejemplo n.º 7
0
                Subset(dset_choices[d], subset_indices[d][1])
            ] for d in subset_indices]

        train_dls = [
            DataLoader(subset[0],
                       batch_size=batch_size,
                       shuffle=True,
                       collate_fn=collate_batch_transformer)
            for subset in subsets
        ]

        val_ds = [subset[1] for subset in subsets]
        # for vds in val_ds:
        #     print(vds.indices)
        validation_evaluators = [
            MultiDatasetClassificationEvaluator([vds], device)
            for vds in val_ds
        ]

        # 1) Create a domain classifier with BERT
        shared_bert_config = DistilBertConfig.from_pretrained(
            bert_model, num_labels=len(train_dls))
        bert = DistilBertForSequenceClassification.from_pretrained(
            bert_model, config=shared_bert_config).to(device)
        shared_bert = VanillaBert(bert).to(device)
        set_sizes = [len(subset[0]) for subset in subsets]
        weights = [1 - (len(subset[0]) / sum(set_sizes)) for subset in subsets]
        domain_classifier_train_dset = ConcatDataset(
            [subset[0] for subset in subsets])
        domain_classifier_train_dl = DataLoader(
            domain_classifier_train_dset,