Esempio n. 1
0
def convert_scores_to_numpy(scores):
    if isinstance(scores, list) and isinstance(scores[0], torch.Tensor):
        return utils.to_numpy(torch.stack(scores).flatten())
    if isinstance(scores, torch.Tensor):
        return utils.to_numpy(scores.flatten())
    if isinstance(scores, np.ndarray):
        return scores.flatten()
    raise ValueError("Cannot convert scores to a numpy array.")
Esempio n. 2
0
def ce_gradient_pair_scatter(model,
                             data_loader,
                             d1=0,
                             d2=1,
                             max_num_examples=2000,
                             plt=None):
    if plt is None:
        plt = matplotlib.pyplot
    model.eval()

    pred = utils.apply_on_dataset(model=model,
                                  dataset=data_loader.dataset,
                                  output_keys_regexp='pred',
                                  max_num_examples=max_num_examples,
                                  description='grad-pair-scatter:pred')['pred']
    n_examples = min(len(data_loader.dataset), max_num_examples)
    labels = []
    for idx in range(n_examples):
        labels.append(data_loader.dataset[idx][1])
    labels = torch.tensor(labels, dtype=torch.long)
    labels = F.one_hot(labels, num_classes=model.num_classes).float()
    labels = utils.to_cpu(labels)
    grad_wrt_logits = torch.softmax(pred, dim=-1) - labels
    grad_wrt_logits = utils.to_numpy(grad_wrt_logits)

    fig, ax = plt.subplots(1, figsize=(5, 5))
    plt.scatter(grad_wrt_logits[:, d1], grad_wrt_logits[:, d2])
    ax.set_xlabel(str(d1))
    ax.set_ylabel(str(d2))
    # L = np.percentile(grad_wrt_logits, q=5, axis=0)
    # R = np.percentile(grad_wrt_logits, q=95, axis=0)
    # ax.set_xlim(L[d1], R[d1])
    # ax.set_ylim(L[d2], R[d2])
    ax.set_title('Two coordinates of grad wrt to logits')
    return fig, plt
Esempio n. 3
0
def ce_gradient_norm_histogram(model,
                               data_loader,
                               tensorboard,
                               epoch,
                               name,
                               max_num_examples=5000):
    model.eval()

    pred = utils.apply_on_dataset(model=model,
                                  dataset=data_loader.dataset,
                                  output_keys_regexp='pred',
                                  description='grad-histogram:pred',
                                  max_num_examples=max_num_examples)['pred']
    n_examples = min(len(data_loader.dataset), max_num_examples)
    labels = []
    for idx in range(n_examples):
        labels.append(data_loader.dataset[idx][1])
    labels = torch.tensor(labels, dtype=torch.long)
    labels = F.one_hot(labels, num_classes=model.num_classes).float()
    labels = utils.to_cpu(labels)

    grad_wrt_logits = torch.softmax(pred, dim=-1) - labels
    grad_norms = torch.sum(grad_wrt_logits**2, dim=-1)
    grad_norms = utils.to_numpy(grad_norms)

    try:
        tensorboard.add_histogram(tag=name,
                                  values=grad_norms,
                                  global_step=epoch)
    except ValueError as e:
        print("Tensorboard histogram error: {}".format(e))
Esempio n. 4
0
def pred_gradient_pair_scatter(model,
                               data_loader,
                               d1=0,
                               d2=1,
                               max_num_examples=2000,
                               plt=None):
    if plt is None:
        plt = matplotlib.pyplot
    model.eval()
    grad_pred = utils.apply_on_dataset(
        model=model,
        dataset=data_loader.dataset,
        output_keys_regexp='grad_pred',
        max_num_examples=max_num_examples,
        description='grad-pair-scatter:grad_pred')['grad_pred']
    grad_pred = utils.to_numpy(grad_pred)
    fig, ax = plt.subplots(1, figsize=(5, 5))
    plt.scatter(grad_pred[:, d1], grad_pred[:, d2])
    ax.set_xlabel(str(d1))
    ax.set_ylabel(str(d2))
    # L = np.percentile(grad_pred, q=5, axis=0)
    # R = np.percentile(grad_pred, q=95, axis=0)
    # ax.set_xlim(L[d1], R[d1])
    # ax.set_ylim(L[d2], R[d2])
    ax.set_title('Two coordinates of grad wrt to logits')
    return fig, plt
