def calculate_influence_ood(params):
  """Calculates influence functions for pre-trained model with OOD classes.

  Args:
    params (dict): contains a number of params - as loaded from flags.
    Should contain:
      seed (int) - random seed for Tensorflow and Numpy initialization.
      training_results_dir (str) - parent directory of the pre-trained model.
      clf_name (str) - the name of the pre-trained model's directory.
      n_test_infl (int) - number of examples to run influence functions for.
      start_ix_test_infl (int) - index to start loading examples from.
      cg_maxiter (int) - max number of iterations for conjugate gradient.
      squared (bool) - whether to calculate squared Hessian directly.
      tol (float) - tolerance for conjugate gradient.
      lam (float) - L2 regularization amount for Hessian.
      hvp_samples (int) - number of samples to take in HVP estimation.
      output_dir (str) - where results should be written - defaults to
        training_results_dir/clf_name/influence_results.
      tname (str) - extra string to add to saved tensor names; can be ''.
      preloaded_model (model or None) - if None, we should load the model
        ourselves. Otherwise, preloaded_model is the model we are interested in.
      preloaded_itr (Iterator or None) - if None, load the data iterator
        ourselves; otherwise, use preloaded_itr as the data iterator.
  """

  tf.set_random_seed(params['seed'])
  np.random.seed(params['seed'])

  # Load a trained classifier.
  modeldir = os.path.join(params['training_results_dir'], params['clf_name'])
  param_file = os.path.join(modeldir, 'params.json')
  model_params = utils.load_json(param_file)

  if params['preloaded_model'] is None:
    ckpt_path = os.path.join(modeldir, 'ckpts/bestmodel-1')
    cnn_args = {'conv_dims':
                    [int(x) for x in model_params['conv_dims'].split(',')],
                'conv_sizes':
                    [int(x) for x in model_params['conv_sizes'].split(',')],
                'dense_sizes':
                    [int(x) for x in model_params['dense_sizes'].split(',')],
                'n_classes': model_params['n_classes'], 'onehot': True}
    model = utils.load_model(ckpt_path, classifier.CNN, cnn_args)
  else:
    model = params['preloaded_model']

  # Load train/validation/test examples
  tensordir = os.path.join(modeldir, 'tensors')
  validation_x = utils.load_tensor(os.path.join(tensordir, 'valid_x_infl.npy'))
  test_x = utils.load_tensor(os.path.join(tensordir, 'test_x_infl.npy'))
  ood_x = utils.load_tensor(os.path.join(tensordir, 'ood_x_infl.npy'))

  # Get in- and out-of-distribution classes.
  n_labels = model_params['n_classes']
  all_classes = range(n_labels)
  ood_classes = ([int(x) for x in model_params['ood_classes'].split(',')]
                 if 'ood_classes' in model_params else [])
  ind_classes = [x for x in all_classes if x not in ood_classes]

  # Load an iterator of training data.
  label_noise = (model_params['label_noise']
                 if 'label_noise' in model_params else 0.)

  # We only look at a portion of the test set for computational reasons.
  ninfl = params['n_test_infl']
  start_ix = params['start_ix_test_infl']
  end_ix = start_ix + ninfl
  xinfl_validation = validation_x[start_ix: end_ix]
  xinfl_test = test_x[start_ix: end_ix]
  xinfl_ood = ood_x[start_ix: end_ix]

  # We want to rotate through all the label options.
  y_all = tf.concat([tf.one_hot(tf.fill((ninfl,), lab), depth=n_labels)
                     for lab in ind_classes], axis=0)
  y_all = tf.concat([y_all, y_all, y_all], axis=0)

  xinfl_validation_all = tf.concat([xinfl_validation for _ in ind_classes],
                                   axis=0)
  xinfl_test_all = tf.concat([xinfl_test for _ in ind_classes], axis=0)
  xinfl_ood_all = tf.concat([xinfl_ood for _ in ind_classes], axis=0)
  x_all = tf.concat([xinfl_validation_all, xinfl_test_all, xinfl_ood_all],
                    axis=0)

  cg_approx_params = {'maxiter': params['cg_maxiter'],
                      'squared': params['squared'],
                      'tol': params['tol'],
                      'hvp_samples': params['hvp_samples']}

  # Here we run conjugate gradient one example at a time, collecting
  # the following outputs.

  # H^{-1}g
  infl_value = []
  # gH^{-1}g
  infl_laplace = []
  # H^{-2}g
  infl_deriv = []
  # g
  grads = []
  # When calculating H^{-1}g with conjugate gradient, Scipy returns a flag
  # denoting the optimization's success.
  warning_flags = []
  # When calculating H^{-2}g with conjugate gradient, Scipy returns a flag
  # denoting the optimization's success.
  warning_flags_deriv = []

  for i in range(x_all.shape[0]):
    logging.info('Example {:d}'.format(i))
    s = time.time()
    xi = tf.expand_dims(x_all[i], 0)
    yi = tf.expand_dims(y_all[i], 0)
    if params['preloaded_itr'] is None:
      itr_train, _, _, _ = dataset_utils.load_dataset_ood_supervised_onehot(
          ind_classes, ood_classes, label_noise=label_noise)
    else:
      itr_train = params['preloaded_itr']
    infl_value_i, grads_i, warning_flag_i = get_parameter_influence(
        model, xi, yi, itr_train,
        approx_params=cg_approx_params,
        damping=params['lam'])
    t = time.time()
    logging.info('IHVP calculation took {:.3f} seconds'.format(t - s))
    infl_laplace_i = tf.multiply(infl_value_i, grads_i)

    infl_value_wtshape = tensor_utils.reshape_vector_as(model.weights,
                                                        infl_value_i)
    loss_function = calculate_influence.make_loss_fn(model, params['lam'])
    gradient_function = calculate_influence.make_grad_fn(model)
    map_gradient_function = calculate_influence.make_map_grad_fn(model)
    s = time.time()
    infl_deriv_i, warning_flag_deriv_i = get_ihvp_conjugate_gradient(
        infl_value_wtshape, itr_train,
        loss_function, gradient_function, map_gradient_function,
        approx_params=cg_approx_params)
    t = time.time()
    logging.info('Second IHVP calculation took {:.3f} seconds'.format(t - s))
    infl_value.append(infl_value_i)
    infl_laplace.append(infl_laplace_i)
    infl_deriv.append(infl_deriv_i)
    grads.append(grads_i)
    warning_flags.append(tf.expand_dims(warning_flag_i, 0))
    warning_flags_deriv.append(tf.expand_dims(warning_flag_deriv_i, 0))

  infl_value = tf.concat(infl_value, axis=0)
  infl_laplace = tf.concat(infl_laplace, axis=0)
  infl_deriv = tf.concat(infl_deriv, axis=0)
  grads = tf.concat(grads, axis=0)
  warning_flags = tf.concat(warning_flags, axis=0)
  warning_flags_deriv = tf.concat(warning_flags_deriv, axis=0)

  res = {}
  for infl_res, nm in [(infl_value, 'infl'),
                       (infl_deriv, 'deriv'),
                       (infl_laplace, 'laplace'),
                       (grads, 'grads'),
                       (warning_flags, 'warnflags'),
                       (warning_flags_deriv, 'warnflags_deriv')]:
    res['valid_{}'.format(nm)] = infl_res[:ninfl * len(ind_classes)]
    res['test_{}'.format(nm)] = infl_res[
        ninfl * len(ind_classes): 2 * ninfl * len(ind_classes)]
    res['ood_{}'.format(nm)] = infl_res[2 * ninfl * len(ind_classes):]

  # Save the results of these calculations.
  if params['output_dir']:
    resdir = utils.make_subdir(params['output_dir'], 'influence_results')
  else:
    resdir = utils.make_subdir(modeldir, 'influence_results')
  tensor_name_template = '{}{}-inv_hvp-cg-ix{:d}-ninfl{:d}'+ (
      '_squared' if params['squared'] else '')
  infl_tensors = [
      (tensor_name_template.format(params['tname'], label, start_ix, ninfl),
       res[label]) for label in res.keys()]
  utils.save_tensors(infl_tensors, resdir)
