예제 #1
0
def validate_regression(uncertainty_flag, num_val_tasks=1):
    assert datasource == 'sine_line'
    
    if uncertainty_flag:
        from scipy.special import erf
        cal_avg = 0
    else:
        from matplotlib import pyplot as plt


    data_generator = DataGenerator(num_samples=num_training_samples_per_class)
    std = data_generator.noise_std

    x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1)

    for _ in range(num_val_tasks):
        # throw a coin to see 0 - 'sine' or 1 - 'line'
        binary_flag = np.random.binomial(n=1, p=p_sine)
        if (binary_flag == 0):
            # generate sinusoidal data
            x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(noise_flag=True)
            y0 = amp * np.sin(x0 + phase)
        else:
            # generate line data
            x_t, y_t, slope, intercept = data_generator.generate_line_data(noise_flag=True)
            y0 = slope * x0 + intercept
        
        x_t = torch.tensor(x_t, dtype=torch.float, device=device)
        y_t = torch.tensor(y_t, dtype=torch.float, device=device)
        y0 = y0.numpy().reshape(shape=(1, -1))

        w_task = adapt_to_task(x=x_t, y=y_t, w0=theta)
        y_pred = predict_label_score(x=x0, w=w_task)
        y_pred = torch.squeeze(y_pred, dim=-1).detach().cpu().numpy() # convert to numpy array

        if uncertainty_flag:
            cal_temp = (1 + erf((y_pred - y0) / (np.sqrt(2) * std))) / 2
            cal_temp_avg = np.mean(a=cal_temp, axis=1)
            cal_avg = cal_avg + cal_temp_avg
        else:
            plt.figure(figsize=(4, 4))
            plt.subplot(111)

            plt.scatter(x_t.numpy(), y_t.numpy(), marker='^', label='Training data')
            plt.plot(x0.numpy(), y_pred, linewidth=1, linestyle='-', label='Prediction')
            plt.plot(x0, y0, linewidth=1, linestyle='--', label='Ground-truth')
            plt.xlabel('x')
            plt.ylabel('y')
            plt.legend()
            plt.tight_layout()
            plt.show()
    
    if uncertainty_flag:
        print('Average calibration \'score\' = {0}'.format(cal_avg / num_val_tasks))
