Пример #1
0
def attention_grid_search(
    model: torch.nn.Module,
    validation_evaluator: MultiDatasetClassificationEvaluator,
    n_epochs: int,
    seed: int,
):
    best_weights = model.module.weights
    # initial
    (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
    best_f1 = F1
    print(F1)
    # Create the grid search
    param_dict = {
        1: list(range(0, 11)),
        2: list(range(0, 11)),
        3: list(range(0, 11)),
        4: list(range(0, 11)),
        5: list(range(0, 11))
    }
    grid_search_params = ParameterSampler(param_dict,
                                          n_iter=n_epochs,
                                          random_state=seed)
    for d in grid_search_params:
        weights = [v for k, v in sorted(d.items(), key=lambda x: x[0])]
        weights = np.array(weights) / sum(weights)
        model.module.weights = weights
        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Weights: {weights}\tValidation F1: {F1}")

        if F1 > best_f1:
            best_weights = weights
            best_f1 = F1
            # Log to wandb
            wandb.log({
                'Validation accuracy': acc,
                'Validation Precision': P,
                'Validation Recall': R,
                'Validation F1': F1,
                'Validation loss': val_loss
            })

        gc.collect()

    return best_weights
Пример #2
0
def train(
        model: torch.nn.Module,
        train_dls: List[DataLoader],
        optimizer: torch.optim.Optimizer,
        scheduler: LambdaLR,
        validation_evaluator: MultiDatasetClassificationEvaluator,
        n_epochs: int,
        device: AnyStr,
        log_interval: int = 1,
        patience: int = 10,
        model_dir: str = "wandb_local",
        gradient_accumulation: int = 1,
        domain_name: str = ''
):
    #best_loss = float('inf')
    best_f1 = 0.0
    patience_counter = 0

    epoch_counter = 0
    total = sum(len(dl) for dl in train_dls)

    # Main loop
    while epoch_counter < n_epochs:
        dl_iters = [iter(dl) for dl in train_dls]
        dl_idx = list(range(len(dl_iters)))
        finished = [0] * len(dl_iters)
        i = 0
        with tqdm(total=total, desc="Training") as pbar:
            while sum(finished) < len(dl_iters):
                random.shuffle(dl_idx)
                for d in dl_idx:
                    domain_dl = dl_iters[d]
                    batches = []
                    try:
                        for j in range(gradient_accumulation):
                            batches.append(next(domain_dl))
                    except StopIteration:
                        finished[d] = 1
                        if len(batches) == 0:
                            continue
                    optimizer.zero_grad()
                    for batch in batches:
                        model.train()
                        batch = tuple(t.to(device) for t in batch)
                        input_ids = batch[0]
                        masks = batch[1]
                        labels = batch[2]
                        # Testing with random domains to see if any effect
                        #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
                        domains = batch[3]

                        loss, logits, alpha = model(input_ids, attention_mask=masks, domains=domains, labels=labels, ret_alpha = True)
                        loss = loss.mean() / gradient_accumulation
                        if i % log_interval == 0:
                            # wandb.log({
                            #     "Loss": loss.item(),
                            #     "alpha0": alpha[:,0].cpu(),
                            #     "alpha1": alpha[:, 1].cpu(),
                            #     "alpha2": alpha[:, 2].cpu(),
                            #     "alpha_shared": alpha[:, 3].cpu()
                            # })
                            wandb.log({
                                "Loss": loss.item()
                            })

                        loss.backward()
                        i += 1
                        pbar.update(1)

                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Validation f1: {F1}")

        #torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')

        # Saving the best model and early stopping
        #if val_loss < best_loss:
        if F1 > best_f1:
            best_model = model.state_dict()
            #best_loss = val_loss
            best_f1 = F1
            #wandb.run.summary['best_validation_loss'] = best_loss
            torch.save(model.state_dict(), f'{model_dir}/{Path(wandb.run.dir).name}/model_{domain_name}.pth')
            patience_counter = 0
            # Log to wandb
            wandb.log({
                'Validation accuracy': acc,
                'Validation Precision': P,
                'Validation Recall': R,
                'Validation F1': F1,
                'Validation loss': val_loss})
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
        epoch_counter += 1
Пример #3
0
        validation_evaluator = MultiDatasetClassificationEvaluator(
            val_ds, device)

        # Create the model
        bert = DistilBertForSequenceClassification.from_pretrained(
            bert_model, config=bert_config).to(device)
        multi_xformer = MultiDistilBertClassifier(bert_model,
                                                  bert_config,
                                                  n_domains=len(train_dls) -
                                                  1).to(device)
        if args.pretrained_multi_xformer is not None:
            multi_xformer.load_state_dict(
                torch.load(
                    f"{args.pretrained_multi_xformer}/model_{domain}.pth"))
            (val_loss, acc, P, R,
             F1), _ = validation_evaluator.evaluate(multi_xformer)
            print(f"Validation acc multi-xformer: {acc}")

        shared_bert = VanillaBert(bert).to(device)
        if args.pretrained_bert is not None:
            shared_bert.load_state_dict(
                torch.load(f"{args.pretrained_bert}/model_{domain}.pth"))
            (val_loss, acc, P, R,
             F1), _ = validation_evaluator.evaluate(shared_bert)
            print(f"Validation acc shared bert: {acc}")

        model = torch.nn.DataParallel(
            MultiViewTransformerNetworkProbabilitiesAdversarial(
                multi_xformer,
                shared_bert,
                supervision_layer=args.supervision_layer)).to(device)
Пример #4
0
def train_domain_classifier(
        model: torch.nn.Module,
        train_dl: DataLoader,
        optimizer: torch.optim.Optimizer,
        scheduler: LambdaLR,
        validation_evaluator: MultiDatasetClassificationEvaluator,
        n_epochs: int,
        device: AnyStr,
        class_weights: List,
        log_interval: int = 1,
        patience: int = 10,
        model_dir: str = "wandb_local",
        gradient_accumulation: int = 1,
        domain_name: str = ''):
    #best_loss = float('inf')
    best_acc = 0.0
    patience_counter = 0

    epoch_counter = 0
    total = sum(len(dl) for dl in train_dls)
    loss_fn = torch.nn.CrossEntropyLoss(
        weight=torch.FloatTensor(class_weights).to(device))

    # Main loop
    while epoch_counter < n_epochs:
        for i, batch in enumerate(tqdm(train_dl)):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            input_ids = batch[0]
            masks = batch[1]
            labels = batch[2]
            # Testing with random domains to see if any effect
            #domains = torch.tensor(np.random.randint(0, 16, batch[3].shape)).to(device)
            domains = batch[3]

            logits = model(input_ids, attention_mask=masks)[0]
            loss = loss_fn(logits, domains)
            loss = loss / gradient_accumulation

            #if i % gradient_accumulation == 0:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()

        gc.collect()

        # Inline evaluation
        (val_loss, acc, P, R, F1), _ = validation_evaluator.evaluate(model)
        print(f"Validation acc: {acc}")

        # Saving the best model and early stopping
        #if val_loss < best_loss:
        if acc > best_acc:
            best_model = model.state_dict()
            best_acc = acc
            torch.save(
                model.state_dict(),
                f'{model_dir}/{Path(wandb.run.dir).name}/model_domainclassifier_{domain_name}.pth'
            )
            patience_counter = 0
        else:
            patience_counter += 1
            # Stop training once we have lost patience
            if patience_counter == patience:
                break

        gc.collect()
        epoch_counter += 1