Example #1
0
    def test_reshape_and_flatten(self):
        x_shape = (64, 28, 28, 1)
        y_shape = (64, 10)
        conv_dims = [20, 10]
        conv_sizes = [5, 5]
        dense_sizes = [100]
        n_classes = 10
        model = classifier.CNN(conv_dims,
                               conv_sizes,
                               dense_sizes,
                               n_classes,
                               onehot=True)
        itr = dataset_utils.get_supervised_batch_noise_iterator(
            x_shape, y_shape)
        x, y = itr.next()
        _, _ = model.get_loss(x, y)

        w = model.weights
        num_wts = sum([tf.size(x) for x in w])
        v = tf.random.normal((10, num_wts))
        v_as_wts = tensor_utils.reshape_vector_as(w, v)

        for i in range(len(v_as_wts)):
            self.assertEqual(v_as_wts[i].shape[1:], w[i].shape)

        v_as_vec = tensor_utils.flat_concat(v_as_wts)
        self.assertAllClose(v, v_as_vec)
    def get_hvp(v):
        """Get the Hessian-vector product of v and Hessian in loss_function.

    Args:
      v (vector): a (n * p,)-shaped vector we want to multiply with H.
    Returns:
      hvp (vector): a (n, p)-shaped matrix representing Hv.
    """
        v = tf.reshape(v, concat_vec.shape)
        v = tensor_utils.reshape_vector_as([el[0] for el in vec], v)
        v_hvp = calculate_influence.hvp(v,
                                        itr,
                                        loss_function,
                                        gradient_function,
                                        map_gradient_function,
                                        n_samples=approx_params['hvp_samples'])
        if approx_params['squared']:
            v_hvp = calculate_influence.hvp(
                v_hvp,
                itr,
                loss_function,
                gradient_function,
                map_gradient_function,
                n_samples=approx_params['hvp_samples'])
        return tensor_utils.flat_concat(v_hvp)
  def hessian_vector_product(vec):
    """Takes the Hessian-vector product (HVP) of vec with model.

    Args:
      vec (vector): a (possibly batched) vector.
    Returns:
      flat_v_hvp: a (possibly batched) HVP of H(model) * vec.
    """
    weight_shaped_vector = tensor_utils.reshape_vector_as(model.weights, vec)

    v_hvp_total = 0.
    for _ in range(n_samples):
      v_hvp = ci.hvp(weight_shaped_vector, itr,
                     loss_fn, grad_fn, map_grad_fn)
      flat_v_hvp = tensor_utils.flat_concat(v_hvp)
      v_hvp_total += flat_v_hvp

    return v_hvp_total / float(n_samples)