예제 #2
0
    def meta_validation(self,
                        datasubset,
                        num_val_tasks,
                        return_uncertainty=False):
        x0 = torch.linspace(start=-5, end=5, steps=100,
                            device=self.device).view(-1, 1)  # vector

        quantiles = np.arange(start=0., stop=1.1, step=0.1)
        cal_data = []

        data_generator = DataGenerator(
            num_samples=self.num_training_samples_per_class,
            device=self.device)
        for _ in range(num_val_tasks):

            # generate sinusoidal data
            x_t, y_t, amp, phase, slope = data_generator.generate_sinusoidal_data(
                noise_flag=True)
            y0 = amp * torch.sin(slope * x0 + phase)
            y0 = y0.view(1, -1).cpu().numpy()  # row vector

            y_preds = torch.stack(
                self.get_task_prediction(x_t=x_t, y_t=y_t,
                                         x_v=x0))  # K x len(x0)

            y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy()

            y_preds_quantile = np.quantile(a=y_preds_np,
                                           q=quantiles,
                                           axis=0,
                                           keepdims=False)

            # ground truth cdf
            std = data_generator.noise_std
            cal_temp = (1 + erf(
                (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2
            cal_temp_avg = np.mean(a=cal_temp, axis=1)  # average for a task
            cal_data.append(cal_temp_avg)
        return cal_data
예제 #3
0
def meta_validation(datasubset, num_val_tasks, return_uncertainty=False):
    if datasource == 'sine_line':
        x0 = torch.linspace(start=-5, end=5, steps=100,
                            device=device).view(-1, 1)  # vector

        if num_val_tasks == 0:
            from matplotlib import pyplot as plt
            import matplotlib
            matplotlib.rcParams['xtick.labelsize'] = 16
            matplotlib.rcParams['ytick.labelsize'] = 16
            matplotlib.rcParams['axes.labelsize'] = 18

            num_stds = 2
            data_generator = DataGenerator(
                num_samples=num_training_samples_per_class, device=device)
            if datasubset == 'sine':
                x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(
                    noise_flag=True)
                y0 = amp * torch.sin(x0 + phase)
            else:
                x_t, y_t, slope, intercept = data_generator.generate_line_data(
                    noise_flag=True)
                y0 = slope * x0 + intercept

            y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0)
            '''LOAD MAML DATA'''
            maml_folder = '{0:s}/MAML_mixed_sine_line'.format(dst_folder_root)
            maml_filename = 'MAML_mixed_{0:d}shot_{1:s}.pt'.format(
                num_training_samples_per_class, '{0:d}')

            i = 1
            maml_checkpoint_filename = os.path.join(maml_folder,
                                                    maml_filename.format(i))
            while (os.path.exists(maml_checkpoint_filename)):
                i = i + 1
                maml_checkpoint_filename = os.path.join(
                    maml_folder, maml_filename.format(i))
            print(maml_checkpoint_filename)
            maml_checkpoint = torch.load(
                os.path.join(maml_folder, maml_filename.format(i - 1)),
                map_location=lambda storage, loc: storage.cuda(gpu_id))
            theta_maml = maml_checkpoint['theta']
            y_pred_maml = get_task_prediction_maml(x_t=x_t,
                                                   y_t=y_t,
                                                   x_v=x0,
                                                   meta_params=theta_maml)
            '''PLOT'''
            _, ax = plt.subplots(figsize=(5, 5))
            y_top = torch.squeeze(
                torch.mean(y_preds, dim=0) +
                num_stds * torch.std(y_preds, dim=0))
            y_bottom = torch.squeeze(
                torch.mean(y_preds, dim=0) -
                num_stds * torch.std(y_preds, dim=0))

            ax.fill_between(x=torch.squeeze(x0).cpu().numpy(),
                            y1=y_bottom.cpu().detach().numpy(),
                            y2=y_top.cpu().detach().numpy(),
                            alpha=0.25,
                            color='C3',
                            zorder=0,
                            label='VAMPIRE')
            ax.plot(x0.cpu().numpy(),
                    y0.cpu().numpy(),
                    color='C7',
                    linestyle='-',
                    linewidth=3,
                    zorder=1,
                    label='Ground truth')
            ax.plot(x0.cpu().numpy(),
                    y_pred_maml.cpu().detach().numpy(),
                    color='C2',
                    linestyle='--',
                    linewidth=3,
                    zorder=2,
                    label='MAML')
            ax.scatter(x=x_t.cpu().numpy(),
                       y=y_t.cpu().numpy(),
                       color='C0',
                       marker='^',
                       s=300,
                       zorder=3,
                       label='Data')
            plt.xticks([-5, -2.5, 0, 2.5, 5])
            plt.savefig(fname='img/mixed_sine_temp.svg', format='svg')
            return 0
        else:
            from scipy.special import erf

            quantiles = np.arange(start=0., stop=1.1, step=0.1)
            cal_data = []

            data_generator = DataGenerator(
                num_samples=num_training_samples_per_class, device=device)
            for _ in range(num_val_tasks):
                binary_flag = np.random.binomial(n=1, p=p_sine)
                if (binary_flag == 0):
                    # generate sinusoidal data
                    x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(
                        noise_flag=True)
                    y0 = amp * torch.sin(x0 + phase)
                else:
                    # generate line data
                    x_t, y_t, slope, intercept = data_generator.generate_line_data(
                        noise_flag=True)
                    y0 = slope * x0 + intercept
                y0 = y0.view(1, -1).cpu().numpy()  # row vector

                y_preds = torch.stack(
                    get_task_prediction(x_t=x_t, y_t=y_t,
                                        x_v=x0))  # K x len(x0)

                y_preds_np = torch.squeeze(y_preds,
                                           dim=-1).detach().cpu().numpy()

                y_preds_quantile = np.quantile(a=y_preds_np,
                                               q=quantiles,
                                               axis=0,
                                               keepdims=False)

                # ground truth cdf
                std = data_generator.noise_std
                cal_temp = (1 + erf(
                    (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2
                cal_temp_avg = np.mean(a=cal_temp,
                                       axis=1)  # average for a task
                cal_data.append(cal_temp_avg)
            return cal_data
    else:
        accuracies = []
        corrects = []
        probability_pred = []

        total_validation_samples = (
            num_total_samples_per_class -
            num_training_samples_per_class) * num_classes_per_task

        if datasubset == 'train':
            all_class_data = all_class_train
            embedding_data = embedding_train
        elif datasubset == 'val':
            all_class_data = all_class_val
            embedding_data = embedding_val
        elif datasubset == 'test':
            all_class_data = all_class_test
            embedding_data = embedding_test
        else:
            sys.exit('Unknown datasubset for validation')

        all_class_names = list(all_class_data.keys())
        all_task_names = itertools.combinations(all_class_names,
                                                r=num_classes_per_task)

        if train_flag:
            all_task_names = list(all_task_names)
            random.shuffle(all_task_names)

        task_count = 0
        for class_labels in all_task_names:
            x_t, y_t, x_v, y_v = get_task_image_data(
                all_class_data, embedding_data, class_labels,
                num_total_samples_per_class, num_training_samples_per_class,
                device)

            y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None)
            y_pred_v = torch.stack(y_pred_v)
            y_pred_v = sm_loss(y_pred_v)
            y_pred = torch.mean(input=y_pred_v, dim=0, keepdim=False)

            prob_pred, labels_pred = torch.max(input=y_pred, dim=1)
            correct = (labels_pred == y_v)
            corrects.extend(correct.detach().cpu().numpy())

            accuracy = torch.sum(correct,
                                 dim=0).item() / total_validation_samples
            accuracies.append(accuracy)

            probability_pred.extend(prob_pred.detach().cpu().numpy())

            task_count += 1
            if not train_flag:
                print(task_count)
            if (task_count >= num_val_tasks):
                break
        if not return_uncertainty:
            return accuracies, all_task_names
        else:
            return corrects, probability_pred
예제 #4
0
def meta_train(train_subset=train_set):
    #region PREPARING DATALOADER
    if datasource == 'sine_line':
        data_generator = DataGenerator(num_samples=num_total_samples_per_class,
                                       device=device)
        # create dummy sampler
        all_class = [0] * 100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)
    else:
        all_class = all_class_train
        embedding = embedding_train
        sampler = torch.utils.data.sampler.RandomSampler(data_source=list(
            all_class.keys()),
                                                         replacement=False)
        train_loader = torch.utils.data.DataLoader(
            dataset=list(all_class.keys()),
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)
    #endregion
    print('Start to train...')
    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        meta_loss_saved = []  # meta loss to save
        val_accuracies = []
        train_accuracies = []

        meta_loss = 0  # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        meta_loss_avg_print = 0  # compute loss average to print

        meta_loss_avg_save = []  # meta loss to save

        task_count = 0  # a counter to decide when a minibatch of task is completed to perform meta update

        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True)
                else:
                    x_t, y_t, x_v, y_v = get_task_image_data(
                        all_class, embedding, class_labels,
                        num_total_samples_per_class,
                        num_training_samples_per_class, device)

                loss_NLL = get_task_prediction(x_t, y_t, x_v, y_v)

                if torch.isnan(loss_NLL).item():
                    sys.exit('NaN error')

                # accumulate meta loss
                meta_loss = meta_loss + loss_NLL

                task_count = task_count + 1

                if task_count % num_tasks_per_minibatch == 0:
                    meta_loss = meta_loss / num_tasks_per_minibatch

                    # accumulate into different variables for printing purpose
                    meta_loss_avg_print += meta_loss.item()

                    op_theta.zero_grad()
                    meta_loss.backward()
                    op_theta.step()

                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        meta_loss_avg_save.append(meta_loss_avg_print /
                                                  num_meta_updates_count)
                        print('{0:d}, {1:2.4f}'.format(task_count,
                                                       meta_loss_avg_save[-1]))

                        num_meta_updates_count = 0
                        meta_loss_avg_print = 0

                    if (task_count % num_tasks_save_loss == 0):
                        meta_loss_saved.append(np.mean(meta_loss_avg_save))

                        meta_loss_avg_save = []

                        # print('Saving loss...')
                        # val_accs, _ = meta_validation(
                        #     datasubset=val_set,
                        #     num_val_tasks=num_val_tasks,
                        #     return_uncertainty=False)
                        # val_acc = np.mean(val_accs)
                        # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                        # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                        # val_accuracies.append(val_acc)

                        # train_accs, _ = meta_validation(
                        #     datasubset=train_set,
                        #     num_val_tasks=num_val_tasks,
                        #     return_uncertainty=False)
                        # train_acc = np.mean(train_accs)
                        # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                        # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                        # train_accuracies.append(train_acc)

                    # reset meta loss
                    meta_loss = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1) % num_epochs_save == 0):
            checkpoint = {
                'theta': theta,
                'meta_loss': meta_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict()
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder,
                                                checkpoint_filename))
        print()
