def calculate_influence_ood(params): """Calculates influence functions for pre-trained model with OOD classes. Args: params (dict): contains a number of params - as loaded from flags. Should contain: seed (int) - random seed for Tensorflow and Numpy initialization. training_results_dir (str) - parent directory of the pre-trained model. clf_name (str) - the name of the pre-trained model's directory. n_test_infl (int) - number of examples to run influence functions for. start_ix_test_infl (int) - index to start loading examples from. cg_maxiter (int) - max number of iterations for conjugate gradient. squared (bool) - whether to calculate squared Hessian directly. tol (float) - tolerance for conjugate gradient. lam (float) - L2 regularization amount for Hessian. hvp_samples (int) - number of samples to take in HVP estimation. output_dir (str) - where results should be written - defaults to training_results_dir/clf_name/influence_results. tname (str) - extra string to add to saved tensor names; can be ''. preloaded_model (model or None) - if None, we should load the model ourselves. Otherwise, preloaded_model is the model we are interested in. preloaded_itr (Iterator or None) - if None, load the data iterator ourselves; otherwise, use preloaded_itr as the data iterator. """ tf.set_random_seed(params['seed']) np.random.seed(params['seed']) # Load a trained classifier. modeldir = os.path.join(params['training_results_dir'], params['clf_name']) param_file = os.path.join(modeldir, 'params.json') model_params = utils.load_json(param_file) if params['preloaded_model'] is None: ckpt_path = os.path.join(modeldir, 'ckpts/bestmodel-1') cnn_args = {'conv_dims': [int(x) for x in model_params['conv_dims'].split(',')], 'conv_sizes': [int(x) for x in model_params['conv_sizes'].split(',')], 'dense_sizes': [int(x) for x in model_params['dense_sizes'].split(',')], 'n_classes': model_params['n_classes'], 'onehot': True} model = utils.load_model(ckpt_path, classifier.CNN, cnn_args) else: model = params['preloaded_model'] # Load train/validation/test examples tensordir = os.path.join(modeldir, 'tensors') validation_x = utils.load_tensor(os.path.join(tensordir, 'valid_x_infl.npy')) test_x = utils.load_tensor(os.path.join(tensordir, 'test_x_infl.npy')) ood_x = utils.load_tensor(os.path.join(tensordir, 'ood_x_infl.npy')) # Get in- and out-of-distribution classes. n_labels = model_params['n_classes'] all_classes = range(n_labels) ood_classes = ([int(x) for x in model_params['ood_classes'].split(',')] if 'ood_classes' in model_params else []) ind_classes = [x for x in all_classes if x not in ood_classes] # Load an iterator of training data. label_noise = (model_params['label_noise'] if 'label_noise' in model_params else 0.) # We only look at a portion of the test set for computational reasons. ninfl = params['n_test_infl'] start_ix = params['start_ix_test_infl'] end_ix = start_ix + ninfl xinfl_validation = validation_x[start_ix: end_ix] xinfl_test = test_x[start_ix: end_ix] xinfl_ood = ood_x[start_ix: end_ix] # We want to rotate through all the label options. y_all = tf.concat([tf.one_hot(tf.fill((ninfl,), lab), depth=n_labels) for lab in ind_classes], axis=0) y_all = tf.concat([y_all, y_all, y_all], axis=0) xinfl_validation_all = tf.concat([xinfl_validation for _ in ind_classes], axis=0) xinfl_test_all = tf.concat([xinfl_test for _ in ind_classes], axis=0) xinfl_ood_all = tf.concat([xinfl_ood for _ in ind_classes], axis=0) x_all = tf.concat([xinfl_validation_all, xinfl_test_all, xinfl_ood_all], axis=0) cg_approx_params = {'maxiter': params['cg_maxiter'], 'squared': params['squared'], 'tol': params['tol'], 'hvp_samples': params['hvp_samples']} # Here we run conjugate gradient one example at a time, collecting # the following outputs. # H^{-1}g infl_value = [] # gH^{-1}g infl_laplace = [] # H^{-2}g infl_deriv = [] # g grads = [] # When calculating H^{-1}g with conjugate gradient, Scipy returns a flag # denoting the optimization's success. warning_flags = [] # When calculating H^{-2}g with conjugate gradient, Scipy returns a flag # denoting the optimization's success. warning_flags_deriv = [] for i in range(x_all.shape[0]): logging.info('Example {:d}'.format(i)) s = time.time() xi = tf.expand_dims(x_all[i], 0) yi = tf.expand_dims(y_all[i], 0) if params['preloaded_itr'] is None: itr_train, _, _, _ = dataset_utils.load_dataset_ood_supervised_onehot( ind_classes, ood_classes, label_noise=label_noise) else: itr_train = params['preloaded_itr'] infl_value_i, grads_i, warning_flag_i = get_parameter_influence( model, xi, yi, itr_train, approx_params=cg_approx_params, damping=params['lam']) t = time.time() logging.info('IHVP calculation took {:.3f} seconds'.format(t - s)) infl_laplace_i = tf.multiply(infl_value_i, grads_i) infl_value_wtshape = tensor_utils.reshape_vector_as(model.weights, infl_value_i) loss_function = calculate_influence.make_loss_fn(model, params['lam']) gradient_function = calculate_influence.make_grad_fn(model) map_gradient_function = calculate_influence.make_map_grad_fn(model) s = time.time() infl_deriv_i, warning_flag_deriv_i = get_ihvp_conjugate_gradient( infl_value_wtshape, itr_train, loss_function, gradient_function, map_gradient_function, approx_params=cg_approx_params) t = time.time() logging.info('Second IHVP calculation took {:.3f} seconds'.format(t - s)) infl_value.append(infl_value_i) infl_laplace.append(infl_laplace_i) infl_deriv.append(infl_deriv_i) grads.append(grads_i) warning_flags.append(tf.expand_dims(warning_flag_i, 0)) warning_flags_deriv.append(tf.expand_dims(warning_flag_deriv_i, 0)) infl_value = tf.concat(infl_value, axis=0) infl_laplace = tf.concat(infl_laplace, axis=0) infl_deriv = tf.concat(infl_deriv, axis=0) grads = tf.concat(grads, axis=0) warning_flags = tf.concat(warning_flags, axis=0) warning_flags_deriv = tf.concat(warning_flags_deriv, axis=0) res = {} for infl_res, nm in [(infl_value, 'infl'), (infl_deriv, 'deriv'), (infl_laplace, 'laplace'), (grads, 'grads'), (warning_flags, 'warnflags'), (warning_flags_deriv, 'warnflags_deriv')]: res['valid_{}'.format(nm)] = infl_res[:ninfl * len(ind_classes)] res['test_{}'.format(nm)] = infl_res[ ninfl * len(ind_classes): 2 * ninfl * len(ind_classes)] res['ood_{}'.format(nm)] = infl_res[2 * ninfl * len(ind_classes):] # Save the results of these calculations. if params['output_dir']: resdir = utils.make_subdir(params['output_dir'], 'influence_results') else: resdir = utils.make_subdir(modeldir, 'influence_results') tensor_name_template = '{}{}-inv_hvp-cg-ix{:d}-ninfl{:d}'+ ( '_squared' if params['squared'] else '') infl_tensors = [ (tensor_name_template.format(params['tname'], label, start_ix, ninfl), res[label]) for label in res.keys()] utils.save_tensors(infl_tensors, resdir)
def run(params): """Calculates influence functions for a pre-loaded model. params should contain: seed (int): random seed for Tensorflow and Numpy initialization. training_results_dir (str): the parent directory of the pre-trained model. clf_name (str): the name of the pre-trained model's directory. n_test_infl (int): number of test examples to run influence functions for. n_train_infl (int): number of train examples to run influence functions for. lissa_recursion_depth (int): how long to run LiSSA for. output_dir (str): where results should be written - defaults to training_results_dir/clf_name/influence_results. Args: params (dict): contains a number of params - as loaded from flags. """ tf.set_random_seed(params['seed']) np.random.seed(params['seed']) # Load a trained classifier. modeldir = os.path.join(params['training_results_dir'], params['clf_name']) ckpt_path = os.path.join(modeldir, 'ckpts/bestmodel-1') param_file = os.path.join(modeldir, 'params.json') clf_params = utils.load_json(param_file) cnn_args = { 'conv_dims': [int(x) for x in clf_params['conv_dims'].split(',')], 'conv_sizes': [int(x) for x in clf_params['conv_sizes'].split(',')], 'dense_sizes': [int(x) for x in clf_params['dense_sizes'].split(',')], 'n_classes': 10, 'onehot': True } clf = utils.load_model(ckpt_path, classifier.CNN, cnn_args) # Load train/valid/test examples tensordir = os.path.join(modeldir, 'tensors') train_x = utils.load_tensor(os.path.join(tensordir, 'train_x_infl.npy')) train_y = utils.load_tensor(os.path.join(tensordir, 'train_y_infl.npy')) test_x = utils.load_tensor(os.path.join(tensordir, 'test_x_infl.npy')) test_y = utils.load_tensor(os.path.join(tensordir, 'test_y_infl.npy')) # Calculate influence functions for this model. if params['output_dir']: resdir = utils.make_subdir(params['output_dir'], 'influence_results') else: resdir = utils.make_subdir(modeldir, 'influence_results') itr_train, _, _ = dataset_utils.load_dataset_supervised_onehot() bx, by = test_x[:params['n_test_infl']], test_y[:params['n_test_infl']] train_loss, _ = clf.get_loss(bx, by) logging.info('Current loss: {:3f}'.format(tf.reduce_mean(train_loss))) bx_tr, by_tr = (train_x[:params['n_train_infl']], train_y[:params['n_train_infl']]) predicted_loss_diffs = get_influence_on_test_loss( (bx, by), (bx_tr, by_tr), itr_train, clf, approx_params={ 'scale': 100, 'damping': 0.1, 'num_samples': 1, 'recursion_depth': params['lissa_recursion_depth'], 'print_iter': 10 }, lam=0.01) utils.save_tensors([('predicted_loss_diffs', predicted_loss_diffs)], resdir) d = predicted_loss_diffs.numpy() df = d.flatten() dfsort = sorted(zip(df, range(len(df))), key=lambda x: abs(x[0]), reverse=True) # Ordering note: df[:10] == d[0,:10] # Loop over the highest influence pairs. imgs_to_save = [] titles = [] n_img_pairs = 10 for inf_val, idx in dfsort[:n_img_pairs]: i, tr_i = convert_index(idx, d) imgs_to_save += [bx[i, :, :, 0], bx_tr[tr_i, :, :, 0]] titles += [ 'Test Image: inf={:.8f}'.format(inf_val), 'Train Image: inf={:.8f}'.format(inf_val) ] utils.save_images(os.path.join(resdir, 'most_influential.pdf'), imgs_to_save, n_img_pairs, titles=titles) # For each test image, find the training image with the highest influence imgs_to_save = [] titles = [] n_img_pairs = 10 # Do this many pairs from the top of the test tensor for i in range(n_img_pairs): infl_i = d[i, :] max_train_ind = np.argmax(infl_i) max_infl_val = max(infl_i) imgs_to_save += [bx[i, :, :, 0], bx_tr[max_train_ind, :, :, 0]] titles += [ 'Test Image {:d}: inf={:.8f}'.format(i, max_infl_val), 'Train Image {:d}: inf={:.8f}'.format(max_train_ind, max_infl_val) ] utils.save_images(os.path.join(resdir, 'most_influential_by_test_img.pdf'), imgs_to_save, n_img_pairs, titles=titles)