Esempio n. 5
0
def plot(quantities, data_X, data_Y, half, t):
    q = utils.to_numpy(torch.stack(quantities).flatten())
    order = np.argsort(q)
    top_percent = 10
    top_cnt = int(top_percent / 100.0 * len(quantities))
    indices = order[-top_cnt:]

    fig, ax = plt.subplots()
    color_pallet = ['gray', 'red']
    colors = [
        color_pallet[1] if i in indices else color_pallet[0]
        for i in range(len(quantities))
    ]
    colors = np.array(colors)
    markers_list = ['o', '*']
    for class_idx in range(2):
        mask = (data_Y[:half] == class_idx)
        ax.scatter(data_X[:half][mask][:, 0],
                   data_X[:half][mask][:, 1],
                   alpha=0.6,
                   s=10,
                   c=colors[mask],
                   marker=markers_list[class_idx])
    ax.set_title(f"top {top_percent}% most important samples at time {t}")

    return fig, ax
Esempio n. 6
0
def plot_confusion_matrix(Q, plt=None):
    if plt is None:
        plt = matplotlib.pyplot
    num_classes = Q.shape[0]
    fig, ax = plt.subplots(1, figsize=(5, 5))
    im = ax.imshow(utils.to_numpy(Q))
    fig.colorbar(im)
    ax.set_xticks(range(num_classes))
    ax.set_yticks(range(num_classes))
    ax.set_xlabel('observed')
    ax.set_ylabel('true')
    return fig, plt
Esempio n. 7
0
def process_results(vectors,
                    quantities,
                    meta,
                    exp_name,
                    output_dir,
                    train_data,
                    plt=None,
                    **kwargs):
    if plt is None:
        matplotlib, plt = import_matplotlib(agg=True)
    np.random.seed(int(time.time()))
    random_string = str(np.random.randint(1000000))
    exp_name += '-h' + random_string

    results_dict = make_importance_result_dict(importance_vectors=vectors,
                                               importance_measures=quantities,
                                               meta=meta)

    # save the results
    if output_dir is not None:
        file_path = os.path.join(output_dir, exp_name, 'results.pkl')
        save_results_dict(results_dict=results_dict, file_path=file_path)

    # save the histogram image and extreme examples
    quantities = torch.stack(quantities).flatten()
    quantities = utils.to_numpy(quantities)

    if output_dir is None:
        save_name = None
    else:
        save_name = os.path.join(output_dir, exp_name, 'histogram.pdf')
    fig, ax = plot_histogram_of_informativeness(
        informativeness_scores=quantities,
        plt=plt,
        save_name=save_name,
        **kwargs)

    order = np.argsort(quantities)
    fig, ax = plot_examples_from_dataset(data=train_data,
                                         indices=order[:10],
                                         plt=plt,
                                         n_rows=1,
                                         **kwargs)
    if output_dir is not None:
        plt.savefig(os.path.join(output_dir, exp_name, 'least-important.pdf'))

    fig, ax = plot_examples_from_dataset(data=train_data,
                                         indices=order[-10:],
                                         plt=plt,
                                         n_rows=1,
                                         **kwargs)
    if output_dir is not None:
        plt.savefig(os.path.join(output_dir, exp_name, 'most-important.pdf'))
Esempio n. 8
0
def compute_exp_matrix(t, eta, ntk, continuous):
    """ computes exp{-t eta ntk} or its discrete counterpart"
    :param t: if time is set to None, then t=infinity (exp_matrix = 0) will be returned (assuming ntk is invertible).
    """
    if t is None:
        return torch.zeros_like(ntk)
    n = ntk.shape[0]
    if continuous:
        exp_matrix = scipy.linalg.expm(-eta * t * utils.to_numpy(ntk))
        exp_matrix = torch.tensor(exp_matrix,
                                  device=ntk.device,
                                  dtype=torch.float)
    else:
        identity_matrix = torch.eye(n, dtype=torch.float, device=ntk.device)
        exp_matrix = torch.matrix_power(identity_matrix - eta * ntk, t)
    return exp_matrix
