def get_inverse_hvp(v, itr_train, loss_fn, grad_fn, map_grad_fn,
                    scale=10, damping=0.0, num_samples=1,
                    recursion_depth=10000, print_iter=1):
  """Calculates an HVP, getting the inverse Hessian using the LISSA method.

  Args:
    v (Tensor): the vector in the HVP - the gradient on the test example.
    itr_train (Iterator): iterator for getting batches to estimate H.
    loss_fn (function): a function which returns a gradient of losses.
    grad_fn (function): a function which takes the gradient of a scalar loss.
    map_grad_fn (function): a function which takes the gradient of each element
                            of a vector of losses.
    scale (number): scales eigenvalues of Hessian to be <= 1.
    damping (number): in [0, 1), LISSA parameter - higher is more stable.
    num_samples (int): how many times to estimate the HVP.
    recursion_depth (int): how many steps in LISSA optimization.
    print_iter (int): how frequently to print LISSA updates.
  Returns:
    inverse_hvp (Tensor): the estimated product of the inverse Hessian
                          of clf and v.
  """

  value_constraints = [(scale, 'scale', lambda x: x > 0, 'greater than 0'),
                       (damping, 'damping', lambda x: 0 <= x <= 1,
                        'between 0 and 1 inclusive'),
                       (num_samples, 'num_samples',
                        lambda x: x > 0, 'greater than 0'),
                       (recursion_depth, 'recursion_depth',
                        lambda x: x > 0, 'greater than 0')]
  for var, vname, constraint, msg in value_constraints:
    if not constraint(var):
      raise ValueError('{} should be {}'.format(vname, msg))

  inverse_hvp = [tf.zeros_like(b) for b in v]
  for _ in range(num_samples):
    cur_estimate = v
    logging.info('cur estimate: %s', str([c.shape for c in cur_estimate]))

    for j in range(recursion_depth):
      old_estimate = cur_estimate
      hessian_vector_val = hvp(cur_estimate, itr_train,
                               loss_fn, grad_fn, map_grad_fn)

      cur_estimate = [a + (1-damping) * b - c / scale for (a, b, c)
                      in zip(v, cur_estimate, hessian_vector_val)]

      if (j % print_iter == 0) or (j == recursion_depth - 1):
        logging.info('Recursion at depth %d: norm is %.8lf, diff is %.8lf',
                     j,
                     np.linalg.norm(
                         tensor_utils.flat_concat(cur_estimate).numpy()),
                     np.linalg.norm(
                         tensor_utils.flat_concat(cur_estimate).numpy() -
                         tensor_utils.flat_concat(old_estimate).numpy()))
      old_estimate = cur_estimate
    inverse_hvp = [a + b / scale
                   for (a, b) in zip(inverse_hvp, cur_estimate)]

  inverse_hvp = [a / num_samples for a in inverse_hvp]
  return inverse_hvp
def get_parameter_influence(model, x, y, itr,
                            approx_params=None, damping=None):
  """Estimate the influence of test examples (x, y) on the parameters of model.

  Args:
    model (Classifier): a classification model whose parameters we are
      interested in.
    x (tensor): the input data whose influence we are interested in.
    y (tensor): the target data whose influence we are interested in.
    itr (Iterator): an iterator of data we will use to estimate the Hessian.
    approx_params (dict, optional): parameters for running LiSSA.
    damping (float, optional): the amount of L2-regularization to add
      to the parameters of model (only used for conjugate gradient).
  Returns:
    ihvp_result (tensor): the HVP of the inverse hessian of model (possibly with
      some L2-regularization) with the gradient of (x, y) w.r.t the
      parameters of model.
    concat_grads (tensor): the gradients of (x, y) w.r.t the model parameters.
    warning_flag (int): a flag representing if this optimization terminated
                        successfully, returned by Scipy.
  """
  loss_function = calculate_influence.make_loss_fn(model, damping)
  gradient_function = calculate_influence.make_grad_fn(model)
  map_gradient_function = calculate_influence.make_map_grad_fn(model)
  grads = calculate_influence.get_loss_grads(x, y, loss_function,
                                             map_gradient_function)
  concat_grads = tensor_utils.flat_concat(grads)
  ihvp_result, warning_flag = get_ihvp_conjugate_gradient(
      grads, itr, loss_function, gradient_function, map_gradient_function,
      approx_params)
  return ihvp_result, concat_grads, warning_flag
    def get_hvp(v):
        """Get the Hessian-vector product of v and Hessian in loss_function.

    Args:
      v (vector): a (n * p,)-shaped vector we want to multiply with H.
    Returns:
      hvp (vector): a (n, p)-shaped matrix representing Hv.
    """
        v = tf.reshape(v, concat_vec.shape)
        v = tensor_utils.reshape_vector_as([el[0] for el in vec], v)
        v_hvp = calculate_influence.hvp(v,
                                        itr,
                                        loss_function,
                                        gradient_function,
                                        map_gradient_function,
                                        n_samples=approx_params['hvp_samples'])
        if approx_params['squared']:
            v_hvp = calculate_influence.hvp(
                v_hvp,
                itr,
                loss_function,
                gradient_function,
                map_gradient_function,
                n_samples=approx_params['hvp_samples'])
        return tensor_utils.flat_concat(v_hvp)
