def plot_single_combined_plot(input_folder):

    output_plot_path = input_folder.joinpath("combined_plot.png")
    average_update_list, average_val_acc_list, num_seeds = get_combined_seeds_lists(input_folder)

    visualization.plot_jasons_lineplot(
        average_update_list,
        average_val_acc_list, 
        "updates",
        "val acc",
        f"average plot for {input_folder.name} over {num_seeds} seeds",
        output_plot_path,
        )
Ejemplo n.º 2
0
def train_eval_model(cfg):

    #load data
    train_sentence_to_label, train_label_to_sentences, train_label_to_sentences_aug, test_sentence_to_label, train_sentence_to_encoding, test_sentence_to_encoding = dataloader.load_ap_data(
        cfg)

    # initialize model
    model, loss_fn, optimizer, device = initialize_model(cfg)

    # train the model
    iter_bar = tqdm(range(cfg.total_updates))
    update_num_list = []
    train_loss_list = []
    val_acc_list = []

    Path(f"plots{cfg.seed_num}/{cfg.exp_id}").mkdir(parents=True,
                                                    exist_ok=True)
    writer = open(f"plots{cfg.seed_num}/{cfg.exp_id}/logs.csv", "w")
    mb_size = 64
    target_activated_examples = 64
    avg_percent_activated = 1.0
    percent_activated_list = []

    for update_num in iter_bar:

        anchor, pos, neg = dataloader.generate_triplet_batch(
            train_label_to_sentences_aug,
            train_sentence_to_encoding,
            device,
            mb_size=mb_size)

        model.train()
        model.zero_grad()

        logits = model(anchor, pos, neg)
        train_loss, percent_activated = loss_fn(*logits)

        train_loss.backward()
        optimizer.step()
        percent_activated_list.append(percent_activated)

        if update_num % cfg.eval_interval == 0:

            val_acc = eval_model(
                model,
                device,
                train_sentence_to_label,
                train_label_to_sentences,
                train_sentence_to_encoding,
                test_sentence_to_label,
                test_sentence_to_encoding,
            )
            avg_percent_activated = sum(percent_activated_list) / len(
                percent_activated_list)

            iter_bar_str = (
                f"update {update_num}/{cfg.total_updates}: " +
                f"mb_train_loss={float(train_loss):.4f}, " +
                f"val_acc={float(val_acc):.4f}, " +
                f"percent_activated={float(avg_percent_activated):.3f}, " +
                f"mb_size={mb_size}")
            iter_bar.set_description(iter_bar_str)
            update_num_list.append(update_num)
            val_acc_list.append(val_acc)
            train_loss_list.append(train_loss)
            writer.write(f"{update_num},{val_acc:.4f},{train_loss:.4f}\n")
            percent_activated_list = []
            if cfg.hard_negative_mining == 'semi-hard':
                mb_size = min(
                    int(target_activated_examples / avg_percent_activated),
                    2000)

    visualization.plot_jasons_lineplot(
        update_num_list, train_loss_list, 'updates', 'training loss',
        f"{cfg.exp_id} n_train_c={cfg.train_nc} max_val_acc={max(val_acc_list):.3f}",
        f"plots{cfg.seed_num}/{cfg.exp_id}/train_loss.png")
    visualization.plot_jasons_lineplot(
        update_num_list, val_acc_list, 'updates', 'validation accuracy',
        f"{cfg.exp_id} n_train_c={cfg.train_nc} max_val_acc={max(val_acc_list):.3f}",
        f"plots{cfg.seed_num}/{cfg.exp_id}/val_acc{max(val_acc_list):.3f}.png")
