def extract_anglediffs_epoch(filedata):
    # data holders
    batch_counter = 0
    prev_angdiff = 0.

    # result-holder...
    epoch_anglediffs = []

    # iterate over the csv files
    for bidx, datanum in tqdm( \
        enumerate(sorted(filedata.keys())), desc='  [a-extract]', total=len(filedata.keys())):
        cur_filedata = filedata[datanum]

        # : [clean] load the correct data
        cur_cdata = np.array([])
        if 'clean' in cur_filedata:

            # :: load the entire data
            cur_cdata = []
            for each_filedata in cur_filedata['clean']:
                each_cdata = io.load_from_csv(each_filedata)
                each_cdata = [
                    list(map(float, each_cline)) for each_cline in each_cdata
                each_cdata = np.array(each_cdata)
            cur_cdata = sum(cur_cdata) / len(cur_cdata)

        # : [poison] load the poison data
        cur_pdata = np.array([])
        if 'poison' in cur_filedata:

            # :: load the entire data
            cur_pdata = []
            for each_filedata in cur_filedata['poison']:
                each_pdata = io.load_from_csv(each_filedata)
                each_pdata = [
                    list(map(float, each_pline)) for each_pline in each_pdata
                each_pdata = np.array(each_pdata)
            cur_pdata = sum(cur_pdata) / len(cur_pdata)

        # : compute the angle between them
        if (cur_cdata.size != 0) \
            and (cur_pdata.size != 0):
            cur_angdiff = _compute_cosine_similarly(cur_cdata, cur_pdata)
            prev_angdiff = cur_angdiff

        # : store them to the data-holder

    # end for bidx....
    astore_csvfile = '{}_epoch.csv'.format(astore_file)
    io.store_to_csv(astore_csvfile, epoch_anglediffs)
    print(' : Angle differences are stored to [{}]'.format(astore_csvfile))
def extract_magnitudes_epoch(filedata):
    # data holders
    batch_counter = 0
    prev_mclean = 0.
    prev_mpoison = None

    # result-holder...
    epoch_magnitudes = []

    # iterate over the csv files
    for bidx, datanum in tqdm( \
        enumerate(sorted(filedata.keys())), desc='  [m-extract]', total=len(filedata.keys())):
        cur_filedata = filedata[datanum]

        # : [clean] load the correct data
        cur_mclean = prev_mclean
        if 'clean' in cur_filedata:

            # :: load the entire data
            cur_cdata = []
            for each_filedata in cur_filedata['clean']:
                each_cdata = io.load_from_csv(each_filedata)
                each_cdata = [
                    list(map(float, each_cline)) for each_cline in each_cdata
                each_cdata = np.array(each_cdata)
            cur_cdata = sum(cur_cdata) / len(cur_cdata)
            cur_mclean = np.linalg.norm(cur_cdata, 2)
            prev_mclean = cur_mclean

        # : [poison] load the poison data
        cur_mpoison = prev_mpoison
        if 'poison' in cur_filedata:

            # :: load the entire data
            cur_pdata = []
            for each_filedata in cur_filedata['poison']:
                each_pdata = io.load_from_csv(each_filedata)
                each_pdata = [
                    list(map(float, each_pline)) for each_pline in each_pdata
                each_pdata = np.array(each_pdata)
            cur_pdata = sum(cur_pdata) / len(cur_pdata)
            cur_mpoison = np.linalg.norm(cur_pdata, 2)
            prev_mpoison = cur_mpoison

        # : store them to the data-holder
        epoch_magnitudes.append([cur_mclean, cur_mpoison])

    # end for bidx....
    mstore_csvfile = '{}_epoch.csv'.format(mstore_file)
    io.store_to_csv(mstore_csvfile, epoch_magnitudes)
    print(' : Magnitudes are stored to [{}]'.format(mstore_csvfile))
            # :: record the best accuracy
            if best_acc < current_acc:
                best_at = epoch
                best_acc = current_acc

            # :: report the current state (cannot compute the total eps, as we split the ....)
            print (' : Epoch {} - acc {:.4f} (base) / {:.4f} (curr) / {:.4f} (best @ {})'.format( \
                epoch, baseline_acc, current_acc, best_acc, best_at))

            # :: flush the stdouts

            # :: info
            print(' : Poison {}, Clean {}'.format(total_pupdates,

    # end for epoch...

    # report the attack results...
    print (' : [Result] epoch {}, poison {}, base {:.4f}, best {:.4f} @ {}'.format( \
        epoch, x_poison.shape[0], baseline_acc, best_acc, best_at))

    # store the attack results
    attack_results = [[best_at, best_acc, baseline_acc, x_poison.shape[0]]]
    io.store_to_csv(results_data, attack_results)

    # finally
    print(' : Done, don\'t store the model')
    # done.
def do_tpoisoning(arguments):

    # --------------------------------------------------------------------------
    #   Passed arguments
    # --------------------------------------------------------------------------
    task_num = arguments[0]
    task_queue = arguments[1]
    args = arguments[2]

    # ------------------------------------------------------------
    #  Tensorflow configurations (load TF here, main causes an error)
    # ------------------------------------------------------------
    import tensorflow as tf
    from tensorflow.python.client import device_lib
    from tensorflow.compat.v1.logging import set_verbosity, ERROR
    from tensorflow.compat.v1.estimator.inputs import numpy_input_fn

    # these will load the tensorflow module, so load it here
    from utils import datasets, models

    # control tensorflow info. level
    # ------------------------------------------------------------
    #  Level | Level for Humans | Level Description
    # -------|------------------|---------------------------------
    #  0     | DEBUG            | [Default] Print all messages
    #  1     | INFO             | Filter out INFO messages
    #  2     | WARNING          | Filter out INFO & WARNING messages
    #  3     | ERROR            | Filter out all messages

    # ------------------------------------------------------------
    #  Run control... (for the error cases)
    # ------------------------------------------------------------
    skip_data = True if args.fromtidx else False
    skip_poison = True if (args.frompidx >= 0) else False
        ' : [Task: {}] skip conditions, from: [{}th target] w. [{}th poison]'.
        format(task_num, skip_data, skip_poison))

    # --------------------------------------------------------------------------
    #   Use the sampled dataset, not the entire one
    # --------------------------------------------------------------------------
    if os.path.exists(args.samples):
        # : load the indexes from the csv file (that contains the list of ints)
        sample_indexes = io.load_from_csv(args.samples)[0]
        sample_indexes = list(map(int, sample_indexes))
        print(' : [Task: {}] consider [{}] target sampled from the entirety'.
              format(task_num, len(sample_indexes)))
        sample_indexes = []
        print(' : [Task: {}] do not sample the targets, consider all.'.format(

    # ------------------------------------------------------------
    #  Do poisoning attacks for each case
    # ------------------------------------------------------------
    for each_data in task_queue:
            Set store locations
        # extract the store location (ex. vanilla_conv.../10.0_2_2000....)
        store_dir = args.poisond.split('/')[4:]
        store_dir = '/'.join(store_dir)

        # the target index
        poison_toks = each_data.split('/')
        poison_tkey = poison_toks[-1].replace('.pkl', '')
        poison_tkey = poison_tkey.split('_')[-1]

        # when we use sampling, check if the indexes are in our interest
        if (sample_indexes) \
            and (int(poison_tkey) not in sample_indexes):
                ' : [Task: {}][Target: {}] is not in our samples, skip'.format(
                    task_num, poison_tkey))

        # result dir and the file to store
        results_dir = os.path.join('results', 'tpoisoning', 'clean-labels',
                                   args.attmode, store_dir)
        if not os.path.exists(results_dir): os.makedirs(results_dir)
        result_file = os.path.join(results_dir,
        print(' : [Task: {}][Target: {}] Store the result to [{}]'.format(
            task_num, poison_tkey, result_file))
            Skip the current data, based on the target index
        if (args.fromtidx == poison_tkey): skip_data = False
        if skip_data:
            print(' : [Task: {}][Target: {}] Skip this...'.format(
                task_num, poison_tkey))
            Load the attack data
        # : load the dataset
        (x_train, y_train), (x_test, y_test) =  \
            datasets.define_dataset(args.dataset, args.datapth)

        # : bound check for the inputs
        assert (x_train.min() == 0.) and (x_train.max() == 1.) \
            and (x_test.min() == 0.) and (x_test.max() == 1.)
        print (' : [Task: {}][Target: {}] Load the dataset [{}] from [{}]'.format( \
            task_num, poison_tkey, args.dataset, args.datapth))

        # : load the poisons
        (x_poisons, y_poisons), (x_target, y_target) = \
            datasets.load_poisons(each_data, x_test, y_test, sort=True)

        # : existence of the poisons
        if (x_poisons.size == 0) or (y_poisons.size == 0):
                ' : [Task: {}][Target: {}] Doesn\'t have poisons, skip'.format(
                    task_num, poison_tkey))

        # : bound check for the poisons
        assert (x_train.min() == 0.) and (x_train.max() == 1.) \
            and (x_test.min() == 0.) and (x_test.max() == 1.)
        print(' : [Task: {}][Target: {}] Load the poisons from [{}]'.format(
            task_num, poison_tkey, each_data))
            Blend poisons and re-train each model
            1) oneshot: consider only one poison at a time
            2) multipoison: consider multiple poisons at a time (0th ~ nth)
        # : condition to stop attack (once the attacker successes on a target)
        stop_attack = False

        # : decide how many poisons to use
        for pidx in range(len(x_poisons)):

            # :: skip, if the attack has been successful
            if stop_attack: continue

            # :: set the poison index
            poison_index = pidx + 1

            # :: consider max. the number of poisons specified
            if (args.poisonn > 0) \
                and (poison_index > args.poisonn):
                print (' : [Task: {}][Target: {}][{:>3}] Stop, # of poisons to consider is [{}]'.format( \
                    task_num, poison_tkey, poison_index, args.poisonn))

            # :: skip the current poison, based on the poison index
            if (args.frompidx == poison_index): skip_poison = False
            if skip_poison:
                print(' : [Task: {}][Target: {}][{:>3}] Skip this poison...'.
                      format(task_num, poison_tkey, poison_index))

            # :: cleanup directories in the previous runs
            _cleanup_directories(results_dir, poison_tkey)

            # :: copy the checkpoint to the result dir.
            result_pmodel = os.path.join(
                results_dir, '{}_{}'.format(poison_tkey, poison_index))
            shutil.copytree(args.netpath, result_pmodel)
            time.sleep(_wait_ops)  # delay for copying files
            print (' : [Task: {}][Target: {}][{:>3}] Copy the clean model to [{}]'.format( \
                task_num, poison_tkey, poison_index, result_pmodel))

            # :: tensorflow runtime configuration
            cur_rconf = tf.estimator.RunConfig(
                keep_checkpoint_max=1,  # 0 means all, do not use

            # :: extract the basic information from the model location
            mtokens = args.netpath.split('/')
            mtokens = mtokens[2].split('_')
            batch_size = int(mtokens[2])
            epochs = int(mtokens[3])
            if ('purchases' == args.dataset):
                epochs = epochs // 2
                epochs = 20 if (epochs > 20) else (epochs // 2)
            learn_rate = float(mtokens[4])

            # :: load the pre-trained model
            if not args.privacy:
                cur_model = models._load_vanilla_model( \
                    cur_rconf, \
                    args.dataset, args.datapth, args.network, result_pmodel, \
                    batch_size, learn_rate)
                print (' : [Task: {}][Target: {}][{:>3}] Load the '.format(task_num, poison_tkey, poison_index) + \
                        'pre-trained vanilla model from [{}]'.format(result_pmodel))
                # :: extract the extra information about privacy
                epsilon = float(mtokens[5])
                delta = float(mtokens[6])
                norm_clip = float(mtokens[7])
                noises = float(mtokens[8])

                # :: load the privacy model
                cur_model = models._load_dp_model( \
                    cur_rconf, \
                    args.dataset, args.datapth, x_train.shape[0], args.network, result_pmodel, \
                    batch_size, learn_rate, epsilon, delta, norm_clip, noises)
                print (' : [Task: {}][Target: {}][{:>3}] Load the '.format(task_num, poison_tkey, poison_index) + \
                        'pre-trained privacy model from [{}]'.format(result_pmodel))

            # :: blend poisons into the training data
            if 'oneshot' == args.attmode:
                cur_x_train = np.concatenate(
                    (x_train, x_poisons[poison_index - 1:poison_index]),
                cur_y_train = np.concatenate(
                    (y_train, y_poisons[poison_index - 1:poison_index]),
            elif 'multipoison' == args.attmode:
                cur_x_train = np.concatenate(
                    (x_train, x_poisons[:poison_index]), axis=0)
                cur_y_train = np.concatenate(
                    (y_train, y_poisons[:poison_index]), axis=0)
                assert False, ('Error: unknown attack mode - {}'.format(

            # :: create the estimator functions
            cur_train_fn = numpy_input_fn(x={'x': cur_x_train},
            cur_test_fn = numpy_input_fn(x={'x': x_test},
            cur_target_fn = numpy_input_fn(x={'x': x_target},

            # :: condition to remove the retrained model
            remove_pmodel = True

            # :: to compare the probability changes from the oracle
            oracle_predict = cur_model.predict(input_fn=cur_target_fn)
            oracle_predict = list(oracle_predict)[0]
            oracle_bas_prob = oracle_predict['probabilities'][args.b_class]
            oracle_tar_prob = oracle_predict['probabilities'][args.t_class]

            # :: re-train the network with the poisoning data
            cur_steps_per_epoch = cur_x_train.shape[0] // batch_size
            for cur_epoch in range(1, epochs + 1):

                # ::: train for an epoch
                cur_model.train( \
                    input_fn=cur_train_fn, steps=cur_steps_per_epoch)

                # ::: evaluate for one instance
                cur_predicts = cur_model.predict(input_fn=cur_target_fn)
                cur_predicts = list(cur_predicts)[0]
                cur_probs = cur_predicts['probabilities']
                cur_bas_prob = cur_predicts['probabilities'][args.b_class]
                cur_tar_prob = cur_predicts['probabilities'][args.t_class]

                # ::: check if we have the successful attack
                if (cur_predicts['classes'] == args.t_class):

                    # > validate the re-trained model
                    cur_predicts = cur_model.evaluate(input_fn=cur_test_fn)
                    cur_accuracy = cur_predicts['accuracy']

                    # > only compute the accuracy (when no privacy)
                    if not args.privacy:
                        # > store the data to a file
                        cur_result = [[poison_tkey, poison_index, \
                                        oracle_bas_prob, oracle_tar_prob, \
                                        cur_bas_prob, cur_tar_prob, \
                                        cur_epoch, cur_accuracy]]
                        io.store_to_csv(result_file, cur_result, mode='a')

                        # > notify
                        print (' : [Task: {}][Target: {}][{:>3}] epoch {} - attack success!'.format( \
                            task_num, poison_tkey, poison_index, cur_epoch))
                        print ('  - Prob [3:{:.4f} / 4:{:.4f}], acc [{:.4f}]'.format( \
                            cur_bas_prob, cur_tar_prob, cur_accuracy), flush=True)

                    # > compute the epsilon (when privacy)
                        cur_epsilon = models.compute_epsilon( \
                            cur_epoch * cur_steps_per_epoch, \
                            cur_x_train.shape[0], batch_size, delta, noises)

                        # > store the data to a file
                        cur_result = [[poison_tkey, poison_index, \
                                        oracle_bas_prob, oracle_tar_prob, \
                                        cur_bas_prob, cur_tar_prob, \
                                        cur_epoch, cur_accuracy, cur_epsilon]]
                        io.store_to_csv(result_file, cur_result, mode='a')

                        # > notify
                        print (' : [Task: {}][Target: {}][{:>3}] epoch {} - attack success!'.format( \
                            task_num, poison_tkey, poison_index, cur_epoch))
                        print ('  - Prob [3:{:.4f} / 4:{:.4f}], acc [{:.4f}], eps [{:.4f} <- {:.4f} + {:.4f}]'.format( \
                            cur_bas_prob, cur_tar_prob, cur_accuracy, cur_epsilon+epsilon, epsilon, cur_epsilon), flush=True)

                    # > stop the attack process (retain model and stop)
                    remove_pmodel = False
                    stop_attack = True

                # ::: if not successful
                    if (len(cur_probs) > 10): cur_probs = cur_probs[:10]
                    print (' : [Task: {}][Target: {}][{:>3}] epoch {} - attack fail, keep going - Prob [3:{:.4f} / 4:{:.4f}] - {}'.format( \
                        task_num, poison_tkey, poison_index, cur_epoch, cur_bas_prob, cur_tar_prob, cur_probs), flush=True)
                # ::: end if (cur_accuracy...

            # :: end for epoch...

            # :: remove model if it's true
            if remove_pmodel:
                shutil.rmtree(result_pmodel, ignore_errors=True)
                print(' : [Task: {}][Target: {}] Attack failed, remove [{}]'.
                      format(task_num, poison_tkey, result_pmodel))

            # :: reset the tensorflow graph for another run

        # : end for pidx...
    # end for aidx...

    print(' : [Task: {}] finished'.format(task_num))