예제 #5
0
def meta_train():
    if datasource == 'sine_line':
        data_generator = DataGenerator(
            num_samples=num_total_samples_per_class,
            device=device
        )

        # create dummy sampler
        all_class = [0]*100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True
        )
    else:
        all_class = all_class_train
        # all_class.update(all_class_val)
        embedding = embedding_train
        # embedding.update(embedding_val)

        sampler = torch.utils.data.sampler.RandomSampler(data_source=list(all_class.keys()), replacement=False)

        train_loader = torch.utils.data.DataLoader(
            dataset=list(all_class.keys()),
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True
        )
    
    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        meta_loss_saved = [] # meta loss to save
        kl_loss_saved = []
        val_accuracies = []
        train_accuracies = []

        task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update
        meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        meta_loss_avg_print = 0 # compute loss average to print

        kl_loss = 0
        kl_loss_avg_print = 0

        meta_loss_avg_save = [] # meta loss to save
        kl_loss_avg_save = []

        task_count = 0
        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True
                    )
                else:
                    x_t, y_t, x_v, y_v = get_task_image_data(
                        all_class,
                        embedding,
                        class_labels,
                        num_total_samples_per_class,
                        num_training_samples_per_class,
                        device
                    )
                loss_NLL, KL_loss = get_task_prediction(
                    x_t=x_t,
                    y_t=y_t,
                    x_v=x_v,
                    y_v=y_v,
                    p_dropout=p_base_dropout
                )

                meta_loss = meta_loss + loss_NLL

                kl_loss = kl_loss + KL_loss
                task_count = task_count + 1

                if torch.isnan(meta_loss).item():
                    sys.exit('nan')

                if (task_count % num_tasks_per_minibatch == 0):
                    # average over the number of tasks per minibatch
                    meta_loss = meta_loss/num_tasks_per_minibatch
                    kl_loss = kl_loss/num_tasks_per_minibatch

                    # accumulate for printing purpose
                    meta_loss_avg_print += meta_loss.item()
                    kl_loss_avg_print += kl_loss.item()

                    op_theta.zero_grad()
                    meta_loss.backward(retain_graph=True)
                    # torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=1)
                    op_theta.step()
                    
                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count)
                        kl_loss_avg_save.append(kl_loss_avg_print/num_meta_updates_count)
                        print('{0:d}, {1:2.4f}, {2:1.4f}'.format(
                            task_count,
                            meta_loss_avg_save[-1],
                            kl_loss_avg_save[-1]
                        ))
                        num_meta_updates_count = 0

                        meta_loss_avg_print = 0
                        kl_loss_avg_print = 0
                        
                    if (task_count % num_tasks_save_loss == 0):
                        meta_loss_saved.append(np.mean(meta_loss_avg_save))
                        kl_loss_saved.append(np.mean(kl_loss_avg_save))

                        meta_loss_avg_save = []
                        kl_loss_avg_save = []
                        # if datasource != 'sine_line':
                        #     val_accs = meta_validation(
                        #         datasubset=val_set,
                        #         num_val_tasks=num_val_tasks)
                        #     val_acc = np.mean(val_accs)
                        #     val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                        #     print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                        #     val_accuracies.append(val_acc)

                        #     train_accs = meta_validation(
                        #         datasubset=train_set,
                        #         num_val_tasks=num_val_tasks)
                        #     train_acc = np.mean(train_accs)
                        #     train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                        #     print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                        #     train_accuracies.append(train_acc)

                    # reset meta loss for the next minibatch of tasks
                    meta_loss = 0
                    kl_loss = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1)% num_epochs_save == 0):
            checkpoint = {
                'theta': theta,
                'meta_loss': meta_loss_saved,
                'kl_loss': kl_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict(),
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename))
            # scheduler.step()
        print()
