예제 #1
0
파일: dataset.py 프로젝트: 9yte/VenoMave
def preprocess_dataset(model_type,
                       data_dir,
                       feature_parameters,
                       device='cuda'):
    def load_raw_data_dir(dataset_dir, device='cuda'):
        dataset_dir = dataset_dir.resolve()  # To resolve symlinks!
        # find raw data
        wav_files = [
            f for f in sorted(
                dataset_dir.joinpath('wav').resolve().glob('*.wav'))
        ]
        praat_files = [
            f for f in sorted(
                dataset_dir.joinpath('TextGrid').resolve().glob('*.TextGrid'))
        ]
        lab_files = [
            f for f in sorted(
                dataset_dir.joinpath('lab').resolve().glob('*.lab'))
        ]

        # load raw data
        X = []
        Y = []
        texts = []
        for wav_file, praat_file, lab_file in tqdm(
                zip(wav_files, praat_files, lab_files),
                total=len(wav_files),
                bar_format='    load raw     {l_bar}{bar:30}{r_bar}'):
            # sanity check
            assert wav_file.stem == praat_file.stem == lab_file.stem
            ## load x
            x, _ = torchaudio.load(wav_file)
            # round to the next `full` frame
            num_frames = np.floor(x.shape[1] / hop_size_samples)
            x = x[:, :int(num_frames * hop_size_samples)].to(device)
            X.append(x)
            ## load y
            # optional: convert praats into jsons
            # dataset_dir.joinpath('align').mkdir(parents=True, exist_ok=True)
            # tg = tgio.openTextgrid(praat_file)
            # align_dict = tools.textgrid_to_dict(tg)
            # json_file = Path(str(praat_file).replace('TextGrid', 'align')).with_suffix('.json')
            # json_file.write_text(json.dumps(align_dict, indent=4))
            # y = tools.json_file_to_target(json_file, sampling_rate, window_size_samples, hop_size_samples, hmm)
            y = tools.praat_file_to_target(praat_file, sampling_rate,
                                           window_size_samples,
                                           hop_size_samples, hmm)
            y = torch.from_numpy(y).to(device)
            Y.append(y)
            ## load text
            text = lab_file.read_text().strip()
            texts.append(text)
        return wav_files, X, Y, texts

    """
    Creates two datasets:
        - plain is simply a pre-processed version of TIDIGITS 
        - aligned replaces the targets Y with more precise targets (obtained via viterbi training)

    """

    # check if data dir exist
    raw_data_dir = Path(data_dir).joinpath('raw')
    assert raw_data_dir.is_dir()

    # data config
    sampling_rate = feature_parameters['sampling_rate']
    window_size_samples = tools.next_pow2_samples(
        feature_parameters['window_size'], sampling_rate)
    hop_size_samples = tools.sec_to_samples(feature_parameters['hop_size'],
                                            sampling_rate)

    # check if dataset is already pre-processed
    plain_out_dir = Path(data_dir).joinpath(model_type, 'plain')
    aligend_out_dir = Path(data_dir).joinpath(model_type, 'aligned')
    if plain_out_dir.joinpath('hmm.h5').is_file() and aligend_out_dir.joinpath(
            'hmm.h5').is_file():
        logging.info(f"[+] Dataset already pre-processed")
        return
    shutil.rmtree(plain_out_dir, ignore_errors=True)
    plain_out_dir.mkdir(parents=True)
    shutil.rmtree(aligend_out_dir, ignore_errors=True)
    aligend_out_dir.mkdir(parents=True)

    # Step 1: plain data
    # -> wavs are split into individual frames (the Xs)
    # -> each frame is mapped to the corresponding target state
    #    of the hmm (the Ys)
    #
    # As these targets are always depend on a particular hmm,
    # we save the hmm alongside with the data
    hmm = HMM.HMM('word')
    pickle.dump(hmm, plain_out_dir.joinpath('hmm.h5').open('wb'))

    # pre-proccess plain data
    dataset_names = [
        d.name for d in Path(raw_data_dir).glob('*') if d.is_dir()
    ]
    for dataset_name in dataset_names:
        logging.info(f"[+] Pre-process {dataset_name}")
        wav_files, X, Y, texts = load_raw_data_dir(
            raw_data_dir.joinpath(dataset_name))
        ## dump plain
        X_out_dir = plain_out_dir.joinpath(dataset_name, 'X')
        X_out_dir.mkdir(parents=True)
        Y_out_dir = plain_out_dir.joinpath(dataset_name, 'Y')
        Y_out_dir.mkdir(parents=True)
        text_out_dir = plain_out_dir.joinpath(dataset_name, 'text')
        text_out_dir.mkdir(parents=True)
        wav_out_dir = plain_out_dir.joinpath(dataset_name, 'wavs')
        wav_out_dir.mkdir(parents=True)
        for wav_file, x, y, text in tqdm(
                zip(wav_files, X, Y, texts),
                total=len(wav_files),
                bar_format='    dump plain  {l_bar}{bar:30}{r_bar}'):
            filename = wav_file.stem
            torch.save(y, Y_out_dir.joinpath(filename).with_suffix('.pt'))
            torch.save(x, X_out_dir.joinpath(filename).with_suffix('.pt'))
            text_out_dir.joinpath(filename).with_suffix('.txt').write_text(
                text)
            shutil.copyfile(wav_file,
                            wav_out_dir.joinpath(filename).with_suffix('.wav'))

    # Step 2: align data
    # -> for the plain data we only used relatively vague alignements between
    #    input frame and target
    # -> to improve this we create a second dataset that uses a hmm
    #    that is trained with viterbi to obtain more precise alignments

    # first we need to train the hmm with viterbi training
    dataset = Dataset(plain_out_dir.joinpath('TRAIN'), feature_parameters)
    model = init_model(model_type, feature_parameters, hmm)
    model.train_model(dataset, epochs=12, batch_size=32)
    model.train_model(dataset, epochs=1, batch_size=32, viterbi_training=True)
    model.hmm.A = hmm.modifyTransitions(model.hmm.A_count)
    model.train_model(dataset, epochs=2, batch_size=32, viterbi_training=True)
    # again, save hmm alongside the data
    pickle.dump(hmm, aligend_out_dir.joinpath('hmm.h5').open('wb'))

    # pre-proccess aligned data
    dataset_names = [
        d.name for d in Path(raw_data_dir).glob('*') if d.is_dir()
    ]
    for dataset_name in dataset_names:
        logging.info(f"[+] Pre-process {dataset_name}")
        # wav_files, X, Y, texts = load_raw_data_dir(raw_data_dir.joinpath(dataset_name), device=device)
        dst_path = plain_out_dir.joinpath(dataset_name)
        dataset = Dataset(dst_path, feature_parameters)
        ## dump plain
        X_out_dir = aligend_out_dir.joinpath(dataset_name, 'X')
        X_out_dir.mkdir(parents=True)
        Y_out_dir = aligend_out_dir.joinpath(dataset_name, 'Y')
        Y_out_dir.mkdir(parents=True)
        text_out_dir = aligend_out_dir.joinpath(dataset_name, 'text')
        text_out_dir.mkdir(parents=True)
        wav_out_dir = aligend_out_dir.joinpath(dataset_name, 'wavs')
        wav_out_dir.mkdir(parents=True)

        with tqdm(
                total=len(wav_files),
                bar_format='    dump aligned {l_bar}{bar:30}{r_bar}') as pbar:
            for X_batch, Y_batch, texts_batch, y_true_length, x_true_length, filenames in dataset.generator(
                    return_filename=True, batch_size=32, return_x_length=True):

                posteriors = model.features_to_posteriors(X_batch)
                Y_batch = hmm.viterbi_train(posteriors, y_true_length, Y_batch,
                                            texts_batch)

                for filename, x, y, y_length, x_length, text in zip(
                        filenames, X_batch, Y_batch, y_true_length,
                        x_true_length, texts_batch):
                    torch.save(y.clone()[:y_length],
                               Y_out_dir.joinpath(filename).with_suffix('.pt'))
                    torch.save(x.clone()[:x_length].unsqueeze(dim=0),
                               X_out_dir.joinpath(filename).with_suffix('.pt'))
                    text_out_dir.joinpath(filename).with_suffix(
                        '.txt').write_text(text)
                    shutil.copyfile(
                        dst_path.joinpath('wavs',
                                          filename).with_suffix('.wav'),
                        wav_out_dir.joinpath(filename).with_suffix('.wav'))
                    pbar.update(1)