def estimate_transition(load_from, data_loader, device='cpu', batch_size=256):
    """ Estimates the label noise matrix. The code is adapted form the original implementation.
    Source: https://github.com/giorgiop/loss-correction/.
    """
    assert load_from is not None
    model = utils.load(load_from, methods=methods, device=device)
    pred = utils.apply_on_dataset(model=model,
                                  dataset=data_loader.dataset,
                                  batch_size=batch_size,
                                  cpu=True,
                                  description="Estimating transition matrix",
                                  output_keys_regexp='pred')['pred']
    pred = torch.softmax(pred, dim=1)
    pred = utils.to_numpy(pred)

    c = model.num_classes
    T = np.zeros((c, c))
    filter_outlier = True

    # find a 'perfect example' for each class
    for i in range(c):
        if not filter_outlier:
            idx_best = np.argmax(pred[:, i])
        else:
            thresh = np.percentile(pred[:, i], 97, interpolation='higher')
            robust_eta = pred[:, i]
            robust_eta[robust_eta >= thresh] = 0.0
            idx_best = np.argmax(robust_eta)

        for j in range(c):
            T[i, j] = pred[idx_best, j]

    # row normalize
    row_sums = T.sum(axis=1, keepdims=True)
    T /= row_sums

    T = torch.tensor(T, dtype=torch.float).to(device)
    print(T)

    return T
Esempio n. 10
0
def pred_gradient_norm_histogram(model,
                                 data_loader,
                                 tensorboard,
                                 epoch,
                                 name,
                                 max_num_examples=5000):
    model.eval()
    grad_pred = utils.apply_on_dataset(
        model=model,
        dataset=data_loader.dataset,
        output_keys_regexp='grad_pred',
        description='grad-histogram:grad_pred',
        max_num_examples=max_num_examples)['grad_pred']
    grad_norms = torch.sum(grad_pred**2, dim=-1)
    grad_norms = utils.to_numpy(grad_norms)

    try:
        tensorboard.add_histogram(tag=name,
                                  values=grad_norms,
                                  global_step=epoch)
    except ValueError as e:
        print("Tensorboard histogram error: {}".format(e))
Esempio n. 11
0
    def compute_jacobian(self,
                         model,
                         dataset,
                         cpu=True,
                         description="",
                         output_key='pred',
                         max_num_examples=2**30,
                         num_workers=0,
                         seed=42,
                         **kwargs):
        np.random.seed(seed)
        model.eval()

        if num_workers > 0:
            torch.multiprocessing.set_sharing_strategy('file_system')
            torch.multiprocessing.set_start_method('spawn', force=True)

        n_examples = min(len(dataset), max_num_examples)
        loader = DataLoader(dataset=Subset(dataset, range(n_examples)),
                            batch_size=1,
                            shuffle=False,
                            num_workers=num_workers)

        jacobians = defaultdict(list)

        # loop over the dataset
        n_outputs = None
        for inputs_batch, labels_batch in tqdm(loader, desc=description):
            if isinstance(inputs_batch, torch.Tensor):
                inputs_batch = [inputs_batch]
            if not isinstance(labels_batch, list):
                labels_batch = [labels_batch]

            with torch.set_grad_enabled(True):
                outputs = model.forward(inputs=inputs_batch,
                                        labels=labels_batch,
                                        loader=loader,
                                        **kwargs)
                preds = outputs[output_key][0]
            n_outputs = preds.shape[-1]

            for output_idx in range(n_outputs):
                retain_graph = (output_idx != n_outputs - 1)
                with torch.set_grad_enabled(True):
                    cur_jacobians = torch.autograd.grad(
                        preds[output_idx],
                        model.parameters(),
                        retain_graph=retain_graph)
                if cpu:
                    cur_jacobians = [utils.to_cpu(v) for v in cur_jacobians]

                if self.projection == 'none':
                    for (k, _), v in zip(model.named_parameters(),
                                         cur_jacobians):
                        jacobians[k].append(v)

                if self.projection == 'random-subset':
                    self._prepare_random_subset_proj_indices(
                        model.named_parameters())

                    for (k, _), v in zip(model.named_parameters(),
                                         cur_jacobians):
                        v = v.flatten()
                        n_select = len(self._random_subset_proj_indices[k])
                        v_proj = v[
                            self._random_subset_proj_indices[k]] * np.sqrt(
                                v.shape[0] / n_select)
                        jacobians[k].append(v_proj)

                if self.projection == 'very-sparse':
                    self._prepare_very_sparse_proj_matrix(
                        model.named_parameters())
                    for (k, _), v in zip(model.named_parameters(),
                                         cur_jacobians):
                        # now that the projection matrix is ready, we can project v into the smaller subspace
                        v = v.flatten()
                        v_proj = self._very_sparse_proj_matrix[k].T.dot(
                            utils.to_numpy(v))
                        v_proj = torch.tensor(v_proj,
                                              dtype=v.dtype,
                                              device=v.device)
                        jacobians[k].append(v_proj)

        for k in jacobians:
            jacobians[k] = torch.stack(
                jacobians[k])  # n_samples * n_outputs x n_params
            assert len(jacobians[k]) == n_outputs * n_examples

        return jacobians