예제 #6
0
def meta_train():
    if datasource == 'sine_line':
        data_generator = DataGenerator(num_samples=num_total_samples_per_class,
                                       device=device)

        # create dummy sampler
        all_class = [0] * 100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)
    else:
        all_class = all_class_train
        # all_class.update(all_class_val)
        embedding = embedding_train
        # embedding.update(embedding_val)

        sampler = torch.utils.data.sampler.RandomSampler(data_source=list(
            all_class.keys()),
                                                         replacement=False)

        train_loader = torch.utils.data.DataLoader(
            dataset=list(all_class.keys()),
            batch_size=num_classes_per_task,
            sampler=sampler,
            drop_last=True)

    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        loss_NLL_saved = []
        kl_loss_saved = []
        d_loss_saved = []
        val_accuracies = []
        train_accuracies = []

        task_count = 0  # a counter to decide when a minibatch of task is completed to perform meta update
        meta_loss = 0  # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        loss_NLL_v = 0
        loss_NLL_avg_print = 0

        kl_loss = 0
        kl_loss_avg_print = 0

        d_loss = 0
        d_loss_avg_print = 0

        loss_NLL_avg_save = []
        kl_loss_avg_save = []
        d_loss_avg_save = []

        task_count = 0

        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True)
                else:
                    x_t, y_t, x_v, y_v = get_task_image_data(
                        all_class, embedding, class_labels,
                        num_total_samples_per_class,
                        num_training_samples_per_class, device)
                loss_NLL, KL_loss, discriminator_loss = get_task_prediction(
                    x_t=x_t,
                    y_t=y_t,
                    x_v=x_v,
                    y_v=y_v,
                    p_dropout=p_dropout_base,
                    p_dropout_g=p_dropout_generator,
                    p_dropout_d=p_dropout_discriminator,
                    p_dropout_e=p_dropout_encoder)

                if torch.isnan(loss_NLL).item():
                    sys.exit('nan')

                loss_NLL_v += loss_NLL.item()

                if (loss_NLL.item() > 1):
                    loss_NLL.data = torch.tensor([1.], device=device)

                # if (discriminator_loss.item() > d_loss_const):
                #     discriminator_loss.data = torch.tensor([d_loss_const], device=device)

                kl_loss = kl_loss + KL_loss

                if (KL_loss.item() < 0):
                    KL_loss.data = torch.tensor([0.], device=device)

                Ri = torch.sqrt((KL_loss + ri_const) /
                                (2 * (total_validation_samples - 1)))
                meta_loss = meta_loss + loss_NLL + Ri

                d_loss = d_loss + discriminator_loss

                task_count = task_count + 1

                if (task_count % num_tasks_per_minibatch == 0):
                    # average over the number of tasks per minibatch
                    meta_loss = meta_loss / num_tasks_per_minibatch
                    loss_NLL_v /= num_tasks_per_minibatch
                    kl_loss = kl_loss / num_tasks_per_minibatch
                    d_loss = d_loss / num_tasks_per_minibatch

                    # accumulate for printing purpose
                    loss_NLL_avg_print += loss_NLL_v
                    kl_loss_avg_print += kl_loss.item()
                    d_loss_avg_print += d_loss.item()

                    # adding R0
                    R0 = 0
                    for key in theta.keys():
                        R0 += theta[key].norm(2)
                    R0 = torch.sqrt((L2_regularization * R0 + r0_const) /
                                    (2 * (num_tasks_per_minibatch - 1)))
                    meta_loss += R0

                    # optimize theta
                    op_theta.zero_grad()
                    meta_loss.backward(retain_graph=True)
                    torch.nn.utils.clip_grad_norm_(parameters=theta.values(),
                                                   max_norm=clip_grad_value)
                    torch.nn.utils.clip_grad_norm_(
                        parameters=w_encoder.values(),
                        max_norm=clip_grad_value)
                    op_theta.step()

                    # optimize the discriminator
                    op_discriminator.zero_grad()
                    d_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        parameters=w_discriminator.values(),
                        max_norm=clip_grad_value)
                    op_discriminator.step()

                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        loss_NLL_avg_save.append(loss_NLL_avg_print /
                                                 num_meta_updates_count)
                        kl_loss_avg_save.append(kl_loss_avg_print /
                                                num_meta_updates_count)
                        d_loss_avg_save.append(d_loss_avg_print /
                                               num_meta_updates_count)
                        print('{0:d}, {1:2.4f}, {2:1.4f}, {3:2.4e}'.format(
                            task_count, loss_NLL_avg_save[-1],
                            kl_loss_avg_save[-1], d_loss_avg_save[-1]))
                        num_meta_updates_count = 0

                        loss_NLL_avg_print = 0
                        kl_loss_avg_print = 0
                        d_loss_avg_print = 0

                    if (task_count % num_tasks_save_loss == 0):
                        loss_NLL_saved.append(np.mean(loss_NLL_avg_save))
                        kl_loss_saved.append(np.mean(kl_loss_avg_save))
                        d_loss_saved.append(np.mean(d_loss_avg_save))

                        loss_NLL_avg_save = []
                        kl_loss_avg_save = []
                        d_loss_avg_save = []

                        # if datasource != 'sine_line':
                        #     val_accs = meta_validation(
                        #         datasubset=val_set,
                        #         num_val_tasks=num_val_tasks)
                        #     val_acc = np.mean(val_accs)
                        #     val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                        #     print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                        #     val_accuracies.append(val_acc)

                        # train_accs = meta_validation(
                        #     datasubset=train_set,
                        #     num_val_tasks=num_val_tasks)
                        # train_acc = np.mean(train_accs)
                        # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                        # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                        # train_accuracies.append(train_acc)

                    # reset meta loss for the next minibatch of tasks
                    meta_loss = 0
                    kl_loss = 0
                    d_loss = 0
                    loss_NLL_v = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1) % num_epochs_save == 0):
            checkpoint = {
                'w_discriminator': w_discriminator,
                'theta': theta,
                'w_encoder': w_encoder,
                'w_encoder_2': w_encoder_2,
                'meta_loss': loss_NLL_saved,
                'kl_loss': kl_loss_saved,
                'd_loss': d_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict(),
                'op_discriminator': op_discriminator.state_dict()
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                        .format(datasource,
                                num_classes_per_task,
                                num_training_samples_per_class,
                                epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder,
                                                checkpoint_filename))
            # scheduler.step()
        print()