Esempio n. 4
0
    def test_reshape_and_flatten(self):
        x_shape = (64, 28, 28, 1)
        y_shape = (64, 10)
        conv_dims = [20, 10]
        conv_sizes = [5, 5]
        dense_sizes = [100]
        n_classes = 10
        model = classifier.CNN(conv_dims,
                               conv_sizes,
                               dense_sizes,
                               n_classes,
                               onehot=True)
        itr = dataset_utils.get_supervised_batch_noise_iterator(
            x_shape, y_shape)
        x, y = itr.next()
        _, _ = model.get_loss(x, y)

        w = model.weights
        num_wts = sum([tf.size(x) for x in w])
        v = tf.random.normal((10, num_wts))
        v_as_wts = tensor_utils.reshape_vector_as(w, v)

        for i in range(len(v_as_wts)):
            self.assertEqual(v_as_wts[i].shape[1:], w[i].shape)

        v_as_vec = tensor_utils.flat_concat(v_as_wts)
        self.assertAllClose(v, v_as_vec)
Esempio n. 5
0
    def test_largest_and_smallest_eigenvalue_estimation_correct(self):
        tf.compat.v1.random.set_random_seed(0)
        x_shape = (10, 5)
        y_shape = (10, 1)
        conv_dims = []
        conv_sizes = []
        dense_sizes = [5]
        n_classes = 3
        model = classifier.CNN(conv_dims, conv_sizes, dense_sizes, n_classes)
        itr = dataset_utils.get_supervised_batch_noise_iterator(
            x_shape, y_shape)
        loss_fn = ci.make_loss_fn(model, 1.)
        grad_fn = ci.make_grad_fn(model)
        map_grad_fn = ci.make_map_grad_fn(model)
        x, y = itr.next()
        _, _ = model.get_loss(x, y)

        loss_fn = ci.make_loss_fn(model, None)
        grad_fn = ci.make_grad_fn(model)
        map_grad_fn = ci.make_map_grad_fn(model)

        with tf.GradientTape(persistent=True) as tape:
            # First estimate the Hessian using training data from itr.
            with tf.GradientTape() as tape_inner:
                loss = tf.reduce_mean(loss_fn(x, y))
            grads = grad_fn(loss, tape_inner)
            concat_grads = tf.concat([tf.reshape(w, [-1, 1]) for w in grads],
                                     0)
            hessian_mapped = map_grad_fn(concat_grads, tape)
            # hessian_mapped is a list of n_params x model-shaped tensors
            # should just be able to flat_concat it
            hessian = tensor_utils.flat_concat(hessian_mapped)
        eigs, _ = tf.linalg.eigh(hessian)
        largest_ev, smallest_ev = eigs[-1], eigs[0]

        # We don't know what these eigenvalues should be, but just test that
        # the functions don't crash.
        est_largest_ev = eigenvalues.estimate_largest_ev(model,
                                                         1000,
                                                         itr,
                                                         loss_fn,
                                                         grad_fn,
                                                         map_grad_fn,
                                                         burnin=100)
        est_smallest_ev = eigenvalues.estimate_smallest_ev(largest_ev,
                                                           model,
                                                           1000,
                                                           itr,
                                                           loss_fn,
                                                           grad_fn,
                                                           map_grad_fn,
                                                           burnin=100)
        self.assertAllClose(largest_ev, est_largest_ev, 0.5)
        self.assertAllClose(smallest_ev, est_smallest_ev, 0.5)
def get_influence_on_test_loss(test_batch,
                               train_batch,
                               itr_train,
                               clf,
                               approx_params=None,
                               lam=None):
    """Calculate influence of examples in train_batch on examples in test_batch.

  Args:
    test_batch (tensor): examples to calculate influence on.
    train_batch (tensor): examples to calculate influence of.
    itr_train (Iterator): where to sample data for estimating Hessian from.
    clf (Classifier): the classifier we are interested in influence for.
    approx_params (dict): optional parameters for LISSA optimization.
    lam (float): optional L2-regularization for Hessian estimation.
  Returns:
    predicted_loss_diffs (tensors): n_test x n_train tensor, with entry
      (i, j) the influence of train example j on test example i.
  """

    loss_fn = make_loss_fn(clf, None)
    reg_loss_fn = make_loss_fn(clf, lam)
    grad_fn = make_grad_fn(clf)
    map_grad_fn = make_map_grad_fn(clf)

    x, y = test_batch
    test_grad_loss_no_reg_val = get_loss_grads(x, y, loss_fn, map_grad_fn)

    if approx_params is None:
        approx_params = {}
    inverse_hvp = get_inverse_hvp(test_grad_loss_no_reg_val, itr_train,
                                  reg_loss_fn, grad_fn, map_grad_fn,
                                  **approx_params)

    xtr, ytr = train_batch
    train_grad_loss_val = get_loss_grads(xtr, ytr, loss_fn, map_grad_fn)
    predicted_loss_diffs = tf.matmul(
        tensor_utils.flat_concat(inverse_hvp),
        tensor_utils.flat_concat(train_grad_loss_val),
        transpose_b=True)
    return predicted_loss_diffs
