def calc_all_grad_then_test(config, model, train_loader, test_loader):
    """Calculates the influence function by first calculating
    all grad_z, all s_test and then loading them to calc the influence"""

    outdir = Path(config["outdir"])
    s_test_outdir = outdir.joinpath("s_test/")
    if not s_test_outdir.exists():
        s_test_outdir.mkdir()
    grad_z_outdir = outdir.joinpath("grad_z/")
    if not grad_z_outdir.exists():
        grad_z_outdir.mkdir()

    influence_results = {}

    calc_s_test(
        model,
        test_loader,
        train_loader,
        s_test_outdir,
        config["gpu"],
        config["damp"],
        config["scale"],
        config["recursion_depth"],
        config["r_averaging"],
        config["test_start_index"],
    )
    calc_grad_z(model, train_loader, grad_z_outdir, config["gpu"],
                config["test_start_index"])

    train_dataset_len = len(train_loader.dataset)
    influences, harmful, helpful = calc_influence_function(train_dataset_len)

    influence_results["influences"] = influences
    influence_results["harmful"] = harmful
    influence_results["helpful"] = helpful
    influences_path = outdir.joinpath("influence_results.json")
    save_json(influence_results, influences_path)
def calc_img_wise(config, model, train_loader, test_loader):
    """Calculates the influence function one test point at a time. Calcualtes
    the `s_test` and `grad_z` values on the fly and discards them afterwards.

    Arguments:
        config: dict, contains the configuration from cli params"""
    influences_meta = copy.deepcopy(config)
    test_sample_num = config["test_sample_num"]
    test_start_index = config["test_start_index"]
    outdir = Path(config["outdir"])

    # If calculating the influence for a subset of the whole dataset,
    # calculate it evenly for the same number of samples from all classes.
    # `test_start_index` is `False` when it hasn't been set by the user. It can
    # also be set to `0`.
    if test_sample_num and test_start_index is not False:
        test_dataset_iter_len = test_sample_num * config["num_classes"]
        _, sample_list = get_dataset_sample_ids(test_sample_num, test_loader,
                                                config["num_classes"],
                                                test_start_index)
    else:
        test_dataset_iter_len = len(test_loader.dataset)

    # Set up logging and save the metadata conf file
    logging.info(f"Running on: {test_sample_num} images per class.")
    logging.info(f"Starting at img number: {test_start_index} per class.")
    influences_meta["test_sample_index_list"] = sample_list
    influences_meta_fn = (f"influences_results_meta_{test_start_index}-"
                          f"{test_sample_num}.json")
    influences_meta_path = outdir.joinpath(influences_meta_fn)
    save_json(influences_meta, influences_meta_path)

    influences = {}
    # Main loop for calculating the influence function one test sample per
    # iteration.
    for j in range(test_dataset_iter_len):
        # If we calculate evenly per class, choose the test img indicies
        # from the sample_list instead
        if test_sample_num and test_start_index:
            if j >= len(sample_list):
                logging.warning(
                    "ERROR: the test sample id is out of index of the"
                    " defined test set. Jumping to next test sample.")
            i = sample_list[j]
        else:
            i = j

        start_time = time.time()
        influence, harmful, helpful, _ = calc_influence_single(
            model,
            train_loader,
            test_loader,
            test_id_num=i,
            gpu=config["gpu"],
            recursion_depth=config["recursion_depth"],
            r=config["r_averaging"],
        )
        end_time = time.time()

        ###########
        # Different from `influence` above
        ###########
        influences[str(i)] = {}
        _, label = test_loader.dataset[i]
        influences[str(i)]["label"] = label
        influences[str(i)]["num_in_dataset"] = j
        influences[str(i)]["time_calc_influence_s"] = end_time - start_time
        infl = [x.cpu().numpy().tolist() for x in influence]
        influences[str(i)]["influence"] = infl
        influences[str(i)]["harmful"] = harmful[:500]
        influences[str(i)]["helpful"] = helpful[:500]

        tmp_influences_path = outdir.joinpath(f"influence_results_tmp_"
                                              f"{test_start_index}_"
                                              f"{test_sample_num}"
                                              f"_last-i_{i}.json")
        save_json(influences, tmp_influences_path)
        display_progress("Test samples processed: ", j, test_dataset_iter_len)

    logging.info(f"The results for this run are:")
    logging.info("Influences: ")
    logging.info(influence[:3])
    logging.info("Most harmful img IDs: ")
    logging.info(harmful[:3])
    logging.info("Most helpful img IDs: ")
    logging.info(helpful[:3])

    influences_path = outdir.joinpath(f"influence_results_{test_start_index}_"
                                      f"{test_sample_num}.json")
    save_json(influences, influences_path)