예제 #7
0
def meta_train():
    if datasource == 'sine_line':
        data_generator = DataGenerator(num_samples=num_samples_per_class)
        # create dummy sampler
        all_class_train = [0] * 10
    else:
        all_class_train, all_data_train = load_dataset(
            dataset_name=datasource,
            subset=train_set
        )
        all_class_val, all_data_val = load_dataset(
            dataset_name=datasource,
            subset=val_set
        )
        all_class_train.update(all_class_val)
        all_data_train.update(all_data_val)
        
    # initialize data loader
    train_loader = initialize_dataloader(
        all_classes=[class_label for class_label in all_class_train],
        num_classes_per_task=num_classes_per_task
    )

    for epoch in range(resume_epoch, resume_epoch + num_epochs):
        # variables used to store information of each epoch for monitoring purpose
        meta_loss_saved = [] # meta loss to save
        val_accuracies = []
        train_accuracies = []

        meta_loss = 0 # accumulate the loss of many ensambling networks to descent gradient for meta update
        num_meta_updates_count = 0

        meta_loss_avg_print = 0 # compute loss average to print

        meta_loss_avg_save = [] # meta loss to save

        task_count = 0 # a counter to decide when a minibatch of task is completed to perform meta update
        while (task_count < num_tasks_per_epoch):
            for class_labels in train_loader:
                if datasource == 'sine_line':
                    x_t, y_t, x_v, y_v = get_task_sine_line_data(
                        data_generator=data_generator,
                        p_sine=p_sine,
                        num_training_samples=num_training_samples_per_class,
                        noise_flag=True
                    )
                    x_t = torch.tensor(x_t, dtype=torch.float, device=device)
                    y_t = torch.tensor(y_t, dtype=torch.float, device=device)
                    x_v = torch.tensor(x_v, dtype=torch.float, device=device)
                    y_v = torch.tensor(y_v, dtype=torch.float, device=device)
                else:
                    x_t, y_t, x_v, y_v = get_train_val_task_data(
                        all_classes=all_class_train,
                        all_data=all_data_train,
                        class_labels=class_labels,
                        num_samples_per_class=num_samples_per_class,
                        num_training_samples_per_class=num_training_samples_per_class,
                        device=device
                    )

                w_task = adapt_to_task(x=x_t, y=y_t, w0=theta)
                y_pred = net.forward(x=x_v, w=w_task)
                
                loss_NLL = loss_fn(input=y_pred, target=y_v)

                if torch.isnan(loss_NLL).item():
                    sys.exit('NaN error')

                # accumulate meta loss
                meta_loss = meta_loss + loss_NLL

                task_count = task_count + 1

                if task_count % num_tasks_per_minibatch == 0:
                    meta_loss = meta_loss/num_tasks_per_minibatch

                    # accumulate into different variables for printing purpose
                    meta_loss_avg_print += meta_loss.item()

                    op_theta.zero_grad()
                    meta_loss.backward()
                    torch.nn.utils.clip_grad_norm_(parameters=theta.values(), max_norm=10)
                    op_theta.step()

                    # Printing losses
                    num_meta_updates_count += 1
                    if (num_meta_updates_count % num_meta_updates_print == 0):
                        meta_loss_avg_save.append(meta_loss_avg_print/num_meta_updates_count)
                        print('{0:d}, {1:2.4f}'.format(
                            task_count,
                            meta_loss_avg_save[-1]
                        ))

                        num_meta_updates_count = 0
                        meta_loss_avg_print = 0
                    
                    if (task_count % num_tasks_save_loss == 0):
                        meta_loss_saved.append(np.mean(meta_loss_avg_save))

                        meta_loss_avg_save = []

                        print('Saving loss...')
                        if datasource != 'sine_line':
                            val_accs = validate_classification(
                                all_classes=all_class_val,
                                all_data=all_data_val,
                                num_val_tasks=100,
                                p_rand=0.5,
                                uncertainty=False,
                                csv_flag=False
                            )
                            val_acc = np.mean(val_accs)
                            val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                            print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                            val_accuracies.append(val_acc)

                            train_accs = validate_classification(
                                all_classes=all_class_train,
                                all_data=all_data_train,
                                num_val_tasks=100,
                                p_rand=0.5,
                                uncertainty=False,
                                csv_flag=False
                            )
                            train_acc = np.mean(train_accs)
                            train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                            print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                            train_accuracies.append(train_acc)
                    
                    # reset meta loss
                    meta_loss = 0

                if (task_count >= num_tasks_per_epoch):
                    break
        if ((epoch + 1)% num_epochs_save == 0):
            checkpoint = {
                'theta': theta,
                'meta_loss': meta_loss_saved,
                'val_accuracy': val_accuracies,
                'train_accuracy': train_accuracies,
                'op_theta': op_theta.state_dict()
            }
            print('SAVING WEIGHTS...')
            checkpoint_filename = 'Epoch_{0:d}.pt'.format(epoch + 1)
            print(checkpoint_filename)
            torch.save(checkpoint, os.path.join(dst_folder, checkpoint_filename))
        scheduler.step()
        print()
