Esempio n. 1
0
def meta_validation(datasubset, num_val_tasks, return_uncertainty=False):
    if datasource == 'sine_line':
        x0 = torch.linspace(start=-5, end=5, steps=100,
                            device=device).view(-1, 1)  # vector

        if num_val_tasks == 0:
            from matplotlib import pyplot as plt
            import matplotlib
            matplotlib.rcParams['xtick.labelsize'] = 16
            matplotlib.rcParams['ytick.labelsize'] = 16
            matplotlib.rcParams['axes.labelsize'] = 18

            num_stds = 2
            data_generator = DataGenerator(
                num_samples=num_training_samples_per_class, device=device)
            if datasubset == 'sine':
                x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(
                    noise_flag=True)
                y0 = amp * torch.sin(x0 + phase)
            else:
                x_t, y_t, slope, intercept = data_generator.generate_line_data(
                    noise_flag=True)
                y0 = slope * x0 + intercept

            y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0)
            '''LOAD MAML DATA'''
            maml_folder = '{0:s}/MAML_mixed_sine_line'.format(dst_folder_root)
            maml_filename = 'MAML_mixed_{0:d}shot_{1:s}.pt'.format(
                num_training_samples_per_class, '{0:d}')

            i = 1
            maml_checkpoint_filename = os.path.join(maml_folder,
                                                    maml_filename.format(i))
            while (os.path.exists(maml_checkpoint_filename)):
                i = i + 1
                maml_checkpoint_filename = os.path.join(
                    maml_folder, maml_filename.format(i))
            print(maml_checkpoint_filename)
            maml_checkpoint = torch.load(
                os.path.join(maml_folder, maml_filename.format(i - 1)),
                map_location=lambda storage, loc: storage.cuda(gpu_id))
            theta_maml = maml_checkpoint['theta']
            y_pred_maml = get_task_prediction_maml(x_t=x_t,
                                                   y_t=y_t,
                                                   x_v=x0,
                                                   meta_params=theta_maml)
            '''PLOT'''
            _, ax = plt.subplots(figsize=(5, 5))
            y_top = torch.squeeze(
                torch.mean(y_preds, dim=0) +
                num_stds * torch.std(y_preds, dim=0))
            y_bottom = torch.squeeze(
                torch.mean(y_preds, dim=0) -
                num_stds * torch.std(y_preds, dim=0))

            ax.fill_between(x=torch.squeeze(x0).cpu().numpy(),
                            y1=y_bottom.cpu().detach().numpy(),
                            y2=y_top.cpu().detach().numpy(),
                            alpha=0.25,
                            color='C3',
                            zorder=0,
                            label='VAMPIRE')
            ax.plot(x0.cpu().numpy(),
                    y0.cpu().numpy(),
                    color='C7',
                    linestyle='-',
                    linewidth=3,
                    zorder=1,
                    label='Ground truth')
            ax.plot(x0.cpu().numpy(),
                    y_pred_maml.cpu().detach().numpy(),
                    color='C2',
                    linestyle='--',
                    linewidth=3,
                    zorder=2,
                    label='MAML')
            ax.scatter(x=x_t.cpu().numpy(),
                       y=y_t.cpu().numpy(),
                       color='C0',
                       marker='^',
                       s=300,
                       zorder=3,
                       label='Data')
            plt.xticks([-5, -2.5, 0, 2.5, 5])
            plt.savefig(fname='img/mixed_sine_temp.svg', format='svg')
            return 0
        else:
            from scipy.special import erf

            quantiles = np.arange(start=0., stop=1.1, step=0.1)
            cal_data = []

            data_generator = DataGenerator(
                num_samples=num_training_samples_per_class, device=device)
            for _ in range(num_val_tasks):
                binary_flag = np.random.binomial(n=1, p=p_sine)
                if (binary_flag == 0):
                    # generate sinusoidal data
                    x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(
                        noise_flag=True)
                    y0 = amp * torch.sin(x0 + phase)
                else:
                    # generate line data
                    x_t, y_t, slope, intercept = data_generator.generate_line_data(
                        noise_flag=True)
                    y0 = slope * x0 + intercept
                y0 = y0.view(1, -1).cpu().numpy()  # row vector

                y_preds = torch.stack(
                    get_task_prediction(x_t=x_t, y_t=y_t,
                                        x_v=x0))  # K x len(x0)

                y_preds_np = torch.squeeze(y_preds,
                                           dim=-1).detach().cpu().numpy()

                y_preds_quantile = np.quantile(a=y_preds_np,
                                               q=quantiles,
                                               axis=0,
                                               keepdims=False)

                # ground truth cdf
                std = data_generator.noise_std
                cal_temp = (1 + erf(
                    (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2
                cal_temp_avg = np.mean(a=cal_temp,
                                       axis=1)  # average for a task
                cal_data.append(cal_temp_avg)
            return cal_data
    else:
        accuracies = []
        corrects = []
        probability_pred = []

        total_validation_samples = (
            num_total_samples_per_class -
            num_training_samples_per_class) * num_classes_per_task

        if datasubset == 'train':
            all_class_data = all_class_train
            embedding_data = embedding_train
        elif datasubset == 'val':
            all_class_data = all_class_val
            embedding_data = embedding_val
        elif datasubset == 'test':
            all_class_data = all_class_test
            embedding_data = embedding_test
        else:
            sys.exit('Unknown datasubset for validation')

        all_class_names = list(all_class_data.keys())
        all_task_names = itertools.combinations(all_class_names,
                                                r=num_classes_per_task)

        if train_flag:
            all_task_names = list(all_task_names)
            random.shuffle(all_task_names)

        task_count = 0
        for class_labels in all_task_names:
            x_t, y_t, x_v, y_v = get_task_image_data(
                all_class_data, embedding_data, class_labels,
                num_total_samples_per_class, num_training_samples_per_class,
                device)

            y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None)
            y_pred_v = torch.stack(y_pred_v)
            y_pred_v = sm_loss(y_pred_v)
            y_pred = torch.mean(input=y_pred_v, dim=0, keepdim=False)

            prob_pred, labels_pred = torch.max(input=y_pred, dim=1)
            correct = (labels_pred == y_v)
            corrects.extend(correct.detach().cpu().numpy())

            accuracy = torch.sum(correct,
                                 dim=0).item() / total_validation_samples
            accuracies.append(accuracy)

            probability_pred.extend(prob_pred.detach().cpu().numpy())

            task_count += 1
            if not train_flag:
                print(task_count)
            if (task_count >= num_val_tasks):
                break
        if not return_uncertainty:
            return accuracies, all_task_names
        else:
            return corrects, probability_pred
Esempio n. 2
0
def meta_train(train_subset=train_set):
    #region PREPARING DATALOADER
    if datasource == 'sine_line':
        data_generator = DataGenerator(num_samples=num_total_samples_per_class,
                                       device=device)
        # create dummy sampler
        all_class = [0] * 100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)
    else:
        all_class = all_class_train
        embedding = embedding_train
        sampler = torch.utils.data.sampler.RandomSampler(data_source=list(
            all_class.keys()),
                                                         replacement=False)
        train_loader = torch.utils.data.DataLoader(
            dataset=list(all_class.keys()),
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)
    #endregion
    print('Start to train...')
    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        meta_loss_saved = []  # meta loss to save
        val_accuracies = []
        train_accuracies = []

        meta_loss = 0  # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        meta_loss_avg_print = 0  # compute loss average to print

        meta_loss_avg_save = []  # meta loss to save

        task_count = 0  # a counter to decide when a minibatch of task is completed to perform meta update

        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True)
                else:
                    x_t, y_t, x_v, y_v = get_task_image_data(
                        all_class, embedding, class_labels,
                        num_total_samples_per_class,
                        num_training_samples_per_class, device)

                loss_NLL = get_task_prediction(x_t, y_t, x_v, y_v)

                if torch.isnan(loss_NLL).item():
                    sys.exit('NaN error')

                # accumulate meta loss
                meta_loss = meta_loss + loss_NLL

                task_count = task_count + 1

                if task_count % num_tasks_per_minibatch == 0:
                    meta_loss = meta_loss / num_tasks_per_minibatch

                    # accumulate into different variables for printing purpose
                    meta_loss_avg_print += meta_loss.item()

                    op_theta.zero_grad()
                    meta_loss.backward()
                    op_theta.step()

                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        meta_loss_avg_save.append(meta_loss_avg_print /
                                                  num_meta_updates_count)
                        print('{0:d}, {1:2.4f}'.format(task_count,
                                                       meta_loss_avg_save[-1]))

                        num_meta_updates_count = 0
                        meta_loss_avg_print = 0

                    if (task_count % num_tasks_save_loss == 0):
                        meta_loss_saved.append(np.mean(meta_loss_avg_save))

                        meta_loss_avg_save = []

                        # print('Saving loss...')
                        # val_accs, _ = meta_validation(
                        #     datasubset=val_set,
                        #     num_val_tasks=num_val_tasks,
                        #     return_uncertainty=False)
                        # val_acc = np.mean(val_accs)
                        # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                        # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                        # val_accuracies.append(val_acc)

                        # train_accs, _ = meta_validation(
                        #     datasubset=train_set,
                        #     num_val_tasks=num_val_tasks,
                        #     return_uncertainty=False)
                        # train_acc = np.mean(train_accs)
                        # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                        # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                        # train_accuracies.append(train_acc)

                    # reset meta loss
                    meta_loss = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1) % num_epochs_save == 0):
            checkpoint = {
                'theta': theta,
                'meta_loss': meta_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict()
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder,
                                                checkpoint_filename))
        print()
