コード例 #1
0
def validate_regression(uncertainty_flag, num_val_tasks=1):
    assert datasource == 'sine_line'
    
    if uncertainty_flag:
        from scipy.special import erf
        cal_avg = 0
    else:
        from matplotlib import pyplot as plt


    data_generator = DataGenerator(num_samples=num_training_samples_per_class)
    std = data_generator.noise_std

    x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1)

    for _ in range(num_val_tasks):
        # throw a coin to see 0 - 'sine' or 1 - 'line'
        binary_flag = np.random.binomial(n=1, p=p_sine)
        if (binary_flag == 0):
            # generate sinusoidal data
            x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(noise_flag=True)
            y0 = amp * np.sin(x0 + phase)
        else:
            # generate line data
            x_t, y_t, slope, intercept = data_generator.generate_line_data(noise_flag=True)
            y0 = slope * x0 + intercept
        
        x_t = torch.tensor(x_t, dtype=torch.float, device=device)
        y_t = torch.tensor(y_t, dtype=torch.float, device=device)
        y0 = y0.numpy().reshape(shape=(1, -1))

        w_task = adapt_to_task(x=x_t, y=y_t, w0=theta)
        y_pred = predict_label_score(x=x0, w=w_task)
        y_pred = torch.squeeze(y_pred, dim=-1).detach().cpu().numpy() # convert to numpy array

        if uncertainty_flag:
            cal_temp = (1 + erf((y_pred - y0) / (np.sqrt(2) * std))) / 2
            cal_temp_avg = np.mean(a=cal_temp, axis=1)
            cal_avg = cal_avg + cal_temp_avg
        else:
            plt.figure(figsize=(4, 4))
            plt.subplot(111)

            plt.scatter(x_t.numpy(), y_t.numpy(), marker='^', label='Training data')
            plt.plot(x0.numpy(), y_pred, linewidth=1, linestyle='-', label='Prediction')
            plt.plot(x0, y0, linewidth=1, linestyle='--', label='Ground-truth')
            plt.xlabel('x')
            plt.ylabel('y')
            plt.legend()
            plt.tight_layout()
            plt.show()
    
    if uncertainty_flag:
        print('Average calibration \'score\' = {0}'.format(cal_avg / num_val_tasks))