예제 #8
0
def validate_regression(uncertainty_flag, num_val_tasks=1):
    assert datasource == 'sine_line'

    if uncertainty_flag:
        from scipy.special import erf
        quantiles = np.arange(start=0., stop=1.1, step=0.1)
        filename = 'VAMPIRE_calibration_{0:s}_{1:d}shot_{2:d}.csv'.format(
            datasource, num_training_samples_per_class, resume_epoch)
        outfile = open(file=os.path.join('csv', filename), mode='w')
        wr = csv.writer(outfile, quoting=csv.QUOTE_NONE)
    else:  # visualization
        from matplotlib import pyplot as plt
        num_stds_plot = 2

    data_generator = DataGenerator(num_samples=num_training_samples_per_class)
    std = data_generator.noise_std

    x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1)

    for _ in range(num_val_tasks):
        # throw a coin to see 0 - 'sine' or 1 - 'line'
        binary_flag = np.random.binomial(n=1, p=p_sine)
        if (binary_flag == 0):
            # generate sinusoidal data
            x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(
                noise_flag=True)
            y0 = amp * np.sin(x0 + phase)
        else:
            # generate line data
            x_t, y_t, slope, intercept = data_generator.generate_line_data(
                noise_flag=True)
            y0 = slope * x0 + intercept

        x_t = torch.tensor(x_t, dtype=torch.float, device=device)
        y_t = torch.tensor(y_t, dtype=torch.float, device=device)
        y0 = y0.numpy().reshape(shape=(1, -1))

        q = adapt_to_task(x=x_t, y=y_t, theta0=theta)
        y_pred = predict(x=x0, q=q, num_models=Lv)
        y_pred = torch.squeeze(y_pred, dim=-1).detach().cpu().numpy(
        )  # convert to numpy array Lv x len(x0)

        if uncertainty_flag:
            # each column in y_pred represents a distribution for that x0-value at that column
            # hence, we calculate the quantile along axis 0
            y_preds_quantile = np.quantile(a=y_pred,
                                           q=quantiles,
                                           axis=0,
                                           keepdims=False)

            # ground truth cdf
            cal_temp = (1 + erf(
                (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2
            cal_temp_avg = np.mean(a=cal_temp, axis=1)  # average for a task
            wr.writerow(cal_temp_avg)
        else:
            y_mean = np.mean(a=y_pred, axis=0)
            y_std = np.std(a=y_pred, axis=0)
            y_top = y_mean + num_stds_plot * y_std
            y_bottom = y_mean - num_stds_plot * y_std

            plt.figure(figsize=(4, 4))

            plt.scatter(x_t.numpy(),
                        y_t.numpy(),
                        marker='^',
                        label='Training data')
            plt.fill_between(x=torch.squeeze(x0).cpu().numpy(),
                             y1=y_bottom.cpu().detach().numpy(),
                             y2=y_top.cpu().detach().numpy(),
                             alpha=0.25,
                             zorder=0,
                             label='Prediction')
            plt.plot(x0, y0, linewidth=1, linestyle='--', label='Ground-truth')
            plt.xlabel('x')
            plt.ylabel('y')
            plt.legend()
            plt.tight_layout()
            plt.show()
    if uncertainty_flag:
        outfile.close()
        print('Reliability data is stored at {0:s}'.format(
            os.path.join('csv', filename)))
예제 #9
0
def meta_validation(datasubset, num_val_tasks, return_uncertainty=False):
    if datasource == 'sine_line':
        from scipy.special import erf
        
        x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1)

        cal_avg = 0

        data_generator = DataGenerator(num_samples=num_training_samples_per_class, device=device)
        for _ in range(num_val_tasks):
            binary_flag = np.random.binomial(n=1, p=p_sine)
            if (binary_flag == 0):
                # generate sinusoidal data
                x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(noise_flag=True)
                y0 = amp*torch.sin(x0 + phase)
            else:
                # generate line data
                x_t, y_t, slope, intercept = data_generator.generate_line_data(noise_flag=True)
                y0 = slope*x0 + intercept
            y0 = y0.view(1, -1).cpu().numpy()
            
            y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0)

            y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy()

            # ground truth cdf
            std = data_generator.noise_std
            cal_temp = (1 + erf((y_preds_np - y0)/(np.sqrt(2)*std)))/2
            cal_temp_avg = np.mean(a=cal_temp, axis=1)
            cal_avg = cal_avg + cal_temp_avg
        cal_avg = cal_avg / num_val_tasks
        return cal_avg
    else:
        accuracies = []
        corrects = []
        probability_pred = []

        total_validation_samples = (num_total_samples_per_class - num_training_samples_per_class)*num_classes_per_task
        
        if datasubset == 'train':
            all_class_data = all_class_train
            embedding_data = embedding_train
        elif datasubset == 'val':
            all_class_data = all_class_val
            embedding_data = embedding_val
        elif datasubset == 'test':
            all_class_data = all_class_test
            embedding_data = embedding_test
        else:
            sys.exit('Unknown datasubset for validation')
        
        all_class_names = list(all_class_data.keys())
        all_task_names = list(itertools.combinations(all_class_names, r=num_classes_per_task))

        if train_flag:
            random.shuffle(all_task_names)

        task_count = 0
        for class_labels in all_task_names:
            x_t, y_t, x_v, y_v = get_task_image_data(
                all_class_data,
                embedding_data,
                class_labels,
                num_total_samples_per_class,
                num_training_samples_per_class,
                device)
            
            y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None)
            y_pred = sm_loss(y_pred_v)

            prob_pred, labels_pred = torch.max(input=y_pred, dim=1)
            correct = (labels_pred == y_v)
            corrects.extend(correct.detach().cpu().numpy())

            accuracy = torch.sum(correct, dim=0).item()/total_validation_samples
            accuracies.append(accuracy)

            probability_pred.extend(prob_pred.detach().cpu().numpy())

            task_count += 1
            if not train_flag:
                print(task_count)
            if task_count >= num_val_tasks:
                break
        if not return_uncertainty:
            return accuracies, all_task_names
        else:
            return corrects, probability_pred