Esempio n. 7
0
  def hessian_vector_product(vec):
    """Takes the Hessian-vector product (HVP) of vec with model.

    Args:
      vec (vector): a (possibly batched) vector.
    Returns:
      flat_v_hvp: a (possibly batched) HVP of H(model) * vec.
    """
    weight_shaped_vector = tensor_utils.reshape_vector_as(model.weights, vec)

    v_hvp_total = 0.
    for _ in range(n_samples):
      v_hvp = ci.hvp(weight_shaped_vector, itr,
                     loss_fn, grad_fn, map_grad_fn)
      flat_v_hvp = tensor_utils.flat_concat(v_hvp)
      v_hvp_total += flat_v_hvp

    return v_hvp_total / float(n_samples)
def get_ihvp_conjugate_gradient(vec, itr, loss_function, gradient_function,
                                map_gradient_function, approx_params):
  """Calculate the inverse HVP of vec with the Hessian of loss_function.

  Note: HVP stands for Hessian-vector product.
  Let n be the number of examples in this batch.
  Let p be the number of parameters in the model which defines loss_function.
  Let the model's parameter output (model.weights) have shapes
    [s0, s1, ... s_k].
  Then vec is a list of tensors with shapes [(n, s0), (n, s1) ... (n, s_k)].

  Uses the Scipy implementation of conjugate gradient descent.

  Args:
    vec (list of tensors): the vector in our HVP, shape described above.
    itr (Iterator): an iterator of data we will use to estimate the Hessian.
    loss_function (function): a function which returns a gradient of losses.
    gradient_function (function): a function which takes the gradient of a
                                  scalar loss.
    map_gradient_function (function): a function which takes the gradient of
                                      each element of a vector of losses.
    approx_params (dict): parameters for conjugate gradient optimization.
  Returns:
    conjugate_gradient_result (tensor): a (n x p) tensor containing the desired
                                        IHVP.
  """
  # Reshape/concatenate our input so concat_vec has shape (n, p).
  concat_vec = tensor_utils.flat_concat(vec)
  def get_hvp(v):
    """Get the Hessian-vector product of v and Hessian in loss_function.

    Args:
      v (vector): a (n * p,)-shaped vector we want to multiply with H.
    Returns:
      hvp (vector): a (n, p)-shaped matrix representing Hv.
    """
    v = tf.reshape(v, concat_vec.shape)
    v = tensor_utils.reshape_vector_as([el[0] for el in vec], v)
    v_hvp = calculate_influence.hvp(v, itr,
                                    loss_function, gradient_function,
                                    map_gradient_function,
                                    n_samples=approx_params['hvp_samples'])
    if approx_params['squared']:
      v_hvp = calculate_influence.hvp(v_hvp, itr,
                                      loss_function, gradient_function,
                                      map_gradient_function,
                                      n_samples=approx_params['hvp_samples'])
    return tensor_utils.flat_concat(v_hvp)

  def objective(v):
    flat_v_hvp = get_hvp(v)
    v = tf.reshape(v, concat_vec.shape)
    objective_value = (0.5 * tf.reduce_sum(tf.multiply(v, flat_v_hvp), axis=1)
                       - tf.reduce_sum(tf.multiply(concat_vec, v), axis=1))
    logging.info('Evaluating objective: obj = {:.3f}'.
                 format(tf.reduce_mean(objective_value).numpy()))
    return tf.reduce_mean(objective_value)

  def objective_gradient(v):
    flat_v_hvp = get_hvp(v)
    grads = flat_v_hvp - concat_vec
    logging.info('Evaluating gradients: norm(grads) = {:.3f}'.
                 format(tf.linalg.norm(grads).numpy()))
    return tf.reshape(grads, [-1])

  def hessian_vector_product(_, v):
    s = time.time()
    hvp = get_hvp(v)
    t = time.time()
    logging.info('Evaluating Hessian: norm(hvp) = {:.3f} ({:.3f} seconds)'.
                 format(tf.linalg.norm(hvp).numpy(), t - s))
    return tf.reshape(hvp, [-1])

  def callback(v):
    hvp = get_hvp(v)
    err = tf.reduce_mean(tf.math.reduce_euclidean_norm(hvp - concat_vec,
                                                       axis=1))
    logging.info('Current error is: {:.4f}; Obj = {:.4f}, vnorm = {:.4f}'
                 .format(err, objective(v),
                         tf.math.reduce_euclidean_norm(v).numpy()))

  x_init = tf.reshape(concat_vec, [-1])
  conjugate_gradient_result, warning_flag = conjugate_gradient_optimize(
      objective, x_init, objective_gradient, hessian_vector_product,
      callback, maxiter=approx_params['maxiter'],
      tol=approx_params['tol'])
  conjugate_gradient_result = tf.reshape(conjugate_gradient_result,
                                         concat_vec.shape)
  return conjugate_gradient_result, warning_flag