def get_inverse_hvp(v, itr_train, loss_fn, grad_fn, map_grad_fn, scale=10, damping=0.0, num_samples=1, recursion_depth=10000, print_iter=1): """Calculates an HVP, getting the inverse Hessian using the LISSA method. Args: v (Tensor): the vector in the HVP - the gradient on the test example. itr_train (Iterator): iterator for getting batches to estimate H. loss_fn (function): a function which returns a gradient of losses. grad_fn (function): a function which takes the gradient of a scalar loss. map_grad_fn (function): a function which takes the gradient of each element of a vector of losses. scale (number): scales eigenvalues of Hessian to be <= 1. damping (number): in [0, 1), LISSA parameter - higher is more stable. num_samples (int): how many times to estimate the HVP. recursion_depth (int): how many steps in LISSA optimization. print_iter (int): how frequently to print LISSA updates. Returns: inverse_hvp (Tensor): the estimated product of the inverse Hessian of clf and v. """ value_constraints = [(scale, 'scale', lambda x: x > 0, 'greater than 0'), (damping, 'damping', lambda x: 0 <= x <= 1, 'between 0 and 1 inclusive'), (num_samples, 'num_samples', lambda x: x > 0, 'greater than 0'), (recursion_depth, 'recursion_depth', lambda x: x > 0, 'greater than 0')] for var, vname, constraint, msg in value_constraints: if not constraint(var): raise ValueError('{} should be {}'.format(vname, msg)) inverse_hvp = [tf.zeros_like(b) for b in v] for _ in range(num_samples): cur_estimate = v logging.info('cur estimate: %s', str([c.shape for c in cur_estimate])) for j in range(recursion_depth): old_estimate = cur_estimate hessian_vector_val = hvp(cur_estimate, itr_train, loss_fn, grad_fn, map_grad_fn) cur_estimate = [a + (1-damping) * b - c / scale for (a, b, c) in zip(v, cur_estimate, hessian_vector_val)] if (j % print_iter == 0) or (j == recursion_depth - 1): logging.info('Recursion at depth %d: norm is %.8lf, diff is %.8lf', j, np.linalg.norm( tensor_utils.flat_concat(cur_estimate).numpy()), np.linalg.norm( tensor_utils.flat_concat(cur_estimate).numpy() - tensor_utils.flat_concat(old_estimate).numpy())) old_estimate = cur_estimate inverse_hvp = [a + b / scale for (a, b) in zip(inverse_hvp, cur_estimate)] inverse_hvp = [a / num_samples for a in inverse_hvp] return inverse_hvp
def get_parameter_influence(model, x, y, itr, approx_params=None, damping=None): """Estimate the influence of test examples (x, y) on the parameters of model. Args: model (Classifier): a classification model whose parameters we are interested in. x (tensor): the input data whose influence we are interested in. y (tensor): the target data whose influence we are interested in. itr (Iterator): an iterator of data we will use to estimate the Hessian. approx_params (dict, optional): parameters for running LiSSA. damping (float, optional): the amount of L2-regularization to add to the parameters of model (only used for conjugate gradient). Returns: ihvp_result (tensor): the HVP of the inverse hessian of model (possibly with some L2-regularization) with the gradient of (x, y) w.r.t the parameters of model. concat_grads (tensor): the gradients of (x, y) w.r.t the model parameters. warning_flag (int): a flag representing if this optimization terminated successfully, returned by Scipy. """ loss_function = calculate_influence.make_loss_fn(model, damping) gradient_function = calculate_influence.make_grad_fn(model) map_gradient_function = calculate_influence.make_map_grad_fn(model) grads = calculate_influence.get_loss_grads(x, y, loss_function, map_gradient_function) concat_grads = tensor_utils.flat_concat(grads) ihvp_result, warning_flag = get_ihvp_conjugate_gradient( grads, itr, loss_function, gradient_function, map_gradient_function, approx_params) return ihvp_result, concat_grads, warning_flag
def get_hvp(v): """Get the Hessian-vector product of v and Hessian in loss_function. Args: v (vector): a (n * p,)-shaped vector we want to multiply with H. Returns: hvp (vector): a (n, p)-shaped matrix representing Hv. """ v = tf.reshape(v, concat_vec.shape) v = tensor_utils.reshape_vector_as([el[0] for el in vec], v) v_hvp = calculate_influence.hvp(v, itr, loss_function, gradient_function, map_gradient_function, n_samples=approx_params['hvp_samples']) if approx_params['squared']: v_hvp = calculate_influence.hvp( v_hvp, itr, loss_function, gradient_function, map_gradient_function, n_samples=approx_params['hvp_samples']) return tensor_utils.flat_concat(v_hvp)
def test_reshape_and_flatten(self): x_shape = (64, 28, 28, 1) y_shape = (64, 10) conv_dims = [20, 10] conv_sizes = [5, 5] dense_sizes = [100] n_classes = 10 model = classifier.CNN(conv_dims, conv_sizes, dense_sizes, n_classes, onehot=True) itr = dataset_utils.get_supervised_batch_noise_iterator( x_shape, y_shape) x, y = itr.next() _, _ = model.get_loss(x, y) w = model.weights num_wts = sum([tf.size(x) for x in w]) v = tf.random.normal((10, num_wts)) v_as_wts = tensor_utils.reshape_vector_as(w, v) for i in range(len(v_as_wts)): self.assertEqual(v_as_wts[i].shape[1:], w[i].shape) v_as_vec = tensor_utils.flat_concat(v_as_wts) self.assertAllClose(v, v_as_vec)
def test_largest_and_smallest_eigenvalue_estimation_correct(self): tf.compat.v1.random.set_random_seed(0) x_shape = (10, 5) y_shape = (10, 1) conv_dims = [] conv_sizes = [] dense_sizes = [5] n_classes = 3 model = classifier.CNN(conv_dims, conv_sizes, dense_sizes, n_classes) itr = dataset_utils.get_supervised_batch_noise_iterator( x_shape, y_shape) loss_fn = ci.make_loss_fn(model, 1.) grad_fn = ci.make_grad_fn(model) map_grad_fn = ci.make_map_grad_fn(model) x, y = itr.next() _, _ = model.get_loss(x, y) loss_fn = ci.make_loss_fn(model, None) grad_fn = ci.make_grad_fn(model) map_grad_fn = ci.make_map_grad_fn(model) with tf.GradientTape(persistent=True) as tape: # First estimate the Hessian using training data from itr. with tf.GradientTape() as tape_inner: loss = tf.reduce_mean(loss_fn(x, y)) grads = grad_fn(loss, tape_inner) concat_grads = tf.concat([tf.reshape(w, [-1, 1]) for w in grads], 0) hessian_mapped = map_grad_fn(concat_grads, tape) # hessian_mapped is a list of n_params x model-shaped tensors # should just be able to flat_concat it hessian = tensor_utils.flat_concat(hessian_mapped) eigs, _ = tf.linalg.eigh(hessian) largest_ev, smallest_ev = eigs[-1], eigs[0] # We don't know what these eigenvalues should be, but just test that # the functions don't crash. est_largest_ev = eigenvalues.estimate_largest_ev(model, 1000, itr, loss_fn, grad_fn, map_grad_fn, burnin=100) est_smallest_ev = eigenvalues.estimate_smallest_ev(largest_ev, model, 1000, itr, loss_fn, grad_fn, map_grad_fn, burnin=100) self.assertAllClose(largest_ev, est_largest_ev, 0.5) self.assertAllClose(smallest_ev, est_smallest_ev, 0.5)
def get_influence_on_test_loss(test_batch, train_batch, itr_train, clf, approx_params=None, lam=None): """Calculate influence of examples in train_batch on examples in test_batch. Args: test_batch (tensor): examples to calculate influence on. train_batch (tensor): examples to calculate influence of. itr_train (Iterator): where to sample data for estimating Hessian from. clf (Classifier): the classifier we are interested in influence for. approx_params (dict): optional parameters for LISSA optimization. lam (float): optional L2-regularization for Hessian estimation. Returns: predicted_loss_diffs (tensors): n_test x n_train tensor, with entry (i, j) the influence of train example j on test example i. """ loss_fn = make_loss_fn(clf, None) reg_loss_fn = make_loss_fn(clf, lam) grad_fn = make_grad_fn(clf) map_grad_fn = make_map_grad_fn(clf) x, y = test_batch test_grad_loss_no_reg_val = get_loss_grads(x, y, loss_fn, map_grad_fn) if approx_params is None: approx_params = {} inverse_hvp = get_inverse_hvp(test_grad_loss_no_reg_val, itr_train, reg_loss_fn, grad_fn, map_grad_fn, **approx_params) xtr, ytr = train_batch train_grad_loss_val = get_loss_grads(xtr, ytr, loss_fn, map_grad_fn) predicted_loss_diffs = tf.matmul( tensor_utils.flat_concat(inverse_hvp), tensor_utils.flat_concat(train_grad_loss_val), transpose_b=True) return predicted_loss_diffs
def hessian_vector_product(vec): """Takes the Hessian-vector product (HVP) of vec with model. Args: vec (vector): a (possibly batched) vector. Returns: flat_v_hvp: a (possibly batched) HVP of H(model) * vec. """ weight_shaped_vector = tensor_utils.reshape_vector_as(model.weights, vec) v_hvp_total = 0. for _ in range(n_samples): v_hvp = ci.hvp(weight_shaped_vector, itr, loss_fn, grad_fn, map_grad_fn) flat_v_hvp = tensor_utils.flat_concat(v_hvp) v_hvp_total += flat_v_hvp return v_hvp_total / float(n_samples)
def get_ihvp_conjugate_gradient(vec, itr, loss_function, gradient_function, map_gradient_function, approx_params): """Calculate the inverse HVP of vec with the Hessian of loss_function. Note: HVP stands for Hessian-vector product. Let n be the number of examples in this batch. Let p be the number of parameters in the model which defines loss_function. Let the model's parameter output (model.weights) have shapes [s0, s1, ... s_k]. Then vec is a list of tensors with shapes [(n, s0), (n, s1) ... (n, s_k)]. Uses the Scipy implementation of conjugate gradient descent. Args: vec (list of tensors): the vector in our HVP, shape described above. itr (Iterator): an iterator of data we will use to estimate the Hessian. loss_function (function): a function which returns a gradient of losses. gradient_function (function): a function which takes the gradient of a scalar loss. map_gradient_function (function): a function which takes the gradient of each element of a vector of losses. approx_params (dict): parameters for conjugate gradient optimization. Returns: conjugate_gradient_result (tensor): a (n x p) tensor containing the desired IHVP. """ # Reshape/concatenate our input so concat_vec has shape (n, p). concat_vec = tensor_utils.flat_concat(vec) def get_hvp(v): """Get the Hessian-vector product of v and Hessian in loss_function. Args: v (vector): a (n * p,)-shaped vector we want to multiply with H. Returns: hvp (vector): a (n, p)-shaped matrix representing Hv. """ v = tf.reshape(v, concat_vec.shape) v = tensor_utils.reshape_vector_as([el[0] for el in vec], v) v_hvp = calculate_influence.hvp(v, itr, loss_function, gradient_function, map_gradient_function, n_samples=approx_params['hvp_samples']) if approx_params['squared']: v_hvp = calculate_influence.hvp(v_hvp, itr, loss_function, gradient_function, map_gradient_function, n_samples=approx_params['hvp_samples']) return tensor_utils.flat_concat(v_hvp) def objective(v): flat_v_hvp = get_hvp(v) v = tf.reshape(v, concat_vec.shape) objective_value = (0.5 * tf.reduce_sum(tf.multiply(v, flat_v_hvp), axis=1) - tf.reduce_sum(tf.multiply(concat_vec, v), axis=1)) logging.info('Evaluating objective: obj = {:.3f}'. format(tf.reduce_mean(objective_value).numpy())) return tf.reduce_mean(objective_value) def objective_gradient(v): flat_v_hvp = get_hvp(v) grads = flat_v_hvp - concat_vec logging.info('Evaluating gradients: norm(grads) = {:.3f}'. format(tf.linalg.norm(grads).numpy())) return tf.reshape(grads, [-1]) def hessian_vector_product(_, v): s = time.time() hvp = get_hvp(v) t = time.time() logging.info('Evaluating Hessian: norm(hvp) = {:.3f} ({:.3f} seconds)'. format(tf.linalg.norm(hvp).numpy(), t - s)) return tf.reshape(hvp, [-1]) def callback(v): hvp = get_hvp(v) err = tf.reduce_mean(tf.math.reduce_euclidean_norm(hvp - concat_vec, axis=1)) logging.info('Current error is: {:.4f}; Obj = {:.4f}, vnorm = {:.4f}' .format(err, objective(v), tf.math.reduce_euclidean_norm(v).numpy())) x_init = tf.reshape(concat_vec, [-1]) conjugate_gradient_result, warning_flag = conjugate_gradient_optimize( objective, x_init, objective_gradient, hessian_vector_product, callback, maxiter=approx_params['maxiter'], tol=approx_params['tol']) conjugate_gradient_result = tf.reshape(conjugate_gradient_result, concat_vec.shape) return conjugate_gradient_result, warning_flag