def run(params):
    """Calculates influence functions for a pre-loaded model.

  params should contain:
    seed (int): random seed for Tensorflow and Numpy initialization.
    training_results_dir (str): the parent directory of the pre-trained model.
    clf_name (str): the name of the pre-trained model's directory.
    n_test_infl (int): number of test examples to run influence functions for.
    n_train_infl (int): number of train examples to run influence functions for.
    lissa_recursion_depth (int): how long to run LiSSA for.
    output_dir (str): where results should be written - defaults to
      training_results_dir/clf_name/influence_results.
  Args:
    params (dict): contains a number of params - as loaded from flags.
  """
    tf.set_random_seed(params['seed'])
    np.random.seed(params['seed'])

    # Load a trained classifier.
    modeldir = os.path.join(params['training_results_dir'], params['clf_name'])
    ckpt_path = os.path.join(modeldir, 'ckpts/bestmodel-1')
    param_file = os.path.join(modeldir, 'params.json')
    clf_params = utils.load_json(param_file)

    cnn_args = {
        'conv_dims': [int(x) for x in clf_params['conv_dims'].split(',')],
        'conv_sizes': [int(x) for x in clf_params['conv_sizes'].split(',')],
        'dense_sizes': [int(x) for x in clf_params['dense_sizes'].split(',')],
        'n_classes': 10,
        'onehot': True
    }
    clf = utils.load_model(ckpt_path, classifier.CNN, cnn_args)

    # Load train/valid/test examples
    tensordir = os.path.join(modeldir, 'tensors')
    train_x = utils.load_tensor(os.path.join(tensordir, 'train_x_infl.npy'))
    train_y = utils.load_tensor(os.path.join(tensordir, 'train_y_infl.npy'))
    test_x = utils.load_tensor(os.path.join(tensordir, 'test_x_infl.npy'))
    test_y = utils.load_tensor(os.path.join(tensordir, 'test_y_infl.npy'))

    # Calculate influence functions for this model.
    if params['output_dir']:
        resdir = utils.make_subdir(params['output_dir'], 'influence_results')
    else:
        resdir = utils.make_subdir(modeldir, 'influence_results')

    itr_train, _, _ = dataset_utils.load_dataset_supervised_onehot()
    bx, by = test_x[:params['n_test_infl']], test_y[:params['n_test_infl']]
    train_loss, _ = clf.get_loss(bx, by)
    logging.info('Current loss: {:3f}'.format(tf.reduce_mean(train_loss)))
    bx_tr, by_tr = (train_x[:params['n_train_infl']],
                    train_y[:params['n_train_infl']])
    predicted_loss_diffs = get_influence_on_test_loss(
        (bx, by), (bx_tr, by_tr),
        itr_train,
        clf,
        approx_params={
            'scale': 100,
            'damping': 0.1,
            'num_samples': 1,
            'recursion_depth': params['lissa_recursion_depth'],
            'print_iter': 10
        },
        lam=0.01)
    utils.save_tensors([('predicted_loss_diffs', predicted_loss_diffs)],
                       resdir)

    d = predicted_loss_diffs.numpy()
    df = d.flatten()
    dfsort = sorted(zip(df, range(len(df))),
                    key=lambda x: abs(x[0]),
                    reverse=True)

    # Ordering note: df[:10] == d[0,:10]
    # Loop over the highest influence pairs.
    imgs_to_save = []
    titles = []
    n_img_pairs = 10
    for inf_val, idx in dfsort[:n_img_pairs]:
        i, tr_i = convert_index(idx, d)
        imgs_to_save += [bx[i, :, :, 0], bx_tr[tr_i, :, :, 0]]
        titles += [
            'Test Image: inf={:.8f}'.format(inf_val),
            'Train Image: inf={:.8f}'.format(inf_val)
        ]
    utils.save_images(os.path.join(resdir, 'most_influential.pdf'),
                      imgs_to_save,
                      n_img_pairs,
                      titles=titles)

    # For each test image, find the training image with the highest influence
    imgs_to_save = []
    titles = []
    n_img_pairs = 10  # Do this many pairs from the top of the test tensor
    for i in range(n_img_pairs):
        infl_i = d[i, :]
        max_train_ind = np.argmax(infl_i)
        max_infl_val = max(infl_i)
        imgs_to_save += [bx[i, :, :, 0], bx_tr[max_train_ind, :, :, 0]]
        titles += [
            'Test Image {:d}: inf={:.8f}'.format(i, max_infl_val),
            'Train Image {:d}: inf={:.8f}'.format(max_train_ind, max_infl_val)
        ]
    utils.save_images(os.path.join(resdir, 'most_influential_by_test_img.pdf'),
                      imgs_to_save,
                      n_img_pairs,
                      titles=titles)