Esempio n. 12
0
def plot_summary_of_informativeness(data,
                                    informativeness_scores,
                                    label_names=None,
                                    save_name=None,
                                    plt=None,
                                    is_label_one_hot=False,
                                    **kwargs):
    """
    :param informativeness_scores: np.ndarray of informativeness scores
    """
    if plt is None:
        _, plt = import_matplotlib(agg=True, use_style=False)

    informativeness_scores = convert_scores_to_numpy(informativeness_scores)

    fig = plt.figure(constrained_layout=True, figsize=(24, 5))
    gs = fig.add_gridspec(22, 13)
    ax_left = fig.add_subplot(gs[1:22, :3])

    ys = [torch.tensor(y) for x, y in data]
    if is_label_one_hot:
        ys = [torch.argmax(y) for y in ys]
    ys = np.array([y.item() for y in ys])

    set_ys = sorted(list(set(ys)))

    for y in set_ys:
        mask = (ys == y)
        label = str(y)
        if label_names is not None:
            label = label_names[y]
        ax_left.hist(informativeness_scores[mask],
                     bins=30,
                     label=label,
                     alpha=0.6)
    ax_left.legend()
    ax_left.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
    ax_left.set_xlabel('Informativeness of an example')
    ax_left.set_ylabel('Count')
    ax_left.legend()
    x_pos = ax_left.get_position().x0
    y_pos = ax_left.get_position().y1
    fig.text(x_pos - 0.05, y_pos + 0.065, 'A', size=28, weight='bold')

    order = np.argsort(informativeness_scores)
    least_informative = order[:10]
    most_informative = order[-10:]

    for i in range(10):
        ax = fig.add_subplot(gs[1:11, 3 + i])
        if i == 0:
            x_pos = ax.get_position().x0
            y_pos = ax.get_position().y1
            fig.text(x_pos - 0.05, y_pos + 0.065, 'B', size=28, weight='bold')

        x, y = data[least_informative[i]]
        x = revert_normalization(x, data)[0]
        x = utils.to_numpy(x)
        x = get_image(x)
        ax.imshow(x, vmin=0, vmax=1)
        ax.set_axis_off()

    for i in range(10):
        ax = fig.add_subplot(gs[12:22, 3 + i])
        if i == 0:
            x_pos = ax.get_position().x0
            y_pos = ax.get_position().y1
            fig.text(x_pos - 0.05, y_pos + 0.042, 'C', size=28, weight='bold')

        x, y = data[most_informative[i]]
        x = revert_normalization(x, data)[0]
        x = utils.to_numpy(x)
        x = get_image(x)
        ax.imshow(x, vmin=0, vmax=1)
        ax.set_axis_off()

    if save_name is not None:
        savefig(fig, save_name)

    return fig, None