def train_eval_cl_gradual_model(cfg, seed_num):

    #load data
    train_sentence_to_label_orig, train_label_to_sentences_orig, _, test_sentence_to_label, train_sentence_to_encoding_orig, test_sentence_to_encoding = dataloader.load_ap_data_no_aug(
        cfg, seed_num)
    train_label_to_sentences_aug = train_label_to_sentences_orig.copy()
    train_sentence_to_label = train_sentence_to_label_orig.copy()
    train_sentence_to_encoding = train_sentence_to_encoding_orig.copy()

    # initialize model
    model, loss_fn, optimizer, device = triplet_methods.initialize_model(cfg)

    # train the model
    iter_bar = tqdm(range(cfg.total_updates + 1))
    update_num_list = []
    train_loss_list = []
    val_acc_list = []

    output_folder = f"plots/{cfg.exp_id}_nc{cfg.train_nc}_aug{cfg.aug_type}_norig{cfg.n_orig}_curr{cfg.curriculum_type}"
    Path(output_folder).mkdir(parents=True, exist_ok=True)
    writer = open(f"{output_folder}/s{seed_num}_logs.csv", "w")
    mb_size = 64
    target_activated_examples = 64
    avg_percent_activated = 1.0
    percent_activated_list = []

    for update_num in iter_bar:

        ##################################################################################### sampling strategies
        # sample differently based on which stage of curriculum learning you're in
        if cfg.curriculum_type == "curriculum":
            if update_num == cfg.first_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.second_stage_alpha)
            elif update_num == cfg.second_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.third_stage_alpha)
            elif update_num == cfg.third_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.fourth_stage_alpha)
            elif update_num == cfg.fourth_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.fifth_stage_alpha)
            elif update_num == cfg.fifth_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.sixth_stage_alpha)

        elif cfg.curriculum_type == "random":
            if update_num % 50 == 0:
                random_alpha = random.choice([0.1, 0.2, 0.3, 0.4, 0.5])
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, random_alpha)

        elif cfg.curriculum_type == "anti":
            if update_num == 1:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.sixth_stage_alpha)
            elif update_num == cfg.first_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.fifth_stage_alpha)
            elif update_num == cfg.second_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.fourth_stage_alpha)
            elif update_num == cfg.third_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.third_stage_alpha)
            elif update_num == cfg.fourth_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = dataloader.reload_ap_cl_data(
                    train_sentence_to_label, train_label_to_sentences_orig,
                    cfg, cfg.second_stage_alpha)
            elif update_num == cfg.fifth_stage_updates:
                train_sentence_to_label, train_label_to_sentences_aug, train_sentence_to_encoding = train_sentence_to_label_orig, train_label_to_sentences_orig, train_sentence_to_encoding_orig
        #####################################################################################

        anchor, pos, neg = dataloader.generate_triplet_batch(
            train_label_to_sentences_aug,
            train_sentence_to_encoding,
            device,
            mb_size=mb_size)

        model.train()
        model.zero_grad()

        logits = model(anchor, pos, neg)
        train_loss, percent_activated = loss_fn(*logits)

        train_loss.backward()
        optimizer.step()
        percent_activated_list.append(percent_activated)

        if update_num % cfg.eval_interval == 0:

            val_acc = triplet_methods.eval_model(
                model,
                device,
                train_sentence_to_label_orig,
                train_label_to_sentences_orig,
                train_sentence_to_encoding_orig,
                test_sentence_to_label,
                test_sentence_to_encoding,
            )
            avg_percent_activated = sum(percent_activated_list) / len(
                percent_activated_list)

            iter_bar_str = (
                f"update {update_num}/{cfg.total_updates}: " +
                f"mb_train_loss={float(train_loss):.4f}, " +
                f"val_acc={float(val_acc):.4f}, " +
                f"percent_activated={float(avg_percent_activated):.3f}, " +
                f"mb_size={mb_size}")
            iter_bar.set_description(iter_bar_str)
            update_num_list.append(update_num)
            val_acc_list.append(val_acc)
            train_loss_list.append(train_loss)
            writer.write(f"{update_num},{val_acc:.4f},{train_loss:.4f}\n")
            percent_activated_list = []
            if cfg.hard_negative_mining == 'semi-hard':
                mb_size = min(
                    int(target_activated_examples / avg_percent_activated),
                    2000)

    visualization.plot_jasons_lineplot(
        update_num_list, train_loss_list, 'updates', 'training loss',
        f"{cfg.exp_id} n_c={cfg.train_nc} aug={cfg.aug_type} curr={cfg.curriculum_type} max_val_acc={max(val_acc_list):.3f}",
        f"{output_folder}/s{seed_num}_train_loss.png")
    visualization.plot_jasons_lineplot(
        update_num_list, val_acc_list, 'updates', 'validation accuracy',
        f"{cfg.exp_id} n_c={cfg.train_nc} aug={cfg.aug_type} curr={cfg.curriculum_type} max_val_acc={max(val_acc_list):.3f}",
        f"{output_folder}/s{seed_num}_val_acc{max(val_acc_list):.3f}.png")
    return max(val_acc_list)