def calculate_influence_ood(params):
  """Calculates influence functions for pre-trained model with OOD classes.

  Args:
    params (dict): contains a number of params - as loaded from flags.
    Should contain:
      seed (int) - random seed for Tensorflow and Numpy initialization.
      training_results_dir (str) - parent directory of the pre-trained model.
      clf_name (str) - the name of the pre-trained model's directory.
      n_test_infl (int) - number of examples to run influence functions for.
      start_ix_test_infl (int) - index to start loading examples from.
      cg_maxiter (int) - max number of iterations for conjugate gradient.
      squared (bool) - whether to calculate squared Hessian directly.
      tol (float) - tolerance for conjugate gradient.
      lam (float) - L2 regularization amount for Hessian.
      hvp_samples (int) - number of samples to take in HVP estimation.
      output_dir (str) - where results should be written - defaults to
        training_results_dir/clf_name/influence_results.
      tname (str) - extra string to add to saved tensor names; can be ''.
      preloaded_model (model or None) - if None, we should load the model
        ourselves. Otherwise, preloaded_model is the model we are interested in.
      preloaded_itr (Iterator or None) - if None, load the data iterator
        ourselves; otherwise, use preloaded_itr as the data iterator.
  """

  tf.set_random_seed(params['seed'])
  np.random.seed(params['seed'])

  # Load a trained classifier.
  modeldir = os.path.join(params['training_results_dir'], params['clf_name'])
  param_file = os.path.join(modeldir, 'params.json')
  model_params = utils.load_json(param_file)

  if params['preloaded_model'] is None:
    ckpt_path = os.path.join(modeldir, 'ckpts/bestmodel-1')
    cnn_args = {'conv_dims':
                    [int(x) for x in model_params['conv_dims'].split(',')],
                'conv_sizes':
                    [int(x) for x in model_params['conv_sizes'].split(',')],
                'dense_sizes':
                    [int(x) for x in model_params['dense_sizes'].split(',')],
                'n_classes': model_params['n_classes'], 'onehot': True}
    model = utils.load_model(ckpt_path, classifier.CNN, cnn_args)
  else:
    model = params['preloaded_model']

  # Load train/validation/test examples
  tensordir = os.path.join(modeldir, 'tensors')
  validation_x = utils.load_tensor(os.path.join(tensordir, 'valid_x_infl.npy'))
  test_x = utils.load_tensor(os.path.join(tensordir, 'test_x_infl.npy'))
  ood_x = utils.load_tensor(os.path.join(tensordir, 'ood_x_infl.npy'))

  # Get in- and out-of-distribution classes.
  n_labels = model_params['n_classes']
  all_classes = range(n_labels)
  ood_classes = ([int(x) for x in model_params['ood_classes'].split(',')]
                 if 'ood_classes' in model_params else [])
  ind_classes = [x for x in all_classes if x not in ood_classes]

  # Load an iterator of training data.
  label_noise = (model_params['label_noise']
                 if 'label_noise' in model_params else 0.)

  # We only look at a portion of the test set for computational reasons.
  ninfl = params['n_test_infl']
  start_ix = params['start_ix_test_infl']
  end_ix = start_ix + ninfl
  xinfl_validation = validation_x[start_ix: end_ix]
  xinfl_test = test_x[start_ix: end_ix]
  xinfl_ood = ood_x[start_ix: end_ix]

  # We want to rotate through all the label options.
  y_all = tf.concat([tf.one_hot(tf.fill((ninfl,), lab), depth=n_labels)
                     for lab in ind_classes], axis=0)
  y_all = tf.concat([y_all, y_all, y_all], axis=0)

  xinfl_validation_all = tf.concat([xinfl_validation for _ in ind_classes],
                                   axis=0)
  xinfl_test_all = tf.concat([xinfl_test for _ in ind_classes], axis=0)
  xinfl_ood_all = tf.concat([xinfl_ood for _ in ind_classes], axis=0)
  x_all = tf.concat([xinfl_validation_all, xinfl_test_all, xinfl_ood_all],
                    axis=0)

  cg_approx_params = {'maxiter': params['cg_maxiter'],
                      'squared': params['squared'],
                      'tol': params['tol'],
                      'hvp_samples': params['hvp_samples']}

  # Here we run conjugate gradient one example at a time, collecting
  # the following outputs.

  # H^{-1}g
  infl_value = []
  # gH^{-1}g
  infl_laplace = []
  # H^{-2}g
  infl_deriv = []
  # g
  grads = []
  # When calculating H^{-1}g with conjugate gradient, Scipy returns a flag
  # denoting the optimization's success.
  warning_flags = []
  # When calculating H^{-2}g with conjugate gradient, Scipy returns a flag
  # denoting the optimization's success.
  warning_flags_deriv = []

  for i in range(x_all.shape[0]):
    logging.info('Example {:d}'.format(i))
    s = time.time()
    xi = tf.expand_dims(x_all[i], 0)
    yi = tf.expand_dims(y_all[i], 0)
    if params['preloaded_itr'] is None:
      itr_train, _, _, _ = dataset_utils.load_dataset_ood_supervised_onehot(
          ind_classes, ood_classes, label_noise=label_noise)
    else:
      itr_train = params['preloaded_itr']
    infl_value_i, grads_i, warning_flag_i = get_parameter_influence(
        model, xi, yi, itr_train,
        approx_params=cg_approx_params,
        damping=params['lam'])
    t = time.time()
    logging.info('IHVP calculation took {:.3f} seconds'.format(t - s))
    infl_laplace_i = tf.multiply(infl_value_i, grads_i)

    infl_value_wtshape = tensor_utils.reshape_vector_as(model.weights,
                                                        infl_value_i)
    loss_function = calculate_influence.make_loss_fn(model, params['lam'])
    gradient_function = calculate_influence.make_grad_fn(model)
    map_gradient_function = calculate_influence.make_map_grad_fn(model)
    s = time.time()
    infl_deriv_i, warning_flag_deriv_i = get_ihvp_conjugate_gradient(
        infl_value_wtshape, itr_train,
        loss_function, gradient_function, map_gradient_function,
        approx_params=cg_approx_params)
    t = time.time()
    logging.info('Second IHVP calculation took {:.3f} seconds'.format(t - s))
    infl_value.append(infl_value_i)
    infl_laplace.append(infl_laplace_i)
    infl_deriv.append(infl_deriv_i)
    grads.append(grads_i)
    warning_flags.append(tf.expand_dims(warning_flag_i, 0))
    warning_flags_deriv.append(tf.expand_dims(warning_flag_deriv_i, 0))

  infl_value = tf.concat(infl_value, axis=0)
  infl_laplace = tf.concat(infl_laplace, axis=0)
  infl_deriv = tf.concat(infl_deriv, axis=0)
  grads = tf.concat(grads, axis=0)
  warning_flags = tf.concat(warning_flags, axis=0)
  warning_flags_deriv = tf.concat(warning_flags_deriv, axis=0)

  res = {}
  for infl_res, nm in [(infl_value, 'infl'),
                       (infl_deriv, 'deriv'),
                       (infl_laplace, 'laplace'),
                       (grads, 'grads'),
                       (warning_flags, 'warnflags'),
                       (warning_flags_deriv, 'warnflags_deriv')]:
    res['valid_{}'.format(nm)] = infl_res[:ninfl * len(ind_classes)]
    res['test_{}'.format(nm)] = infl_res[
        ninfl * len(ind_classes): 2 * ninfl * len(ind_classes)]
    res['ood_{}'.format(nm)] = infl_res[2 * ninfl * len(ind_classes):]

  # Save the results of these calculations.
  if params['output_dir']:
    resdir = utils.make_subdir(params['output_dir'], 'influence_results')
  else:
    resdir = utils.make_subdir(modeldir, 'influence_results')
  tensor_name_template = '{}{}-inv_hvp-cg-ix{:d}-ninfl{:d}'+ (
      '_squared' if params['squared'] else '')
  infl_tensors = [
      (tensor_name_template.format(params['tname'], label, start_ix, ninfl),
       res[label]) for label in res.keys()]
  utils.save_tensors(infl_tensors, resdir)