Esempio n. 13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str, required=True)
    parser.add_argument('--device',
                        '-d',
                        default='cuda',
                        help='specifies the main device')
    parser.add_argument('--seed', type=int, default=42)

    # data parameters
    parser.add_argument('--dataset', '-D', type=str, default='mnist4vs9')
    parser.add_argument('--data_augmentation',
                        '-A',
                        action='store_true',
                        dest='data_augmentation')
    parser.set_defaults(data_augmentation=False)
    parser.add_argument('--error_prob', '-n', type=float, default=0.0)
    parser.add_argument('--num_train_examples', type=int, default=None)
    parser.add_argument('--clean_validation',
                        action='store_true',
                        default=False)
    parser.add_argument('--resize_to_imagenet',
                        action='store_true',
                        dest='resize_to_imagenet')
    parser.set_defaults(resize_to_imagenet=False)
    parser.add_argument('--cache_dataset',
                        action='store_true',
                        dest='cache_dataset')
    parser.set_defaults(cache_dataset=False)

    # hyper-parameters
    parser.add_argument('--model_class',
                        '-m',
                        type=str,
                        default='ClassifierL2')

    parser.add_argument('--l2_reg_coef', type=float, default=0.0)
    parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate')

    parser.add_argument(
        '--output_dir',
        '-o',
        type=str,
        default='sample_info/results/data-summarization/orders/')
    parser.add_argument('--exp_name', '-E', type=str, required=True)

    # which measures to compute
    parser.add_argument('--which_measure',
                        '-w',
                        type=str,
                        required=True,
                        choices=['weights-plain', 'predictions'])

    # NTK arguments
    parser.add_argument('--t', '-t', type=int, default=None)
    parser.add_argument('--projection',
                        type=str,
                        default='none',
                        choices=['none', 'random-subset', 'very-sparse'])
    parser.add_argument('--cpu', dest='cpu', action='store_true')
    parser.set_defaults(cpu=False)
    parser.add_argument('--large_model_regime',
                        dest='large_model_regime',
                        action='store_true')
    parser.add_argument('--random_subset_n_select', type=int, default=2000)
    parser.set_defaults(large_model_regime=False)

    args = parser.parse_args()
    print(args)

    # Load data
    train_data, val_data, test_data, _ = load_data_from_arguments(
        args, build_loaders=False)

    if args.cache_dataset:
        train_data = CacheDatasetWrapper(train_data)
        val_data = CacheDatasetWrapper(val_data)
        test_data = CacheDatasetWrapper(test_data)

    with open(args.config, 'r') as f:
        architecture_args = json.load(f)

    model_class = getattr(methods, args.model_class)

    model = model_class(input_shape=train_data[0][0].shape,
                        architecture_args=architecture_args,
                        l2_reg_coef=args.l2_reg_coef,
                        device=args.device,
                        seed=args.seed)
    model.eval()
    print("Number of parameters: ", utils.get_num_parameters(model))

    iter_idx = 0
    exclude_indices = []

    while len(exclude_indices) / len(train_data) < 0.95:
        print(f"Computing the order for iteration {iter_idx}")

        # Prepare the needed terms
        cur_train_data = SubsetDataWrapper(train_data,
                                           exclude_indices=exclude_indices)
        n = len(cur_train_data)
        ret = prepare_needed_items(model=model,
                                   train_data=cur_train_data,
                                   test_data=val_data,
                                   projection=args.projection,
                                   cpu=args.cpu)

        quantities = None
        order_file_name = None

        # weights without SGD
        if args.which_measure == 'weights-plain':
            _, quantities = weight_stability(
                t=args.t,
                n=n,
                eta=args.lr / n,
                init_params=ret['init_params'],
                jacobians=ret['train_jacobians'],
                ntk=ret['ntk'],
                init_preds=ret['train_init_preds'],
                Y=ret['train_Y'],
                l2_reg_coef=n * args.l2_reg_coef,
                continuous=False,
                without_sgd=True,
                model=model,
                dataset=cur_train_data,
                large_model_regime=args.large_model_regime,
                return_change_vectors=False)

            order_file_name = f'iter{iter_idx}-weights.pkl'

        # test prediction
        if args.which_measure == 'predictions':
            _, quantities = test_pred_stability(
                t=args.t,
                n=n,
                eta=args.lr / n,
                ntk=ret['ntk'],
                test_train_ntk=ret['test_train_ntk'],
                train_init_preds=ret['train_init_preds'],
                test_init_preds=ret['test_init_preds'],
                train_Y=ret['train_Y'],
                l2_reg_coef=n * args.l2_reg_coef,
                continuous=False)

            order_file_name = f'iter{iter_idx}-predictions.pkl'

        # save the order
        relative_order = np.argsort(
            utils.to_numpy(torch.stack(quantities).flatten()))
        absolute_order = [
            cur_train_data.include_indices[rel_idx]
            for rel_idx in relative_order
        ]
        absolute_order = exclude_indices + absolute_order
        file_path = os.path.join(args.output_dir, args.exp_name,
                                 order_file_name)
        utils.make_path(os.path.dirname(file_path))
        with open(file_path, 'wb') as f:
            pickle.dump(absolute_order, f)

        # remove 5% percent of remaining samples
        exclude_count = int(0.05 * len(cur_train_data))
        new_exclude_indices = [
            cur_train_data.include_indices[rel_idx]
            for rel_idx in relative_order[:exclude_count]
        ]
        exclude_indices.extend(new_exclude_indices)
        iter_idx += 1
        print(len(exclude_indices))
