def calc_all_grad_then_test(config, model, train_loader, test_loader):
    """Calculates the influence function by first calculating
    all grad_z, all s_test and then loading them to calc the influence"""

    outdir = Path(config['outdir'])
    s_test_outdir = outdir.joinpath("s_test/")
    if not s_test_outdir.exists():
        s_test_outdir.mkdir()
    grad_z_outdir = outdir.joinpath("grad_z/")
    if not grad_z_outdir.exists():
        grad_z_outdir.mkdir()

    influence_results = {}

    calc_s_test(model, test_loader, train_loader, s_test_outdir, config['gpu'],
                config['damp'], config['scale'], config['recursion_depth'],
                config['r_averaging'], config['test_start_index'])
    calc_grad_z(model, train_loader, grad_z_outdir, config['gpu'],
                config['test_start_index'])

    train_dataset_len = len(train_loader.dataset)
    influences, harmful, helpful = calc_influence_function(train_dataset_len)

    influence_results['influences'] = influences
    influence_results['harmful'] = harmful
    influence_results['helpful'] = helpful
    influences_path = outdir.joinpath("influence_results.json")
    save_json(influence_results, influences_path)
Beispiel #2
0
def run_trial(results, mode, domain_adaptation_task, sample_per_class,
              src_repetition, tgt_repetition, batch, epochs, retrain_epochs,
              alpha):

    if domain_adaptation_task in ['USPS_to_MNIST', 'MNIST_to_USPS']:
        class_num = 10

    config = get_default_config()
    if not os.path.exists(config['outdir']):
        os.mkdir(config['outdir'])

    device = torch.device("cuda")
    repetition = (src_repetition, tgt_repetition)

    results[str(repetition)] = {}
    train_set_base = TrainSet(domain_adaptation_task, 'baseline',
                              src_repetition, tgt_repetition, sample_per_class)
    train_set_base_loader = DataLoader(train_set_base,
                                       batch_size=batch,
                                       shuffle=True)
    test_set = TestSet(domain_adaptation_task, 0, sample_per_class)
    test_set_loader = DataLoader(test_set, batch_size=batch, shuffle=True)
    print("Dataset Length Train (baseline) : ", len(train_set_base),
          " Test : ", len(test_set))

    if not os.path.exists('saved_models'):
        os.mkdir('saved_models')

    mplain_PATH = f'./saved_models/{domain_adaptation_task}_plain_s{str(sample_per_class)}'\
                  f'_r{str(repetition)}_b{str(batch)}_e{str(epochs)}_a{str(alpha)}'

    mbase_PATH = f'./saved_models/{domain_adaptation_task}_baseline_s{str(sample_per_class)}'\
                 f'_r{str(repetition)}_b{str(batch)}_e{str(epochs)}_a{str(alpha)}.pth'

    NEW_PATH = f'./saved_models/re_{domain_adaptation_task}_{mode}_s{str(sample_per_class)}'\
               f'_r{str(repetition)}_b{str(batch)}_e{str(epochs)}_a{str(alpha)}.pth'

    influence_path = config[
        'outdir'] + f'/influence_results_r{str(repetition)}.json'

    if mode in ['all', 'train_baseline']:
        results[str(repetition)]['baseline_acc'] = {}
        display = []
        sum_acc = 0
        for i in range(10):
            net = Network().to(device)

            best_test_acc = 0
            best_epoch = 0
            for epoch in range(epochs):
                train_loss = train(device, net, train_set_base_loader)
                test_acc = test(device, net, test_set_loader)
                if test_acc > best_test_acc:
                    best_test_acc = test_acc
                    best_epoch = epoch

            results[str(repetition)]['baseline_acc'][i] = (best_epoch,
                                                           best_test_acc)
            display.append(best_test_acc)
            print(f'repetition [{str(repetition)}] at time {str(i)}th training best acc at epoch'\
                    f'{str(best_epoch)}: {str(best_test_acc)}')
            sum_acc += best_test_acc

        results[str(repetition)]['baseline_acc']['avg_acc'] = sum_acc / 10
        display.append(sum_acc / 10)
        print(domain_adaptation_task,
              f' repetition [{str(repetition)}] avg test acc: ', sum_acc / 10)
        for i in display:
            print(i)

    if mode in ['all', 'influence']:
        infl_src = InflSet(domain_adaptation_task, 'source', src_repetition,
                           sample_per_class)
        infl_src_loader = DataLoader(infl_src, batch_size=batch)
        infl_tgt = InflSet(domain_adaptation_task, 'target', tgt_repetition,
                           sample_per_class)
        infl_tgt_loader = DataLoader(infl_tgt, batch_size=batch)
        print("infl_src : ", len(infl_src), "infl_tgt : ", len(infl_tgt))

        results[str(repetition)]['plain_acc'] = {}
        results[str(repetition)]['influence'] = {}
        infl_sum = np.zeros([len(infl_tgt), len(infl_src)])
        for i in range(5):
            net = Network().to(device)
            best_acc = 0
            best_epoch = 0
            for epoch in range(epochs):
                train_loss = train(device, net, infl_src_loader)
                test_acc = test(device, net, test_set_loader)
                if test_acc > best_acc:
                    best_acc = test_acc
                    best_epoch = epoch
                    save_model(net, mplain_PATH + f'v{str(i)}.pth')

            results[str(repetition)]['plain_acc'][i] = (best_epoch, best_acc)
            print(f'repetition [{str(repetition)}] at time {str(i)}th plain model best acc at epoch'\
                    f'{str(best_epoch)}: {str(best_acc)}')

            net = load_model(mplain_PATH + f'v{str(i)}.pth')

            # fine-tune the net with 10 target examples, spc = 1

            train_set = TestSet(domain_adaptation_task, 1, 1)
            val_set = TestSet(domain_adaptation_task, 2, 1)
            train_set_indices = np.random.permutation(
                len(train_set))[:len(train_set)]
            val_set_indices = np.random.permutation(len(val_set))[:100]
            train_loader = DataLoader(
                train_set,
                batch_size=1,
                shuffle=False,
                sampler=SubsetRandomSampler(train_set_indices))
            val_loader = DataLoader(
                val_set,
                batch_size=2,
                shuffle=False,
                sampler=SubsetRandomSampler(val_set_indices))
            print("fine_tuning the net...")
            train_fine_tune(net, train_loader, val_loader, 200)

            #################

            infl_arr = calc_img_wise(config, net, infl_src_loader,
                                     infl_tgt_loader, i)
            # results[str(repetition)]['influence'][i] = infl_arr.tolist()
            infl_sum = np.add(infl_sum, infl_arr)

        infl_avg = infl_sum / 5
        results[str(repetition)]['influence']['avg'] = infl_avg.tolist()
        for target in range(5):
            print("Results stats for target", target)
            acs = np.sort(infl_avg[target])
            print(acs[:10])
            print(acs[-10:])
            print("median", np.median(infl_avg[target]))
            print("mean", np.mean(infl_avg[target]))
            print("std", np.std(infl_avg[target]))

        save_json(infl_avg.tolist(), influence_path)

    if mode in ['all', 'stats']:

        with open(influence_path) as json_file:
            data = json.load(json_file)

        counter = np.zeros((8, len(data[0])))

        for i in range(sample_per_class * 10):
            infl = data[i]
            std = statistics.stdev(infl)
            avg = np.mean(infl)
            print("std for target image", i, ":", std)

            counter[0] += infl
            counter[1] += [
                1 if x > avg + 2 * std else -1 if x < avg - 2 * std else 0
                for x in infl
            ]
            counter[2] += [-1 if x < -1 * std else 0 for x in infl]
            counter[3] += [-1 if x < -2 * std else 0 for x in infl]
            counter[4] += [1 if x > std else 0 for x in infl]
            counter[5] += [-1 if abs(x) > 2 * std else 0 for x in infl]

        # counter 6
        counter[6][counter[0] < np.percentile(counter[0], 10)] = 1
        counter[6][(counter[0] >= np.percentile(counter[0], 10))
                   & (counter[0] < np.percentile(counter[0], 50))] = 2
        counter[6][(counter[0] >= np.percentile(counter[0], 50))
                   & (counter[0] <= np.percentile(counter[0], 90))] = 3
        counter[6][counter[0] > np.percentile(counter[0], 90)] = 5

        # counter 7
        counter[7][counter[0] < np.percentile(counter[0], 10)] = 1
        counter[7][(counter[0] >= np.percentile(counter[0], 10))
                   & (counter[0] < np.percentile(counter[0], 30))] = 2
        counter[7][(counter[0] >= np.percentile(counter[0], 30))
                   & (counter[0] < np.percentile(counter[0], 50))] = 3
        counter[7][(counter[0] >= np.percentile(counter[0], 50))
                   & (counter[0] < np.percentile(counter[0], 70))] = 4
        counter[7][(counter[0] >= np.percentile(counter[0], 70))
                   & (counter[0] <= np.percentile(counter[0], 90))] = 5
        counter[7][counter[0] > np.percentile(counter[0], 90)] = 7

        # remove stragety

        th0 = np.percentile(counter[0], 2)
        removed0 = np.where(counter[0] < th0)[0]

        th1 = np.percentile(counter[1], 2)
        removed1 = np.where(counter[1] < th1)[0]

        th2 = np.percentile(counter[2], 2)
        removed2 = np.where(counter[2] < th2)[0]

        th3 = np.percentile(counter[3], 2)
        removed3 = np.where(counter[3] < th3)[0]

        th4 = np.percentile(counter[4], 2)
        removed4 = np.where(counter[4] < th4)[0]

        th5 = np.percentile(counter[5], 2)
        removed5 = np.where(counter[5] < th5)[0]

        # counter 6
        th6 = np.percentile(counter[6], 2)
        removed6 = np.where(counter[6] < th6)[0]

        th7 = np.percentile(counter[6], 5)
        removed7 = np.where(counter[6] < th7)[0]

        th8 = np.percentile(counter[6], 95)
        removed8 = np.where(counter[6] > th8)[0]

        th9 = np.percentile(counter[6], 98)
        removed9 = np.where(counter[6] > th9)[0]

        # counter 7
        th10 = np.percentile(counter[7], 2)
        removed10 = np.where(counter[7] < th10)[0]

        th11 = np.percentile(counter[7], 5)
        removed11 = np.where(counter[7] < th11)[0]

        th12 = np.percentile(counter[7], 95)
        removed12 = np.where(counter[7] > th12)[0]

        th13 = np.percentile(counter[7], 98)
        removed13 = np.where(counter[7] > th13)[0]

        infl_src = InflSet(domain_adaptation_task, 'source', src_repetition,
                           sample_per_class)

        print("infl_src : ", len(infl_src))

        removed_random_5 = np.random.permutation(np.arange(
            len(infl_src)))[:int(len(infl_src) * (5 * 0.01))]

        removed_random_2 = np.random.permutation(np.arange(
            len(infl_src)))[:int(len(infl_src) * (2 * 0.01))]

        removed_random_5e = np.random.permutation(np.arange(len(infl_src)))[:5]

        # sample rate stragety: use counter[i] as weight

        sample_weight = {}
        sample_weight['pure_sum'] = (counter[0] + abs(np.amin(counter[0])) +
                                     0.1).tolist()
        sample_weight['tri_2std'] = (counter[1] + abs(np.amin(counter[1])) +
                                     0.1).tolist()
        # sample_weight['bi_neg_1std'] = (counter[2] + abs(np.amin(counter[2])) + 0.1).tolist()
        # sample_weight['bi_neg_2std'] = (counter[3] + abs(np.amin(counter[3])) + 0.1).tolist()
        # sample_weight['bi_pos_1std'] = (counter[4] + abs(np.amin(counter[4])) + 0.1).tolist()
        # sample_weight['abs_2std'] = (counter[5] + abs(np.amin(counter[5])) + 0.1).tolist()
        sample_weight['4_seg'] = counter[6].tolist()
        sample_weight['6_seg'] = counter[7].tolist()
        # sample_weight['random_weight1'] = np.random.permutation(counter[0] + abs(np.amin(counter[0])) + 0.1).tolist()
        sample_weight['random_weight2'] = np.random.rand(
            counter[0].shape[0]).tolist()

        sample_weight_path = config['outdir']+f'/sample_weight_{domain_adaptation_task}'\
                        f'_s{str(sample_per_class)}_r{str(repetition)}.json'
        save_json(sample_weight, sample_weight_path)

        stats = {}
        # stats['pure_sum'] = removed0.tolist()
        stats['tri_2std'] = removed1.tolist()
        # stats['bi_neg_1std'] = removed2.tolist()
        # stats['bi_neg_2std'] = removed3.tolist()
        # stats['bi_pos_1std'] = removed4.tolist()
        # stats['abs_2std'] = removed5.tolist()
        # stats['4_seg_2per'] = removed6.tolist()
        # stats['4_seg_5per'] = removed7.tolist()
        # stats['4_seg_95per'] = removed8.tolist()
        # stats['4_seg_98per'] = removed9.tolist()
        # stats['6_seg_2per'] = removed10.tolist()
        # stats['6_seg_5per'] = removed11.tolist()
        # stats['6_seg_95per'] = removed12.tolist()
        # stats['6_seg_98per'] = removed13.tolist()
        stats['random_reomove_2per'] = removed_random_2.tolist()
        # stats['random_reomove_5per'] = removed_random_5.tolist()
        # stats['random_remove_5example'] = removed_random_5e.tolist()
        time = dt.now().strftime("%Y-%m-%d-%H-%M-%S")
        stats_path = config['outdir']+f'/infl_std_stats_{domain_adaptation_task}'\
                        f'_s{str(sample_per_class)}_r{str(repetition)}.json'
        save_json(stats, stats_path)

        # if mode in ['all', 'retrain']:
        # stats_path = config['outdir']+f'infl_stats_{domain_adaptation_task}'\
        #                 f'_s{str(sample_per_class)}_r{str(repetition)}.json'

        with open(sample_weight_path) as json_file:
            sample_weight = json.load(json_file)

        with open(stats_path) as json_file:
            data = json.load(json_file)

        device = torch.device("cuda")
        net = Network().to(device)

        # remove indices
        results[str(repetition)]['retrain_remove_indices'] = {}
        for l in data:
            results[str(repetition)]['retrain_remove_indices'][l] = {}
            removed_indices = data[l]
            retrain_set = TrainSet(domain_adaptation_task, 'baseline',
                                   src_repetition, tgt_repetition,
                                   sample_per_class, removed_indices)
            retrain_set_loader = DataLoader(retrain_set,
                                            batch_size=batch,
                                            shuffle=True)
            excel = []
            sum_re_acc = 0
            for i in range(10):

                net = Network().to(device)
                best_test_acc = 0
                for epoch in range(epochs):
                    train_loss = train(device, net, retrain_set_loader)
                    test_acc = test(device, net, test_set_loader)
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        save_model(net, NEW_PATH)
                results[str(repetition
                            )]['retrain_remove_indices'][l][i] = best_test_acc
                excel.append(best_test_acc)
                print(
                    l + ' at time ' + str(i) +
                    ' has retraining best test acc :', best_test_acc)
                sum_re_acc += best_test_acc
            results[str(repetition)]['retrain_remove_indices'][l][
                'avg_re_acc'] = sum_re_acc / 10
            print(l + ' has average retrain test acc :', sum_re_acc / 10)
            excel.append(sum_re_acc / 10)
            for i in excel:
                print(i)

    # sample weight
        results[str(repetition)]['retrain_sample_weight'] = {}
        for w in sample_weight:
            results[str(repetition)]['retrain_sample_weight'][w] = {}
            m = max(sample_weight[w]) * 1.2
            weight = sample_weight[w] + [m] * 10 * sample_per_class
            sampler = WeightedRandomSampler(weight,
                                            len(sample_weight[w]),
                                            replacement=True)
            retrain_set = TrainSet(domain_adaptation_task, 'baseline',
                                   src_repetition, tgt_repetition,
                                   sample_per_class)
            retrain_set_loader = DataLoader(retrain_set,
                                            batch_size=batch,
                                            shuffle=False,
                                            sampler=sampler)

            sum_re_acc = 0
            excel = []
            for i in range(10):
                net = Network().to(device)
                best_test_acc = 0
                for epoch in range(epochs):
                    train_loss = train(device, net, retrain_set_loader)
                    test_acc = test(device, net, test_set_loader)
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        # save_model(net, NEW_PATH)
                results[str(
                    repetition)]['retrain_sample_weight'][w][i] = best_test_acc
                print(
                    w + ' at time ' + str(i) +
                    ' has retraining best test acc :', best_test_acc)
                excel.append(best_test_acc)
                sum_re_acc += best_test_acc
            results[str(repetition)]['retrain_sample_weight'][w][
                'avg_re_acc'] = sum_re_acc / 10
            print(w + ' average retrain test acc :', sum_re_acc / 10)
            excel.append(sum_re_acc / 10)
            for i in excel:
                print(i)

    # sample weight + cssa
        results[str(repetition)]['retrain_sample_weight_CSSA'] = {}
        for w in sample_weight:
            results[str(repetition)]['retrain_sample_weight_CSSA'][w] = {}
            weight = sample_weight[w]
            retrain_set = TrainSet_CSSA(domain_adaptation_task,
                                        src_repetition,
                                        tgt_repetition,
                                        sample_per_class,
                                        weights=weight)
            sample_rate = retrain_set.weights
            sampler = WeightedRandomSampler(sample_rate,
                                            len(sample_weight[w]),
                                            replacement=True)
            retrain_set_loader = DataLoader(retrain_set,
                                            batch_size=batch,
                                            shuffle=False,
                                            sampler=sampler)
            # retrain_set_loader = DataLoader(retrain_set, batch_size=batch, shuffle=True)

            sum_re_acc = 0
            excel = []
            for i in range(10):
                net = Network().to(device)
                best_test_acc = 0
                for epoch in range(epochs):
                    train_loss = train_CSSA(device, net, retrain_set_loader,
                                            'CSSA')
                    test_acc = test(device, net, test_set_loader)
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        # save_model(net, NEW_PATH)
                results[str(repetition)]['retrain_sample_weight_CSSA'][w][
                    i] = best_test_acc
                print(
                    w + ' at time ' + str(i) +
                    ' has retraining best test acc :', best_test_acc)
                excel.append(best_test_acc)
                sum_re_acc += best_test_acc
            results[str(repetition)]['retrain_sample_weight_CSSA'][w][
                'avg_re_acc'] = sum_re_acc / 10
            print(w + ' average retrain test acc :', sum_re_acc / 10)
            excel.append(sum_re_acc / 10)
            for i in excel:
                print(i)

        print("\n plain cssa \n")
        # plain cssa
        results[str(repetition)]['retrain_CSSA'] = {}
        for w in sample_weight:
            results[str(repetition)]['retrain_CSSA'][w] = {}
            weight = sample_weight[w]
            retrain_set = TrainSet_CSSA(domain_adaptation_task,
                                        src_repetition,
                                        tgt_repetition,
                                        sample_per_class,
                                        weights=weight)
            # sample_rate = retrain_set.weights
            # sampler = WeightedRandomSampler(sample_rate, len(sample_weight[w]), replacement=True)
            # retrain_set_loader = DataLoader(retrain_set, batch_size=batch, shuffle=False, sampler=sampler)
            retrain_set_loader = DataLoader(retrain_set,
                                            batch_size=batch,
                                            shuffle=True)

            sum_re_acc = 0
            excel = []
            for i in range(10):
                net = Network().to(device)
                best_test_acc = 0
                for epoch in range(epochs):
                    train_loss = train_CSSA(device, net, retrain_set_loader,
                                            'CSSA')
                    test_acc = test(device, net, test_set_loader)
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        # save_model(net, NEW_PATH)
                results[str(repetition)]['retrain_CSSA'][w][i] = best_test_acc
                print(
                    w + ' at time ' + str(i) +
                    ' has retraining best test acc :', best_test_acc)
                excel.append(best_test_acc)
                sum_re_acc += best_test_acc
            results[str(
                repetition)]['retrain_CSSA'][w]['avg_re_acc'] = sum_re_acc / 10
            print(w + ' average retrain test acc :', sum_re_acc / 10)
            excel.append(sum_re_acc / 10)
            for i in excel:
                print(i)

            break

    return results