예제 #2
0
파일: eval.py 프로젝트: 9yte/VenoMave
def eval_victim(model_type, feature_parameters, dataset, dataset_test, target,
                dropout):

    res = {}

    model = init_model(model_type,
                       feature_parameters,
                       dataset.hmm,
                       dropout=dropout).cuda()
    model.train_model(dataset, epochs=15)

    # benign accuracy of victim model
    model_acc, test_res = model.parallel_test(dataset_test)

    # Poisons Classification Loss.
    poisons_x, poisons_y, poisons_true_length, poisons_imp_indices, _ = dataset.poisons.get_all_poisons(
    )
    poisons_classification_loss, poisons_imp_indices_classification_loss = \
        model.compute_loss_batch(poisons_x, poisons_y, poisons_true_length, important_indices=poisons_imp_indices)

    # similarity to original and target states
    loss_original_states = model.compute_loss_single(
        target.x, target.original_states).item()
    loss_adversarial_states = model.compute_loss_single(
        target.x, target.adversarial_states, target.adv_indices).item()

    # predicted transcription
    posteriors = model.features_to_posteriors(target.x)
    pred_phoneme_seq, victim_states_viterbi = model.hmm.posteriors_to_words(
        posteriors)
    pred_phoneme_seq = tools.str_to_digits(pred_phoneme_seq)
    victim_states_argmax = np.argmax(posteriors, axis=1)

    # target states
    target_states = target.adversarial_states

    victim_adv_states_acc = (100.0 * sum([
        v == t for v, t in zip(victim_states_argmax[target.adv_indices],
                               target_states[target.adv_indices].tolist())
    ])) / len(target.adv_indices)

    # Bullseye Loss (just for evaluation!)
    from craft_poisons import bullseye_loss
    victim_bullseye_loss = bullseye_loss(target,
                                         dataset.poisons, [model],
                                         compute_gradients=False)

    # logging
    print(f'')
    print(
        f'    -> loss original states              : {loss_original_states:6.4f}'
    )
    print(
        f'    -> loss adversarial states           : {loss_adversarial_states:6.4f}'
    )
    print(f'    -> clean accuracy                    : {model_acc:6.4f}')
    print(
        f'    -> poisons cls. loss                 : {poisons_classification_loss.item():6.4f}'
    )
    print(
        f'    -> imp. indices poisons cls. loss    : {poisons_imp_indices_classification_loss.item():6.4f}'
    )
    print(
        f'    -> bullseye loss                     : {victim_bullseye_loss:6.4f}'
    )
    print(
        f'    -> adv. states acc.                  : {victim_adv_states_acc:6.4f}'
    )
    print(
        f"    -> model decoded                     : {', '.join([f'{p:>3}' for p in pred_phoneme_seq])}"
    )
    print(
        f"    -> target label                      : {', '.join([f'{p:>3}' for p in target.target_transcription])}"
    )
    print(f"    -> model output\n")
    states_to_interval = lambda states: [
        states[i:i + 28] for i in range(0, len(states), 28)
    ]
    for original_seq, target_seq, victim_argmax_seq, victim_viterbi_seq in \
            zip(states_to_interval(target.original_states.tolist()), states_to_interval(target_states.tolist()),
                states_to_interval(victim_states_argmax), states_to_interval(victim_states_viterbi)):
        print("       " + "| ORIGINAL  " +
              " ".join([f'{x:2}' for x in original_seq]))
        print("       " + "| TARGET    " +
              " ".join([f'{x:2}' for x in target_seq]))
        print("       " + "| ARGMAX    " +
              " ".join([f'{x:2}' for x in victim_argmax_seq]))
        print("       " + "| VITERBI   " +
              " ".join([f'{x:2}' for x in victim_viterbi_seq]))
        print('')

    res = {
        "loss_original_states":
        loss_original_states,
        "loss_adversarial_states":
        loss_adversarial_states,
        # "model_clean_test_acc": model_acc,
        "poisons_classification_loss":
        poisons_classification_loss.item(),
        "poisons_imp_indices_classification_loss":
        poisons_imp_indices_classification_loss.item(),
        "bullseye_loss":
        victim_bullseye_loss,
        "adv_states_acc":
        victim_adv_states_acc,
        "model_pred":
        "".join([str(p) for p in pred_phoneme_seq]),
        "test_res":
        test_res
    }

    return model, res