예제 #10
0
    def meta_train(self, train_subset='train'):
        data_generator = DataGenerator(
            num_samples=self.num_total_samples_per_class, device=self.device)

        #create dummy sampler
        all_class = [0] * 100
        sampler = torch.utils.data.sampler.RandomSampler(data_source=all_class)
        train_loader = torch.utils.data.DataLoader(
            dataset=all_class,
            batch_size=self.num_classes_per_task,
            sampler=sampler,
            drop_last=True)

        print('Start to train...')
        for epoch in range(0, self.num_epochs):
            #variables for monitoring
            meta_loss_saved = []
            val_accuracies = []
            train_accuracies = []

            meta_loss = 0  #accumulate the loss of many ensembling networks
            num_meta_updates_count = 0

            meta_loss_avg_print = 0
            #meta_mse_avg_print = 0

            meta_loss_avg_save = []
            #meta_mse_avg_save = []

            task_count = 0

            while (task_count < self.num_tasks_per_epoch):
                for class_labels in train_loader:
                    #print class labels probably
                    x_t, y_t, x_v, y_v = get_task_sine_data(
                        data_generator=data_generator,
                        p_sine=self.p_sine,
                        num_training_samples=self.
                        num_training_samples_per_class,
                        noise_flag=True)

                    chaser, leader, y_pred = self.get_task_prediction(
                        x_t, y_t, x_v, y_v)
                    loss_NLL = self.get_meta_loss(chaser, leader)

                    if torch.isnan(loss_NLL).item():
                        sys.exit('NaN error')

                    meta_loss = meta_loss + loss_NLL
                    #meta_mse = self.loss(y_pred, y_v)

                    task_count = task_count + 1

                    if task_count % self.num_tasks_per_minibatch == 0:
                        meta_loss = meta_loss / self.num_tasks_per_minibatch
                        #meta_mse = meta_mse/self.num_tasks_per_minibatch

                        # accumulate into different variables for printing purpose
                        meta_loss_avg_print += meta_loss.item()
                        #meta_mse_avg_print += meta_mse.item()

                        self.op_theta.zero_grad()
                        meta_loss.backward()
                        self.op_theta.step()

                        # Printing losses
                        num_meta_updates_count += 1
                        if (num_meta_updates_count %
                                self.num_meta_updates_print == 0):
                            meta_loss_avg_save.append(meta_loss_avg_print /
                                                      num_meta_updates_count)
                            #meta_mse_avg_save.append(meta_mse_avg_print/num_meta_updates_count)
                            print('{0:d}, {1:2.4f}, {1:2.4f}'.format(
                                task_count, meta_loss_avg_save[-1]
                                #meta_mse_avg_save[-1]
                            ))

                            num_meta_updates_count = 0
                            meta_loss_avg_print = 0
                            #meta_mse_avg_print = 0

                        if (task_count % self.num_tasks_save_loss == 0):
                            meta_loss_saved.append(np.mean(meta_loss_avg_save))

                            meta_loss_avg_save = []
                            #meta_mse_avg_save = []

                            # print('Saving loss...')
                            # val_accs, _ = meta_validation(
                            #     datasubset=val_set,
                            #     num_val_tasks=num_val_tasks,
                            #     return_uncertainty=False)
                            # val_acc = np.mean(val_accs)
                            # val_ci95 = 1.96*np.std(val_accs)/np.sqrt(num_val_tasks)
                            # print('Validation accuracy = {0:2.4f} +/- {1:2.4f}'.format(val_acc, val_ci95))
                            # val_accuracies.append(val_acc)

                            # train_accs, _ = meta_validation(
                            #     datasubset=train_set,
                            #     num_val_tasks=num_val_tasks,
                            #     return_uncertainty=False)
                            # train_acc = np.mean(train_accs)
                            # train_ci95 = 1.96*np.std(train_accs)/np.sqrt(num_val_tasks)
                            # print('Train accuracy = {0:2.4f} +/- {1:2.4f}\n'.format(train_acc, train_ci95))
                            # train_accuracies.append(train_acc)

                        # reset meta loss
                        meta_loss = 0

                    if (task_count >= self.num_tasks_per_epoch):
                        break
            if ((epoch + 1) % self.num_epochs_save == 0):
                checkpoint = {
                    'theta': self.theta,
                    'meta_loss': meta_loss_saved,
                    'val_accuracy': val_accuracies,
                    'train_accuracy': train_accuracies,
                    'op_theta': self.op_theta.state_dict()
                }
                print('SAVING WEIGHTS...')
                checkpoint_filename = ('{0:s}_{1:d}way_{2:d}shot_{3:d}.pt')\
                   .format('sine_line',
                     self.num_classes_per_task,
                     self.num_training_samples_per_class,
                     epoch + 1)
                print(checkpoint_filename)
                torch.save(checkpoint,
                           os.path.join(self.dst_folder, checkpoint_filename))
                print(checkpoint['meta_loss'])
            print()