Ejemplo n.º 4
0
def train_mlp(cfg):

    #load data
    train_sentence_to_label, _, train_label_to_sentences, test_sentence_to_label, train_sentence_to_encoding, test_sentence_to_encoding = dataloader.load_ap_data(
        cfg)
    train_x, train_y = mlp_dataloader.get_mlp_train_x_y(
        cfg, train_label_to_sentences, train_sentence_to_encoding)
    test_x, test_y = mlp_dataloader.get_mlp_test_x_y(
        cfg, test_sentence_to_label, test_sentence_to_encoding)

    if cfg.model == "LR":
        model = LR(num_classes=cfg.num_output_classes)
    else:
        model = MLP(num_classes=cfg.num_output_classes)

    optimizer = optim.Adam(params=model.parameters(),
                           lr=cfg.learning_rate,
                           weight_decay=cfg.weight_decay
                           )  #wow, works for even large learning rates
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
                                                 gamma=cfg.decay_gamma)

    num_minibatches_train = int(train_x.shape[0] / cfg.minibatch_size)
    train_loss_list = []
    val_acc_list = []

    ######## training loop ########
    for epoch in range(1, cfg.num_epochs + 1):

        ######## training ########
        model.train(mode=True)

        train_x, train_y = shuffle(train_x, train_y, random_state=cfg.seed_num)

        for minibatch_num in range(num_minibatches_train):

            start_idx = minibatch_num * cfg.minibatch_size
            end_idx = start_idx + cfg.minibatch_size
            train_inputs = torch.from_numpy(train_x[start_idx:end_idx].astype(
                np.float32))
            train_labels = torch.from_numpy(train_y[start_idx:end_idx].astype(
                np.long))

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):

                train_outputs = model(train_inputs)
                train_conf, train_preds = torch.max(train_outputs, dim=1)
                train_loss = nn.CrossEntropyLoss()(input=train_outputs,
                                                   target=train_labels)
                train_loss.backward()
                optimizer.step()
                optimizer.zero_grad()

        train_loss_list.append(train_loss)

        ######## validation ########
        model.train(mode=False)

        val_inputs = torch.from_numpy(test_x.astype(np.float32))
        val_labels = torch.from_numpy(test_y.astype(np.long))

        # Feed forward.
        with torch.set_grad_enabled(mode=False):
            val_outputs = model(val_inputs)
            val_confs, val_preds = torch.max(val_outputs, dim=1)
            val_loss = nn.CrossEntropyLoss()(input=val_outputs,
                                             target=val_labels)
            val_loss_print = val_loss / val_inputs.shape[0]
            val_acc = accuracy_score(test_y, val_preds)
            val_acc_list.append(val_acc)

    Path(f"plots/{cfg.exp_id}").mkdir(parents=True, exist_ok=True)
    visualization.plot_jasons_lineplot(
        None, train_loss_list, 'updates', 'training loss',
        f"{cfg.exp_id} n_train_c={cfg.train_nc} max_val_acc={max(val_acc_list):.3f}",
        f"plots/{cfg.exp_id}/train_loss.png")
    visualization.plot_jasons_lineplot(
        None, val_acc_list, 'updates', 'validation accuracy',
        f"{cfg.exp_id} n_train_c={cfg.train_nc} max_val_acc={max(val_acc_list):.3f}",
        f"plots/{cfg.exp_id}/val_acc{max(val_acc_list):.3f}.png")