예제 #3
0
def eval_victim(model_type,
                feature_parameters,
                dataset,
                dataset_test,
                target,
                repeat_evaluation_num=1,
                dropout=0.0):

    res = {}
    for eval_idx in range(1, repeat_evaluation_num + 1):
        logging.info("[+] Evaluating the victim - {}".format(eval_idx))

        # init network
        tools.set_seed(202020 + eval_idx)
        model = init_model(model_type,
                           feature_parameters,
                           dataset.hmm,
                           dropout=dropout)
        model.train_model(dataset)

        # benign accuracy of victim model
        model_acc = model.test(dataset_test)

        # Poisons Classification Loss.
        poisons_x, poisons_y, poisons_true_length, poisons_imp_indices, _ = dataset.poisons.get_all_poisons(
        )
        poisons_classification_loss, poisons_imp_indices_classification_loss = \
            model.compute_loss_batch(poisons_x, poisons_y, poisons_true_length, important_indices=poisons_imp_indices)

        # similarity to original and target states
        loss_original_states = model.compute_loss_single(
            target.x, target.original_states).item()
        loss_adversarial_states = model.compute_loss_single(
            target.x, target.adversarial_states, target.adv_indices).item()

        # predicted transcription
        posteriors = model.features_to_posteriors(target.x)
        pred_phoneme_seq, victim_states_viterbi = dataset.hmm.posteriors_to_words(
            posteriors)
        pred_phoneme_seq = tools.str_to_digits(pred_phoneme_seq)
        victim_states_argmax = np.argmax(posteriors, axis=1)

        # target states
        target_states = target.adversarial_states

        victim_adv_states_acc = (100.0 * sum([
            v == t for v, t in zip(victim_states_argmax[target.adv_indices],
                                   target_states[target.adv_indices].tolist())
        ])) / len(target.adv_indices)

        # Bullseye Loss (just for evaluation!)
        victim_bullseye_loss = bullseye_loss(target,
                                             dataset.poisons, [model],
                                             compute_gradients=False)

        # logging
        logging.info(f'')
        logging.info(
            f'    -> loss original states              : {loss_original_states:6.4f}'
        )
        logging.info(
            f'    -> loss adversarial states           : {loss_adversarial_states:6.4f}'
        )
        logging.info(
            f'    -> clean accuracy                    : {model_acc:6.4f}')
        logging.info(
            f'    -> poisons cls. loss                 : {poisons_classification_loss.item():6.4f}'
        )
        logging.info(
            f'    -> imp. indices poisons cls. loss    : {poisons_imp_indices_classification_loss.item():6.4f}'
        )
        logging.info(
            f'    -> bullseye loss                     : {victim_bullseye_loss:6.4f}'
        )
        logging.info(
            f'    -> adv. states acc.                  : {victim_adv_states_acc:6.4f}'
        )
        logging.info(
            f"    -> model decoded                     : {', '.join([f'{p:>3}' for p in pred_phoneme_seq])}"
        )
        logging.info(
            f"    -> target label                      : {', '.join([f'{p:>3}' for p in target.target_transcription])}"
        )
        logging.info(f"    -> model output\n")
        states_to_interval = lambda states: [
            states[i:i + 28] for i in range(0, len(states), 28)
        ]
        for original_seq, target_seq, victim_argmax_seq, victim_viterbi_seq in \
                zip(states_to_interval(target.original_states.tolist()), states_to_interval(target_states.tolist()),
                    states_to_interval(victim_states_argmax), states_to_interval(victim_states_viterbi)):
            logging.info("       " + "| ORIGINAL  " +
                         " ".join([f'{x:2}' for x in original_seq]))
            logging.info("       " + "| TARGET    " +
                         " ".join([f'{x:2}' for x in target_seq]))
            logging.info("       " + "| ARGMAX    " +
                         " ".join([f'{x:2}' for x in victim_argmax_seq]))
            logging.info("       " + "| VITERBI   " +
                         " ".join([f'{x:2}' for x in victim_viterbi_seq]))
            logging.info('')

        res[eval_idx] = {
            "loss_original_states":
            loss_original_states,
            "loss_adversarial_states":
            loss_adversarial_states,
            "model_clean_test_acc":
            model_acc,
            "poisons_classification_loss":
            poisons_classification_loss.item(),
            "poisons_imp_indices_classification_loss":
            poisons_imp_indices_classification_loss.item(),
            "bullseye_loss":
            victim_bullseye_loss,
            "adv_states_acc":
            victim_adv_states_acc,
            "model_pred":
            "".join([str(p) for p in pred_phoneme_seq])
        }

        # if not succesful we do not haave to eval more victim networks
        if transcription2string(pred_phoneme_seq) != transcription2string(
                target.target_transcription):
            return loss_adversarial_states, model_acc, victim_adv_states_acc, res, False

    return loss_adversarial_states, model_acc, victim_adv_states_acc, res, True