Esempio n. 3
0
def meta_train():
    if datasource == 'sine_line':
        data_generator = DataGenerator(num_samples=num_total_samples_per_class,
                                       device=device)

        # create dummy sampler
        all_class = [0] * 100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)
    else:
        all_class = all_class_train
        # all_class.update(all_class_val)
        embedding = embedding_train
        # embedding.update(embedding_val)

        sampler = torch.utils.data.sampler.RandomSampler(data_source=list(
            all_class.keys()),
                                                         replacement=False)

        train_loader = torch.utils.data.DataLoader(
            dataset=list(all_class.keys()),
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)

    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        loss_NLL_saved = []
        kl_loss_saved = []
        d_loss_saved = []
        val_accuracies = []
        train_accuracies = []

        task_count = 0  # a counter to decide when a minibatch of task is completed to perform meta update
        meta_loss = 0  # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        loss_NLL_v = 0
        loss_NLL_avg_print = 0

        kl_loss = 0
        kl_loss_avg_print = 0

        d_loss = 0
        d_loss_avg_print = 0

        loss_NLL_avg_save = []
        kl_loss_avg_save = []
        d_loss_avg_save = []

        task_count = 0

        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True)
                else:
                    x_t, y_t, x_v, y_v = get_task_image_data(
                        all_class, embedding, class_labels,
                        num_total_samples_per_class,
                        num_training_samples_per_class, device)
                loss_NLL, KL_loss, discriminator_loss = get_task_prediction(
                    x_t=x_t,
                    y_t=y_t,
                    x_v=x_v,
                    y_v=y_v,
                    p_dropout=p_dropout_base,
                    p_dropout_g=p_dropout_generator,
                    p_dropout_d=p_dropout_discriminator,
                    p_dropout_e=p_dropout_encoder)

                if torch.isnan(loss_NLL).item():
                    sys.exit('nan')

                loss_NLL_v += loss_NLL.item()

                if (loss_NLL.item() > 1):
                    loss_NLL.data = torch.tensor([1.], device=device)

                # if (discriminator_loss.item() > d_loss_const):
                #     discriminator_loss.data = torch.tensor([d_loss_const], device=device)

                kl_loss = kl_loss + KL_loss

                if (KL_loss.item() < 0):
                    KL_loss.data = torch.tensor([0.], device=device)

                Ri = torch.sqrt((KL_loss + ri_const) /
                                (2 * (total_validation_samples - 1)))
                meta_loss = meta_loss + loss_NLL + Ri

                d_loss = d_loss + discriminator_loss

                task_count = task_count + 1

                if (task_count % num_tasks_per_minibatch == 0):
                    # average over the number of tasks per minibatch
                    meta_loss = meta_loss / num_tasks_per_minibatch
                    loss_NLL_v /= num_tasks_per_minibatch
                    kl_loss = kl_loss / num_tasks_per_minibatch
                    d_loss = d_loss / num_tasks_per_minibatch

                    # accumulate for printing purpose
                    loss_NLL_avg_print += loss_NLL_v
                    kl_loss_avg_print += kl_loss.item()
                    d_loss_avg_print += d_loss.item()

                    # adding R0
                    R0 = 0
                    for key in theta.keys():
                        R0 += theta[key].norm(2)
                    R0 = torch.sqrt((L2_regularization * R0 + r0_const) /
                                    (2 * (num_tasks_per_minibatch - 1)))
                    meta_loss += R0

                    # optimize theta
                    op_theta.zero_grad()
                    meta_loss.backward(retain_graph=True)
                    torch.nn.utils.clip_grad_norm_(parameters=theta.values(),
                                                   max_norm=clip_grad_value)
                    torch.nn.utils.clip_grad_norm_(
                        parameters=w_encoder.values(),
                        max_norm=clip_grad_value)
                    op_theta.step()

                    # optimize the discriminator
                    op_discriminator.zero_grad()
                    d_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        parameters=w_discriminator.values(),
                        max_norm=clip_grad_value)
                    op_discriminator.step()

                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        loss_NLL_avg_save.append(loss_NLL_avg_print /
                                                 num_meta_updates_count)
                        kl_loss_avg_save.append(kl_loss_avg_print /
                                                num_meta_updates_count)
                        d_loss_avg_save.append(d_loss_avg_print /
                                               num_meta_updates_count)
                        print('{0:d}, {1:2.4f}, {2:1.4f}, {3:2.4e}'.format(
                            task_count, loss_NLL_avg_save[-1],
                            kl_loss_avg_save[-1], d_loss_avg_save[-1]))
                        num_meta_updates_count = 0

                        loss_NLL_avg_print = 0
                        kl_loss_avg_print = 0
                        d_loss_avg_print = 0

                    if (task_count % num_tasks_save_loss == 0):
                        loss_NLL_saved.append(np.mean(loss_NLL_avg_save))
                        kl_loss_saved.append(np.mean(kl_loss_avg_save))
                        d_loss_saved.append(np.mean(d_loss_avg_save))

                        loss_NLL_avg_save = []
                        kl_loss_avg_save = []
                        d_loss_avg_save = []

                        # if datasource != 'sine_line':
                        #     val_accs = meta_validation(
                        #         datasubset=val_set,
                        #         num_val_tasks=num_val_tasks)
                        #     val_acc = np.mean(val_accs)
                        #     val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                        #     print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                        #     val_accuracies.append(val_acc)

                        # train_accs = meta_validation(
                        #     datasubset=train_set,
                        #     num_val_tasks=num_val_tasks)
                        # train_acc = np.mean(train_accs)
                        # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                        # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                        # train_accuracies.append(train_acc)

                    # reset meta loss for the next minibatch of tasks
                    meta_loss = 0
                    kl_loss = 0
                    d_loss = 0
                    loss_NLL_v = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1) % num_epochs_save == 0):
            checkpoint = {
                'w_discriminator': w_discriminator,
                'theta': theta,
                'w_encoder': w_encoder,
                'w_encoder_2': w_encoder_2,
                'meta_loss': loss_NLL_saved,
                'kl_loss': kl_loss_saved,
                'd_loss': d_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict(),
                'op_discriminator': op_discriminator.state_dict()
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder,
                                                checkpoint_filename))
            # scheduler.step()
        print()