def calc_img_wise(config, model, train_loader, test_loader):
    """Calculates the influence function one test point at a time. Calcualtes
    the `s_test` and `grad_z` values on the fly and discards them afterwards.

    Arguments:
        config: dict, contains the configuration from cli params"""
    influences_meta = copy.deepcopy(config)
    test_sample_num = config['test_sample_num']
    test_start_index = config['test_start_index']
    outdir = Path(config['outdir'])

    # If calculating the influence for a subset of the whole dataset,
    # calculate it evenly for the same number of samples from all classes.
    # `test_start_index` is `False` when it hasn't been set by the user. It can
    # also be set to `0`.
    if test_sample_num and test_start_index is not False:
        test_dataset_iter_len = test_sample_num * config['num_classes']
        _, sample_list = get_dataset_sample_ids(test_sample_num, test_loader,
                                                config['num_classes'],
                                                test_start_index)
    else:
        test_dataset_iter_len = len(test_loader.dataset)

    # Set up logging and save the metadata conf file
    logging.info(f"Running on: {test_sample_num} images per class.")
    logging.info(f"Starting at img number: {test_start_index} per class.")
    influences_meta['test_sample_index_list'] = sample_list
    influences_meta_fn = f"influences_results_meta_{test_start_index}-" \
                         f"{test_sample_num}.json"
    influences_meta_path = outdir.joinpath(influences_meta_fn)
    save_json(influences_meta, influences_meta_path)

    influences = {}
    # Main loop for calculating the influence function one test sample per
    # iteration.
    for j in range(test_dataset_iter_len):
        # If we calculate evenly per class, choose the test img indicies
        # from the sample_list instead
        if test_sample_num and test_start_index:
            if j >= len(sample_list):
                logging.warn("ERROR: the test sample id is out of index of the"
                             " defined test set. Jumping to next test sample.")
                next
            i = sample_list[j]
        else:
            i = j

        start_time = time.time()
        influence, harmful, helpful, _ = calc_influence_single(
            model,
            train_loader,
            test_loader,
            test_id_num=i,
            gpu=0,
            recursion_depth=config['recursion_depth'],
            r=config['r_averaging'])
        end_time = time.time()

        ###########
        # Different from `influence` above
        ###########
        influences[str(i)] = {}
        _, label = test_loader.dataset[i]
        influences[str(i)]['label'] = label
        influences[str(i)]['num_in_dataset'] = j
        influences[str(i)]['time_calc_influence_s'] = end_time - start_time
        infl = [x.cpu().numpy().tolist() for x in influence]
        influences[str(i)]['influence'] = infl
        influences[str(i)]['harmful'] = harmful[:500]
        influences[str(i)]['helpful'] = helpful[:500]

        tmp_influences_path = outdir.joinpath(f"influence_results_tmp_"
                                              f"{test_start_index}_"
                                              f"{test_sample_num}"
                                              f"_last-i_{i}.json")
        save_json(influences, tmp_influences_path)
        display_progress("Test samples processed: ", j, test_dataset_iter_len)

    logging.info(f"The results for this run are:")
    logging.info("Influences: ")
    logging.info(influence[:3])
    logging.info("Most harmful img IDs: ")
    logging.info(harmful[:3])
    logging.info("Most helpful img IDs: ")
    logging.info(helpful[:3])

    # infl = [x.cpu().numpy().tolist() for x in influence]
    influences_path = outdir.joinpath(f"influence_results_{test_start_index}_"
                                      f"{test_sample_num}.json")
    save_json(influences, influences_path)