예제 #4
0
파일: eval.py 프로젝트: 9yte/VenoMave
                       poisoned_dataset_path.joinpath("raw", "TEST"))

            # This only generates files in the plain folder, aligned based on a new HMM.
            preprocess(poisoned_dataset_path, feature_parameters)

            dataset = Dataset(poisoned_dataset_path.joinpath('plain', 'TRAIN'),
                              feature_parameters)
            dataset_test = Dataset(
                poisoned_dataset_path.joinpath('plain', 'TEST'),
                feature_parameters)
            dataset.poisons = Poisons(
                poisoned_dataset_path.joinpath('plain', 'TRAIN'), poison_paths,
                attack_dir.joinpath("poisons").with_suffix(".json"))
            # Training the model
            model = init_model(params.model_type,
                               feature_parameters,
                               dataset.hmm,
                               dropout=params.dropout)

            model.train_model(dataset, epochs=10, batch_size=params.batch_size)
            model.train_model(dataset,
                              epochs=1,
                              batch_size=params.batch_size,
                              viterbi_training=True)
            model.hmm.A = model.hmm.modifyTransitions(model.hmm.A_count)
            model.train_model(dataset,
                              epochs=1,
                              batch_size=params.batch_size,
                              viterbi_training=True)
            model.train_model(dataset,
                              epochs=1,
                              batch_size=params.batch_size,
예제 #5
0
def craft_poisons(model_type,
                  data_dir,
                  exp_dir,
                  target_filename,
                  adv_label,
                  seed,
                  feature_parameters,
                  poison_parameters,
                  device,
                  victim_hmm_seed=123456):

    epsilon = poison_parameters['eps']
    dropout = poison_parameters['dropout']

    # load dataset used for evaluation of the victim
    # tools.set_seed(victim_hmm_seed)
    # victim_data_dir = "{}/victim".format(data_dir)
    # preprocess_dataset(victim_data_dir, feature_parameters)
    # victim_dataset = Dataset(Path(victim_data_dir, 'aligned').joinpath('TRAIN'), feature_parameters)
    # victim_dataset_test = Dataset(Path(victim_data_dir, 'aligned').joinpath('TEST'), feature_parameters, subset=100)

    tools.set_seed(seed)

    # load dataseet
    preprocess_dataset(model_type, data_dir, feature_parameters)
    dataset = Dataset(
        Path(data_dir, model_type, 'aligned').joinpath('TRAIN'),
        feature_parameters)
    dataset_test = Dataset(Path(data_dir, model_type,
                                'aligned').joinpath('TEST'),
                           feature_parameters,
                           subset=10)

    # select target and find poisons
    target = Target(data_dir.joinpath(model_type), target_filename, adv_label,
                    feature_parameters, dataset, device)
    dataset.poisons = Poisons(poison_parameters, dataset, target,
                              (feature_parameters['left_context'],
                               feature_parameters['right_context']))

    # victim_dataset.poisons = dataset.poisons

    # save the poisons info - for future reference!
    dataset.poisons.save_poisons_info(
        exp_dir.joinpath("poisons").with_suffix(".json"))

    # init loss dict
    losses_dict = defaultdict(list)

    # TODO fix poisons
    # We store them to compare them with the new poisons
    # orig_poisons = [p.clone().detach() for p in dataset.poisons_batch.poisons]

    # First, let's test the victim against the original poison base samples
    victim_adv_loss, victim_clean_acc, victim_adv_states_acc, victim_res, _ = \
        eval_victim(model_type, feature_parameters, dataset, dataset_test, target, dropout=dropout)
    losses_dict['victim_adv_losses'].append(np.round(victim_adv_loss, 4))
    losses_dict['victim_adv_word_accs'].append(
        np.round(victim_adv_states_acc, 4))
    losses_dict['victim_clean_accs'].append(np.round(victim_clean_acc, 4))

    # define Optimizer
    opt = torch.optim.Adam(dataset.poisons.X, lr=0.0004, betas=(0.9, 0.999))
    decay_steps = [10, 20, 30, 50]  # [10, 30, 50, 80]
    decay_ratio = 0.5

    res = {0: victim_res}
    for step in range(1, poison_parameters['outer_crafting_steps'] + 1):
        res[step] = {}
        cur_time = time.strftime("%Y-%m-%d %H:%M:%S")
        logging.info('-' * 50)
        logging.info(f'{cur_time}: Step {step} of crafting poisons')
        res[step]['time'] = {'start': cur_time}

        # adjust learning rate
        if step in decay_steps:
            for param_group in opt.param_groups:
                param_group['lr'] *= decay_ratio
            logging.info(f'[+] Adjust lr to {param_group["lr"]:.2e}')

        # Now, let's refresh the models!
        logging.info(
            f'''[+] Train/refresh {poison_parameters['num_models']} models''')
        models = []
        for m in range(poison_parameters['num_models']):
            model_seed = tools.get_seed(model_idx=m,
                                        crafting_step=step,
                                        init_seed=seed)
            tools.set_seed(model_seed)
            model = init_model(model_type,
                               feature_parameters,
                               dataset.hmm,
                               device=device,
                               dropout=dropout,
                               test_dropout_enabled=True)
            model.train_model(dataset)
            models.append(model)

        # if dropout > 0.0:
        #     # Since the dropout is enabled, we use multiple draws to get a better estimate
        #     # of the target in the feature space
        #     target_features_models = [[model.forward(target_audio, penu=True).squeeze().detach() for _ in range(100)]
        #                               for model in models]
        #     target_features_models = [sum(t) / len(t) for t in target_features_models]
        # else:
        #     target_features_models = [model.forward(target_audio, penu=True).squeeze().detach() for model in models]

        last_inner_step_loss = None

        logging.info((f'[+] Optimizing the poisons'))

        # In this step, the models are fixed. So we only compute phi(target) once, instead of doing it at each iteration of inner optimization!
        # In case the dropout is enabled, we pass the target multiple times to get a better feature vector!
        mult_draw = 20 if dropout > 0.0 else 1
        phi_x_target_models = []
        for model in models:
            if mult_draw == 1:
                phi_x_target_models.append(
                    model.forward(target.x, penu=True).squeeze().detach())
            else:
                tmp = [
                    model.forward(target.x, penu=True).squeeze().detach()
                    for _ in range(mult_draw)
                ]
                phi_x_target_models.append(sum(tmp) / len(tmp))
        with tqdm(total=poison_parameters['inner_crafting_steps'],
                  bar_format='    {l_bar}{bar:30}{r_bar}') as pbar:
            for inner_step in range(
                    1, poison_parameters['inner_crafting_steps'] + 1):
                opt.zero_grad()

                inner_step_loss = bullseye_loss(
                    target,
                    dataset.poisons,
                    models,
                    compute_gradients=True,
                    phi_x_target_models=phi_x_target_models)
                # inner_step_loss.backward(), this is now being done in the bullseye_loss function!

                if inner_step == 1:
                    res[step]['subs_bullseye_loss'] = {
                        'start': inner_step_loss
                    }

                opt.step()

                pbar.set_description(f'Bullseye Loss: {inner_step_loss:6.4f}')

                pbar.update(1)

                dataset.poisons.clip(epsilon=epsilon)

                if last_inner_step_loss is not None \
                        and abs((inner_step_loss - last_inner_step_loss) / last_inner_step_loss) <= 0.0001:
                    # We are not making much progress in decreasing the bullseye loss. Let's take a break
                    break

                last_inner_step_loss = inner_step_loss

        logging.info((f'Bullseye Loss: {inner_step_loss:6.4f}'))
        res[step]['subs_bullseye_loss']['end'] = inner_step_loss
        res[step]['subs_bullseye_loss']['inner_step'] = inner_step

        # append bullseye loss to history
        losses_dict['step_losses'].append(np.round(inner_step_loss, 4))

        # norm = [torch.norm(p) for p in orig_poisons]
        # mean_norm = sum(norm) / len(norm)
        # diff = [torch.norm(new_p - p) for new_p, p in zip(dataset.poisons_batch.poisons, orig_poisons)]
        # logging.info("    Mean Diff Norm of Poisons: %.4e (Mean Norm of Original Poisons: %.4e)"
        #       % (sum(diff) / len(diff), mean_norm))

        step_dir = exp_dir.joinpath(f"{step}")
        step_dir.mkdir(parents=True)
        dataset.poisons.calc_snrseg(step_dir,
                                    feature_parameters['sampling_rate'])
        dataset.poisons.save(step_dir, feature_parameters['sampling_rate'])
        logging.info(f"Step {step} - Poisons saved at {step_dir}")

        res[step]['time']['end'] = time.strftime("%Y-%m-%d %H:%M:%S")

        # Now, let's test against the victim model
        if step % 1 == 0:
            victim_adv_loss, victim_clean_acc, victim_adv_states_acc, res[step]['victim'], succesful = \
                eval_victim(model_type, feature_parameters, dataset, dataset_test, target, repeat_evaluation_num=3, dropout=dropout)
            losses_dict['victim_adv_losses'].append(
                np.round(victim_adv_loss, 4))
            losses_dict['victim_adv_word_accs'].append(
                np.round(victim_adv_states_acc, 4))
            losses_dict['victim_clean_accs'].append(
                np.round(victim_clean_acc, 4))

            # if attack_succeeded(target.target_transcription, res[step]['victim']):
            if succesful:
                logging.info(
                    "Early stopping of the attack after {} steps".format(step))
                break

    logging.info("Bullseye Losses (Substitute networks): \n{}".format(
        losses_dict['step_losses']))
    logging.info("+" * 20)
    logging.info("Victim adv losses: \n{}".format(
        losses_dict['victim_adv_losses']))
    logging.info("+" * 20)
    logging.info("Victim adv states accs: \n{}".format(
        losses_dict['victim_adv_word_accs']))
    logging.info("+" * 20)
    logging.info("Victim clean accs: \n{}".format(
        losses_dict['victim_clean_accs']))
    logging.info("+" * 20)

    with open(exp_dir.joinpath("logs.json"), 'w') as f:
        json.dump(res, f)