def weight_stability(t,
                     n,
                     eta,
                     init_params,
                     jacobians,
                     ntk,
                     init_preds,
                     Y,
                     continuous=False,
                     without_sgd=True,
                     l2_reg_coef=0.0,
                     large_model_regime=False,
                     model=None,
                     dataset=None,
                     return_change_vectors=True,
                     **kwargs):
    """
    :param without_sgd: if without_sgd = True, then only ||w1-w2|| will be returned,
                        otherwise (w1-w2)^T H Sigma^{-1} (w1-w2).
    """
    if l2_reg_coef > 0:
        ntk = ntk + l2_reg_coef * torch.eye(
            ntk.shape[0], dtype=torch.float, device=ntk.device)

    ntk_inv = torch.inverse(ntk)
    old_weights = get_weights_at_time_t(t=t,
                                        eta=eta,
                                        init_params=init_params,
                                        jacobians=jacobians,
                                        ntk=ntk,
                                        ntk_inv=ntk_inv,
                                        init_preds=init_preds,
                                        Y=Y,
                                        continuous=continuous,
                                        large_model_regime=large_model_regime,
                                        model=model,
                                        dataset=dataset,
                                        **kwargs)

    steady_state_inv_cov = None
    if not without_sgd:
        if large_model_regime:
            raise ValueError("SGD formula works only for small models")

        # compute the SGD noise covariance matrix at the end
        assert (model is not None) and (dataset is not None)
        with utils.SetTemporaryParams(model=model, params=old_weights):
            sgd_cov = get_sgd_covariance_full(model=model,
                                              dataset=dataset,
                                              cpu=False,
                                              **kwargs)
            # add small amount of isotropic Gaussian noise to make sgd_cov invertible
            sgd_cov += 1e-10 * torch.eye(
                sgd_cov.shape[0], device=sgd_cov.device, dtype=torch.float)

        # now we compute H Sigma^{-1}
        jacobians_cat = [
            v.view((v.shape[0], -1)) for k, v in jacobians.items()
        ]
        jacobians_cat = torch.cat(jacobians_cat,
                                  dim=1)  # (n_samples * n_outputs, n_params)
        H = torch.mm(jacobians_cat.T, jacobians_cat) + l2_reg_coef * torch.eye(
            jacobians_cat.shape[1], device=ntk.device, dtype=torch.float)
        # steady_state_inv_cov = torch.mm(H, torch.inverse(sgd_cov))
        with utils.Timing(description="Solving the Lyapunov equation"):
            steady_state_cov = solve_continuous_lyapunov(
                a=utils.to_numpy(H), q=utils.to_numpy(sgd_cov))
        steady_state_cov = torch.tensor(steady_state_cov,
                                        dtype=torch.float,
                                        device=ntk.device)
        # add small amount of isotropic Gaussian noise to make steady_state_cov invertible
        steady_state_cov += 1e-10 * torch.eye(steady_state_cov.shape[0],
                                              device=steady_state_cov.device,
                                              dtype=torch.float)
        steady_state_inv_cov = torch.inverse(steady_state_cov)

    change_vectors = []
    change_quantities = []
    n_outputs = init_preds.shape[-1]
    for sample_idx in tqdm(range(n)):
        example_indices = [i for i in range(n) if i != sample_idx]
        example_output_indices = []
        for i in example_indices:
            example_output_indices.extend(
                range(i * n_outputs, (i + 1) * n_outputs))

        new_ntk = ntk.clone()[example_output_indices]
        new_ntk = new_ntk[:, example_output_indices]

        new_ntk_inv = misc.update_ntk_inv(ntk=ntk,
                                          ntk_inv=ntk_inv,
                                          keep_indices=example_output_indices)

        new_init_preds = init_preds[example_indices]
        new_Y = Y[example_indices]

        if not large_model_regime:
            new_jacobians = dict()
            for k, v in jacobians.items():
                new_jacobians[k] = v[example_output_indices]
        else:
            new_jacobians = None

        new_dataset = Subset(dataset, example_indices)

        new_weights = get_weights_at_time_t(
            t=t,
            eta=eta * n / (n - 1),
            init_params=init_params,
            jacobians=new_jacobians,
            ntk=new_ntk,
            ntk_inv=new_ntk_inv,
            init_preds=new_init_preds,
            Y=new_Y,
            continuous=continuous,
            large_model_regime=large_model_regime,
            model=model,
            dataset=new_dataset,
            **kwargs)

        total_change = 0.0

        param_changes = dict()
        for k in old_weights.keys():
            param_changes[k] = (new_weights[k] -
                                old_weights[k]).cpu()  # to save GPU memory

        if return_change_vectors:
            change_vectors.append(param_changes)

        if without_sgd:
            for k in old_weights.keys():
                total_change += torch.sum(param_changes[k]**2)
        else:
            param_changes = [v.flatten() for k, v in param_changes.items()]
            param_changes = torch.cat(param_changes, dim=0)

            total_change = torch.mm(
                param_changes.view((1, -1)),
                torch.mm(steady_state_inv_cov.cpu(), param_changes.view(-1,
                                                                        1)))

        change_quantities.append(total_change)

    return change_vectors, change_quantities