Esempio n. 4
0
def meta_train():
    if datasource == 'sine_line':
        data_generator = DataGenerator(
            num_samples=num_total_samples_per_class,
            device=device
        )

        # create dummy sampler
        all_class = [0]*100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True
        )
    else:
        all_class = all_class_train
        # all_class.update(all_class_val)
        embedding = embedding_train
        # embedding.update(embedding_val)

        sampler = torch.utils.data.sampler.RandomSampler(data_source=list(all_class.keys()), replacement=False)

        train_loader = torch.utils.data.DataLoader(
            dataset=list(all_class.keys()),
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True
        )
    
    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        meta_loss_saved = [] # meta loss to save
        kl_loss_saved = []
        val_accuracies = []
        train_accuracies = []

        task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update
        meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        meta_loss_avg_print = 0 # compute loss average to print

        kl_loss = 0
        kl_loss_avg_print = 0

        meta_loss_avg_save = [] # meta loss to save
        kl_loss_avg_save = []

        task_count = 0
        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True
                    )
                else:
                    x_t, y_t, x_v, y_v = get_task_image_data(
                        all_class,
                        embedding,
                        class_labels,
                        num_total_samples_per_class,
                        num_training_samples_per_class,
                        device
                    )
                loss_NLL, KL_loss = get_task_prediction(
                    x_t=x_t,
                    y_t=y_t,
                    x_v=x_v,
                    y_v=y_v,
                    p_dropout=p_base_dropout
                )

                meta_loss = meta_loss + loss_NLL

                kl_loss = kl_loss + KL_loss
                task_count = task_count + 1

                if torch.isnan(meta_loss).item():
                    sys.exit('nan')

                if (task_count % num_tasks_per_minibatch == 0):
                    # average over the number of tasks per minibatch
                    meta_loss = meta_loss/num_tasks_per_minibatch
                    kl_loss = kl_loss/num_tasks_per_minibatch

                    # accumulate for printing purpose
                    meta_loss_avg_print += meta_loss.item()
                    kl_loss_avg_print += kl_loss.item()

                    op_theta.zero_grad()
                    meta_loss.backward(retain_graph=True)
                    # torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=1)
                    op_theta.step()
                    
                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count)
                        kl_loss_avg_save.append(kl_loss_avg_print/num_meta_updates_count)
                        print('{0:d}, {1:2.4f}, {2:1.4f}'.format(
                            task_count,
                            meta_loss_avg_save[-1],
                            kl_loss_avg_save[-1]
                        ))
                        num_meta_updates_count = 0

                        meta_loss_avg_print = 0
                        kl_loss_avg_print = 0
                        
                    if (task_count % num_tasks_save_loss == 0):
                        meta_loss_saved.append(np.mean(meta_loss_avg_save))
                        kl_loss_saved.append(np.mean(kl_loss_avg_save))

                        meta_loss_avg_save = []
                        kl_loss_avg_save = []
                        # if datasource != 'sine_line':
                        #     val_accs = meta_validation(
                        #         datasubset=val_set,
                        #         num_val_tasks=num_val_tasks)
                        #     val_acc = np.mean(val_accs)
                        #     val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                        #     print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                        #     val_accuracies.append(val_acc)

                        #     train_accs = meta_validation(
                        #         datasubset=train_set,
                        #         num_val_tasks=num_val_tasks)
                        #     train_acc = np.mean(train_accs)
                        #     train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                        #     print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                        #     train_accuracies.append(train_acc)

                    # reset meta loss for the next minibatch of tasks
                    meta_loss = 0
                    kl_loss = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1)% num_epochs_save == 0):
            checkpoint = {
                'theta': theta,
                'meta_loss': meta_loss_saved,
                'kl_loss': kl_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict(),
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename))
            # scheduler.step()
        print()