Beispiel #4
0
    domain_adaptation_task = 'USPS_to_MNIST'
    sample_per_class = 1
    src_repetition = [[4]]
    tgt_repetition = [[4]]
    batch = 128
    epochs = 320
    retrain_epochs = 200
    alpha = 0.25
    logfile = 'log.txt'
    config = get_default_config()
    results = {}

    for i in range(1):
        results = run_trial(results, mode, domain_adaptation_task,
                            sample_per_class, src_repetition[i],
                            tgt_repetition[i], batch, epochs, retrain_epochs,
                            alpha)

    results_path = config['outdir']+f'/results_{domain_adaptation_task}'\
                    f'_s{str(sample_per_class)}_r{str((src_repetition, tgt_repetition))}.json'
    save_json(results, results_path)
    # cnt = 0
    # # for domain_adaptation_task in ['USPS_to_MNIST', 'MNIST_to_USPS']:
    # for domain_adaptation_task in ['MNIST_to_USPS']:
    #     for sample_per_class in range(1, 3):
    #         for repetition in range(1, 3):
    #             cnt += 1
    #             print(cnt)
    #             weight_strategies = ['remove_most_beneficial', 'remove_most_harmful', 'remove_most_influential', 'remove_no_influential']
    #             remove_percs = [5, 20, 40]
    #             run_trial(mode, domain_adaptation_task, sample_per_class, repetition, weight_strategies, method, sampling, remove_percs, batch, epochs, retrain_epochs, alpha, logfile)