Esempio n. 15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', '-E', type=str, required=True)
    parser.add_argument('--root_dir',
                        '-r',
                        type=str,
                        default='sample_info/results/ground-truth/')
    parser.add_argument('--num_examples', '-n', type=int, required=True)
    args = parser.parse_args()
    print(args)

    # storage for all methods
    results = defaultdict(lambda: defaultdict(list))

    # read ground truths
    mask = read_ground_truth(args, results)

    # read proposed
    read_proposed(args, results)

    # read influence functions
    influence_functions = read_influence_functions(args, results)

    # plot
    keys = ['weights_diff', 'pred_diff']
    for key in keys:
        norms = dict()

        for method, D in results.items():
            vectors = D[key]
            cur_norms = [
                torch.sum(x**2) for idx, x in enumerate(vectors) if mask[idx]
            ]
            cur_norms = torch.stack(cur_norms).flatten()
            norms[method] = utils.to_numpy(cur_norms)

        fig, ax = plt.subplots()
        vmin = np.min(norms['ground-truth'])
        vmax = np.max(norms['ground-truth'])
        ax.set_title(key)
        ax.scatter(norms['ground-truth'],
                   norms['proposed'],
                   label='gt-vs-proposed',
                   s=5)
        ax.set_xlim(left=vmin, right=vmax)
        ax.set_ylim(bottom=vmin, top=vmax)
        # ax.scatter(norms['ground-truth'], norms['influence-functions'], label='gt-vs-influence', s=5)
        ax.set_xlabel('ground_truth')
        ax.legend()
        fig.tight_layout()
        save_path = os.path.join(args.root_dir, 'aggregated', args.exp_name,
                                 f'{key}-norm-scatter.pdf')
        savefig(fig, save_path)

        print("Correlations of proposed:")
        print(np.corrcoef(norms['ground-truth'], norms['proposed']))

        if influence_functions:
            print("Correlations of influence functions:")
            print(
                np.corrcoef(norms['ground-truth'],
                            norms['influence-functions']))