Esempio n. 5
0
def meta_validation(datasubset, num_val_tasks, return_uncertainty=False):
    if datasource == 'sine_line':
        from scipy.special import erf
        
        x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1)

        cal_avg = 0

        data_generator = DataGenerator(num_samples=num_training_samples_per_class, device=device)
        for _ in range(num_val_tasks):
            binary_flag = np.random.binomial(n=1, p=p_sine)
            if (binary_flag == 0):
                # generate sinusoidal data
                x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(noise_flag=True)
                y0 = amp*torch.sin(x0 + phase)
            else:
                # generate line data
                x_t, y_t, slope, intercept = data_generator.generate_line_data(noise_flag=True)
                y0 = slope*x0 + intercept
            y0 = y0.view(1, -1).cpu().numpy()
            
            y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0)

            y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy()

            # ground truth cdf
            std = data_generator.noise_std
            cal_temp = (1 + erf((y_preds_np - y0)/(np.sqrt(2)*std)))/2
            cal_temp_avg = np.mean(a=cal_temp, axis=1)
            cal_avg = cal_avg + cal_temp_avg
        cal_avg = cal_avg / num_val_tasks
        return cal_avg
    else:
        accuracies = []
        corrects = []
        probability_pred = []

        total_validation_samples = (num_total_samples_per_class - num_training_samples_per_class)*num_classes_per_task
        
        if datasubset == 'train':
            all_class_data = all_class_train
            embedding_data = embedding_train
        elif datasubset == 'val':
            all_class_data = all_class_val
            embedding_data = embedding_val
        elif datasubset == 'test':
            all_class_data = all_class_test
            embedding_data = embedding_test
        else:
            sys.exit('Unknown datasubset for validation')
        
        all_class_names = list(all_class_data.keys())
        all_task_names = list(itertools.combinations(all_class_names, r=num_classes_per_task))

        if train_flag:
            random.shuffle(all_task_names)

        task_count = 0
        for class_labels in all_task_names:
            x_t, y_t, x_v, y_v = get_task_image_data(
                all_class_data,
                embedding_data,
                class_labels,
                num_total_samples_per_class,
                num_training_samples_per_class,
                device)
            
            y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None)
            y_pred = sm_loss(y_pred_v)

            prob_pred, labels_pred = torch.max(input=y_pred, dim=1)
            correct = (labels_pred == y_v)
            corrects.extend(correct.detach().cpu().numpy())

            accuracy = torch.sum(correct, dim=0).item()/total_validation_samples
            accuracies.append(accuracy)

            probability_pred.extend(prob_pred.detach().cpu().numpy())

            task_count += 1
            if not train_flag:
                print(task_count)
            if task_count >= num_val_tasks:
                break
        if not return_uncertainty:
            return accuracies, all_task_names
        else:
            return corrects, probability_pred