コード例 #2
0
ファイル: bmaml_versa.py プロジェクト: lazyotter/ARC
    def meta_validation(self,
                        datasubset,
                        num_val_tasks,
                        return_uncertainty=False):
        x0 = torch.linspace(start=-5, end=5, steps=100,
                            device=self.device).view(-1, 1)  # vector

        quantiles = np.arange(start=0., stop=1.1, step=0.1)
        cal_data = []

        data_generator = DataGenerator(
            num_samples=self.num_training_samples_per_class,
            device=self.device)
        for _ in range(num_val_tasks):

            # generate sinusoidal data
            x_t, y_t, amp, phase, slope = data_generator.generate_sinusoidal_data(
                noise_flag=True)
            y0 = amp * torch.sin(slope * x0 + phase)
            y0 = y0.view(1, -1).cpu().numpy()  # row vector

            y_preds = torch.stack(
                self.get_task_prediction(x_t=x_t, y_t=y_t,
                                         x_v=x0))  # K x len(x0)

            y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy()

            y_preds_quantile = np.quantile(a=y_preds_np,
                                           q=quantiles,
                                           axis=0,
                                           keepdims=False)

            # ground truth cdf
            std = data_generator.noise_std
            cal_temp = (1 + erf(
                (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2
            cal_temp_avg = np.mean(a=cal_temp, axis=1)  # average for a task
            cal_data.append(cal_temp_avg)
        return cal_data
コード例 #3
0
def meta_validation(datasubset, num_val_tasks, return_uncertainty=False):
    if datasource == 'sine_line':
        x0 = torch.linspace(start=-5, end=5, steps=100,
                            device=device).view(-1, 1)  # vector

        if num_val_tasks == 0:
            from matplotlib import pyplot as plt
            import matplotlib
            matplotlib.rcParams['xtick.labelsize'] = 16
            matplotlib.rcParams['ytick.labelsize'] = 16
            matplotlib.rcParams['axes.labelsize'] = 18

            num_stds = 2
            data_generator = DataGenerator(
                num_samples=num_training_samples_per_class, device=device)
            if datasubset == 'sine':
                x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(
                    noise_flag=True)
                y0 = amp * torch.sin(x0 + phase)
            else:
                x_t, y_t, slope, intercept = data_generator.generate_line_data(
                    noise_flag=True)
                y0 = slope * x0 + intercept

            y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0)
            '''LOAD MAML DATA'''
            maml_folder = '{0:s}/MAML_mixed_sine_line'.format(dst_folder_root)
            maml_filename = 'MAML_mixed_{0:d}shot_{1:s}.pt'.format(
                num_training_samples_per_class, '{0:d}')

            i = 1
            maml_checkpoint_filename = os.path.join(maml_folder,
                                                    maml_filename.format(i))
            while (os.path.exists(maml_checkpoint_filename)):
                i = i + 1
                maml_checkpoint_filename = os.path.join(
                    maml_folder, maml_filename.format(i))
            print(maml_checkpoint_filename)
            maml_checkpoint = torch.load(
                os.path.join(maml_folder, maml_filename.format(i - 1)),
                map_location=lambda storage, loc: storage.cuda(gpu_id))
            theta_maml = maml_checkpoint['theta']
            y_pred_maml = get_task_prediction_maml(x_t=x_t,
                                                   y_t=y_t,
                                                   x_v=x0,
                                                   meta_params=theta_maml)
            '''PLOT'''
            _, ax = plt.subplots(figsize=(5, 5))
            y_top = torch.squeeze(
                torch.mean(y_preds, dim=0) +
                num_stds * torch.std(y_preds, dim=0))
            y_bottom = torch.squeeze(
                torch.mean(y_preds, dim=0) -
                num_stds * torch.std(y_preds, dim=0))

            ax.fill_between(x=torch.squeeze(x0).cpu().numpy(),
                            y1=y_bottom.cpu().detach().numpy(),
                            y2=y_top.cpu().detach().numpy(),
                            alpha=0.25,
                            color='C3',
                            zorder=0,
                            label='VAMPIRE')
            ax.plot(x0.cpu().numpy(),
                    y0.cpu().numpy(),
                    color='C7',
                    linestyle='-',
                    linewidth=3,
                    zorder=1,
                    label='Ground truth')
            ax.plot(x0.cpu().numpy(),
                    y_pred_maml.cpu().detach().numpy(),
                    color='C2',
                    linestyle='--',
                    linewidth=3,
                    zorder=2,
                    label='MAML')
            ax.scatter(x=x_t.cpu().numpy(),
                       y=y_t.cpu().numpy(),
                       color='C0',
                       marker='^',
                       s=300,
                       zorder=3,
                       label='Data')
            plt.xticks([-5, -2.5, 0, 2.5, 5])
            plt.savefig(fname='img/mixed_sine_temp.svg', format='svg')
            return 0
        else:
            from scipy.special import erf

            quantiles = np.arange(start=0., stop=1.1, step=0.1)
            cal_data = []

            data_generator = DataGenerator(
                num_samples=num_training_samples_per_class, device=device)
            for _ in range(num_val_tasks):
                binary_flag = np.random.binomial(n=1, p=p_sine)
                if (binary_flag == 0):
                    # generate sinusoidal data
                    x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(
                        noise_flag=True)
                    y0 = amp * torch.sin(x0 + phase)
                else:
                    # generate line data
                    x_t, y_t, slope, intercept = data_generator.generate_line_data(
                        noise_flag=True)
                    y0 = slope * x0 + intercept
                y0 = y0.view(1, -1).cpu().numpy()  # row vector

                y_preds = torch.stack(
                    get_task_prediction(x_t=x_t, y_t=y_t,
                                        x_v=x0))  # K x len(x0)

                y_preds_np = torch.squeeze(y_preds,
                                           dim=-1).detach().cpu().numpy()

                y_preds_quantile = np.quantile(a=y_preds_np,
                                               q=quantiles,
                                               axis=0,
                                               keepdims=False)

                # ground truth cdf
                std = data_generator.noise_std
                cal_temp = (1 + erf(
                    (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2
                cal_temp_avg = np.mean(a=cal_temp,
                                       axis=1)  # average for a task
                cal_data.append(cal_temp_avg)
            return cal_data
    else:
        accuracies = []
        corrects = []
        probability_pred = []

        total_validation_samples = (
            num_total_samples_per_class -
            num_training_samples_per_class) * num_classes_per_task

        if datasubset == 'train':
            all_class_data = all_class_train
            embedding_data = embedding_train
        elif datasubset == 'val':
            all_class_data = all_class_val
            embedding_data = embedding_val
        elif datasubset == 'test':
            all_class_data = all_class_test
            embedding_data = embedding_test
        else:
            sys.exit('Unknown datasubset for validation')

        all_class_names = list(all_class_data.keys())
        all_task_names = itertools.combinations(all_class_names,
                                                r=num_classes_per_task)

        if train_flag:
            all_task_names = list(all_task_names)
            random.shuffle(all_task_names)

        task_count = 0
        for class_labels in all_task_names:
            x_t, y_t, x_v, y_v = get_task_image_data(
                all_class_data, embedding_data, class_labels,
                num_total_samples_per_class, num_training_samples_per_class,
                device)

            y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None)
            y_pred_v = torch.stack(y_pred_v)
            y_pred_v = sm_loss(y_pred_v)
            y_pred = torch.mean(input=y_pred_v, dim=0, keepdim=False)

            prob_pred, labels_pred = torch.max(input=y_pred, dim=1)
            correct = (labels_pred == y_v)
            corrects.extend(correct.detach().cpu().numpy())

            accuracy = torch.sum(correct,
                                 dim=0).item() / total_validation_samples
            accuracies.append(accuracy)

            probability_pred.extend(prob_pred.detach().cpu().numpy())

            task_count += 1
            if not train_flag:
                print(task_count)
            if (task_count >= num_val_tasks):
                break
        if not return_uncertainty:
            return accuracies, all_task_names
        else:
            return corrects, probability_pred
コード例 #4
0
def validate_regression(uncertainty_flag, num_val_tasks=1):
    assert datasource == 'sine_line'

    if uncertainty_flag:
        from scipy.special import erf
        quantiles = np.arange(start=0., stop=1.1, step=0.1)
        filename = 'VAMPIRE_calibration_{0:s}_{1:d}shot_{2:d}.csv'.format(
            datasource, num_training_samples_per_class, resume_epoch)
        outfile = open(file=os.path.join('csv', filename), mode='w')
        wr = csv.writer(outfile, quoting=csv.QUOTE_NONE)
    else:  # visualization
        from matplotlib import pyplot as plt
        num_stds_plot = 2

    data_generator = DataGenerator(num_samples=num_training_samples_per_class)
    std = data_generator.noise_std

    x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1)

    for _ in range(num_val_tasks):
        # throw a coin to see 0 - 'sine' or 1 - 'line'
        binary_flag = np.random.binomial(n=1, p=p_sine)
        if (binary_flag == 0):
            # generate sinusoidal data
            x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(
                noise_flag=True)
            y0 = amp * np.sin(x0 + phase)
        else:
            # generate line data
            x_t, y_t, slope, intercept = data_generator.generate_line_data(
                noise_flag=True)
            y0 = slope * x0 + intercept

        x_t = torch.tensor(x_t, dtype=torch.float, device=device)
        y_t = torch.tensor(y_t, dtype=torch.float, device=device)
        y0 = y0.numpy().reshape(shape=(1, -1))

        q = adapt_to_task(x=x_t, y=y_t, theta0=theta)
        y_pred = predict(x=x0, q=q, num_models=Lv)
        y_pred = torch.squeeze(y_pred, dim=-1).detach().cpu().numpy(
        )  # convert to numpy array Lv x len(x0)

        if uncertainty_flag:
            # each column in y_pred represents a distribution for that x0-value at that column
            # hence, we calculate the quantile along axis 0
            y_preds_quantile = np.quantile(a=y_pred,
                                           q=quantiles,
                                           axis=0,
                                           keepdims=False)

            # ground truth cdf
            cal_temp = (1 + erf(
                (y_preds_quantile - y0) / (np.sqrt(2) * std))) / 2
            cal_temp_avg = np.mean(a=cal_temp, axis=1)  # average for a task
            wr.writerow(cal_temp_avg)
        else:
            y_mean = np.mean(a=y_pred, axis=0)
            y_std = np.std(a=y_pred, axis=0)
            y_top = y_mean + num_stds_plot * y_std
            y_bottom = y_mean - num_stds_plot * y_std

            plt.figure(figsize=(4, 4))

            plt.scatter(x_t.numpy(),
                        y_t.numpy(),
                        marker='^',
                        label='Training data')
            plt.fill_between(x=torch.squeeze(x0).cpu().numpy(),
                             y1=y_bottom.cpu().detach().numpy(),
                             y2=y_top.cpu().detach().numpy(),
                             alpha=0.25,
                             zorder=0,
                             label='Prediction')
            plt.plot(x0, y0, linewidth=1, linestyle='--', label='Ground-truth')
            plt.xlabel('x')
            plt.ylabel('y')
            plt.legend()
            plt.tight_layout()
            plt.show()
    if uncertainty_flag:
        outfile.close()
        print('Reliability data is stored at {0:s}'.format(
            os.path.join('csv', filename)))
コード例 #5
0
def meta_validation(datasubset, num_val_tasks, return_uncertainty=False):
    if datasource == 'sine_line':
        from scipy.special import erf
        
        x0 = torch.linspace(start=-5, end=5, steps=100, device=device).view(-1, 1)

        cal_avg = 0

        data_generator = DataGenerator(num_samples=num_training_samples_per_class, device=device)
        for _ in range(num_val_tasks):
            binary_flag = np.random.binomial(n=1, p=p_sine)
            if (binary_flag == 0):
                # generate sinusoidal data
                x_t, y_t, amp, phase = data_generator.generate_sinusoidal_data(noise_flag=True)
                y0 = amp*torch.sin(x0 + phase)
            else:
                # generate line data
                x_t, y_t, slope, intercept = data_generator.generate_line_data(noise_flag=True)
                y0 = slope*x0 + intercept
            y0 = y0.view(1, -1).cpu().numpy()
            
            y_preds = get_task_prediction(x_t=x_t, y_t=y_t, x_v=x0)

            y_preds_np = torch.squeeze(y_preds, dim=-1).detach().cpu().numpy()

            # ground truth cdf
            std = data_generator.noise_std
            cal_temp = (1 + erf((y_preds_np - y0)/(np.sqrt(2)*std)))/2
            cal_temp_avg = np.mean(a=cal_temp, axis=1)
            cal_avg = cal_avg + cal_temp_avg
        cal_avg = cal_avg / num_val_tasks
        return cal_avg
    else:
        accuracies = []
        corrects = []
        probability_pred = []

        total_validation_samples = (num_total_samples_per_class - num_training_samples_per_class)*num_classes_per_task
        
        if datasubset == 'train':
            all_class_data = all_class_train
            embedding_data = embedding_train
        elif datasubset == 'val':
            all_class_data = all_class_val
            embedding_data = embedding_val
        elif datasubset == 'test':
            all_class_data = all_class_test
            embedding_data = embedding_test
        else:
            sys.exit('Unknown datasubset for validation')
        
        all_class_names = list(all_class_data.keys())
        all_task_names = list(itertools.combinations(all_class_names, r=num_classes_per_task))

        if train_flag:
            random.shuffle(all_task_names)

        task_count = 0
        for class_labels in all_task_names:
            x_t, y_t, x_v, y_v = get_task_image_data(
                all_class_data,
                embedding_data,
                class_labels,
                num_total_samples_per_class,
                num_training_samples_per_class,
                device)
            
            y_pred_v = get_task_prediction(x_t, y_t, x_v, y_v=None)
            y_pred = sm_loss(y_pred_v)

            prob_pred, labels_pred = torch.max(input=y_pred, dim=1)
            correct = (labels_pred == y_v)
            corrects.extend(correct.detach().cpu().numpy())

            accuracy = torch.sum(correct, dim=0).item()/total_validation_samples
            accuracies.append(accuracy)

            probability_pred.extend(prob_pred.detach().cpu().numpy())

            task_count += 1
            if not train_flag:
                print(task_count)
            if task_count >= num_val_tasks:
                break
        if not return_uncertainty:
            return accuracies, all_task_names
        else:
            return corrects, probability_pred