Exemplo n.º 1
0
def test_experiments(architecture, config, num_days, seed_days, seed_test_day,
                     experiment_setup, testing_setup):

    # print(architecture)

    # -------------------------------------------------
    # Analysis
    # -------------------------------------------------

    plot_signal = False
    check_signal_power_effect = False

    # -------------------------------------------------
    # Data configuration
    # -------------------------------------------------

    exp_dir = config['exp_dir']
    sample_duration = config['sample_duration']
    preprocess_type = config['preprocess_type']
    sample_rate = config['sample_rate']

    # -------------------------------------------------
    # Training configuration
    # -------------------------------------------------
    epochs = config['epochs']

    # -------------------------------------------------
    # Equalization before any preprocessing
    # -------------------------------------------------
    equalize_train_before = experiment_setup['equalize_train_before']
    equalize_test_before = experiment_setup['equalize_test_before']

    # -------------------------------------------------
    # Physical Channel Parameters
    # -------------------------------------------------
    add_channel = experiment_setup['add_channel']

    phy_method = num_days
    seed_phy_train = seed_days
    # seed_phy_test = config['seed_phy_test']
    seed_phy_test = seed_test_day
    channel_type_phy_train = config['channel_type_phy_train']
    channel_type_phy_test = config['channel_type_phy_test']
    phy_noise = config['phy_noise']
    snr_train_phy = config['snr_train_phy']
    snr_test_phy = config['snr_test_phy']

    # -------------------------------------------------
    # Physical CFO parameters
    # -------------------------------------------------

    add_cfo = experiment_setup['add_cfo']
    remove_cfo = experiment_setup['remove_cfo']

    phy_method_cfo = phy_method  # config["phy_method_cfo"]
    df_phy_train = config['df_phy_train']
    df_phy_test = config['df_phy_test']
    seed_phy_train_cfo = seed_phy_train  # config['seed_phy_train_cfo']
    seed_phy_test_cfo = seed_phy_test  # config['seed_phy_test_cfo']

    # -------------------------------------------------
    # Equalization params
    # -------------------------------------------------
    equalize_train = experiment_setup['equalize_train']
    equalize_test = testing_setup['equalize_test']
    verbose_train = False
    verbose_test = False

    # -------------------------------------------------
    # Augmentation channel parameters
    # -------------------------------------------------
    augment_channel = experiment_setup['augment_channel']

    channel_type_aug_train = config['channel_type_aug_train']
    channel_type_aug_test = config['channel_type_aug_test']
    num_aug_train = config['num_aug_train']
    num_aug_test = config['num_aug_test']
    aug_type = config['aug_type']
    num_ch_train = config['num_ch_train']
    num_ch_test = config['num_ch_test']
    channel_method = config['channel_method']
    noise_method = config['noise_method']
    delay_seed_aug_train = False
    delay_seed_aug_test = False
    keep_orig_train = False
    keep_orig_test = False
    snr_train = config['snr_train']
    snr_test = config['snr_test']
    beta = config['beta']

    # -------------------------------------------------
    # Augmentation CFO parameters
    # -------------------------------------------------
    augment_cfo = experiment_setup['augment_cfo']

    df_aug_train = df_phy_train
    rand_aug_train = config['rand_aug_train']
    num_aug_train_cfo = config['num_aug_train_cfo']
    keep_orig_train_cfo = config['keep_orig_train_cfo']
    aug_type_cfo = config['aug_type_cfo']

    keep_orig_test_cfo = config["keep_orig_test_cfo"]
    num_aug_test_cfo = config["num_aug_test_cfo"]
    df_aug_test = config["df_aug_test"]

    # -------------------------------------------------
    # Residuals
    # -------------------------------------------------

    obtain_residuals = experiment_setup['obtain_residuals']

    # -------------------------------------------------
    # Loading Data
    # -------------------------------------------------

    data_format = '{:.0f}-pp-{:.0f}-fs-{:.0f}'.format(sample_duration,
                                                      preprocess_type,
                                                      sample_rate)
    outfile = exp_dir + '/sym-' + data_format + '.npz'

    np_dict = np.load(outfile)
    dict_wifi = {}
    dict_wifi['x_train'] = np_dict['arr_0']
    dict_wifi['y_train'] = np_dict['arr_1']
    dict_wifi['x_test'] = np_dict['arr_2']
    dict_wifi['y_test'] = np_dict['arr_3']
    dict_wifi['fc_train'] = np_dict['arr_4']
    dict_wifi['fc_test'] = np_dict['arr_5']
    dict_wifi['num_classes'] = dict_wifi['y_test'].shape[1]

    x_test_orig = dict_wifi['x_test']
    y_test_orig = dict_wifi['y_test']

    data_format += '_{}'.format(architecture)

    num_train = dict_wifi['x_train'].shape[0]
    num_test = dict_wifi['x_test'].shape[0]
    num_classes = dict_wifi['y_train'].shape[1]

    sampling_rate = sample_rate * 1e+6
    fs = sample_rate * 1e+6

    if equalize_train_before or equalize_test_before:
        print('\nEqualization Before')
        print('\tTrain: {}, Test: {}'.format(equalize_train_before,
                                             equalize_test_before))

        data_format = data_format + '-eq'

    if equalize_test_before is True:
        dict_wifi, data_format = equalize_channel(dict_wifi=dict_wifi,
                                                  sampling_rate=sampling_rate,
                                                  data_format=data_format,
                                                  verbosity=verbose_test,
                                                  which_set='x_test')

    # --------------------------------------------------------------------------------------------
    # Physical channel simulation (different days)
    # --------------------------------------------------------------------------------------------
    if add_channel:
        dict_wifi, data_format = physical_layer_channel(
            dict_wifi=dict_wifi,
            phy_method=phy_method,
            channel_type_phy_train=channel_type_phy_train,
            channel_type_phy_test=channel_type_phy_test,
            channel_method=channel_method,
            noise_method=noise_method,
            seed_phy_train=seed_phy_train,
            seed_phy_test=seed_phy_test,
            sampling_rate=sampling_rate,
            data_format=data_format)

    # --------------------------------------------------------------------------------------------
    # Physical offset simulation (different days)
    # --------------------------------------------------------------------------------------------
    if add_cfo:

        dict_wifi, data_format = physical_layer_cfo(
            dict_wifi=dict_wifi,
            df_phy_train=df_phy_train,
            df_phy_test=df_phy_test,
            seed_phy_train_cfo=seed_phy_train_cfo,
            seed_phy_test_cfo=seed_phy_test_cfo,
            sampling_rate=sampling_rate,
            phy_method_cfo=phy_method_cfo,
            data_format=data_format)
    if remove_cfo:
        data_format = data_format + '[_comp]-'

    # --------------------------------------------------------------------------------------------
    # Physical offset compensation
    # --------------------------------------------------------------------------------------------
    if testing_setup['remove_test_cfo']:
        dict_wifi, _ = cfo_compansator(dict_wifi=dict_wifi,
                                       sampling_rate=sampling_rate,
                                       data_format=data_format)

    # --------------------------------------------------------------------------------------------
    # Equalization
    # --------------------------------------------------------------------------------------------
    if equalize_train or equalize_test:
        print('\nEqualization')
        print('\tTrain: {}, Test: {}'.format(equalize_train, equalize_test))

        data_format = data_format + '-eq'

    if equalize_train is True:
        dict_wifi, data_format = equalize_channel(dict_wifi=dict_wifi,
                                                  sampling_rate=sampling_rate,
                                                  data_format=data_format,
                                                  verbosity=verbose_train,
                                                  which_set='x_train')

    if equalize_test is True:
        dict_wifi, data_format = equalize_channel(dict_wifi=dict_wifi,
                                                  sampling_rate=sampling_rate,
                                                  data_format=data_format,
                                                  verbosity=verbose_test,
                                                  which_set='x_test')

    if augment_channel:
        data_format = data_format + '[aug-{}-art-{}-ty-{}-nch-{}-snr-{:.0f}]-'.format(
            num_aug_train, channel_type_aug_train, aug_type, num_ch_train,
            snr_train)

    if augment_cfo:
        data_format = data_format + '[augcfo-{}-df-{}-rand-{}-ty-{}-{}-t-]-'.format(
            num_aug_train_cfo, df_aug_train * 1e6, rand_aug_train,
            aug_type_cfo, keep_orig_train_cfo)

    if obtain_residuals is True:
        print('Residuals are being obtained.')

        dict_wifi, data_format = get_residual(dict_wifi=dict_wifi,
                                              sampling_rate=sampling_rate,
                                              data_format=data_format,
                                              verbosity=verbose_test,
                                              which_set='x_test')

    print(data_format)

    # Checkpoint path
    exp_dir += "/CFO_channel_experiments"
    checkpoint = str(exp_dir + '/ckpt-' + data_format)

    if augment_channel is False:
        num_aug_test = 0

    print(checkpoint)

    dict_wifi_no_aug = copy.deepcopy(dict_wifi)

    if testing_setup['augment_test_channel']:

        seed_aug = np.max(seed_phy_train) + seed_phy_test + num_classes + 1

        dict_wifi, data_format = augment_with_channel_test(
            dict_wifi=dict_wifi,
            aug_type=aug_type,
            channel_method=channel_method,
            num_aug_train=num_aug_train,
            num_aug_test=num_aug_test,
            keep_orig_train=keep_orig_train,
            keep_orig_test=keep_orig_test,
            num_ch_train=num_ch_train,
            num_ch_test=num_ch_test,
            channel_type_aug_train=channel_type_aug_train,
            channel_type_aug_test=channel_type_aug_test,
            delay_seed_aug_test=delay_seed_aug_test,
            snr_test=snr_test,
            noise_method=noise_method,
            seed_aug=seed_aug,
            sampling_rate=sampling_rate,
            data_format=data_format)

    if testing_setup['augment_test_cfo']:

        dict_wifi = augment_with_cfo_test(
            dict_wifi=dict_wifi,
            aug_type_cfo=aug_type_cfo,
            df_aug_test=df_aug_test,
            num_aug_test_cfo=num_aug_test_cfo,
            keep_orig_test_cfo=keep_orig_test_cfo,
            rand_aug_test=False,
            sampling_rate=sampling_rate)

    print("========================================")
    print("== BUILDING MODEL... ==")

    checkpoint_in = checkpoint

    if checkpoint_in is None:
        raise ValueError('Cannot test without a checkpoint')
        # data_input = Input(batch_shape=(batch_size, num_features, 2))
        # output, model_name = network_20_2(data_input, num_classes, weight_decay)
        # densenet = Model(data_input, output)

    checkpoint_in = checkpoint_in + '.h5'
    densenet = load_model(checkpoint_in,
                          custom_objects={
                              'ComplexConv1D': ComplexConv1D,
                              'GetAbs': utils.GetAbs,
                              'Modrelu': Modrelu
                          })

    batch_size = 100

    num_test_aug = dict_wifi['x_test'].shape[0]

    # probs = densenet.predict(x=x_test, batch_size=batch_size, verbose=0)
    # label_pred = probs.argmax(axis=1)
    # label_act = y_test.argmax(axis=1)
    # ind_correct = np.where(label_pred==label_act)[0]
    # ind_wrong = np.where(label_pred!=label_act)[0]
    # assert (num_test == ind_wrong.size + ind_correct.size), 'Major calculation mistake!'
    # test_acc = 100.*ind_correct.size / num_test

    # acc_class = np.zeros([num_classes])
    # for class_idx in range(num_classes):
    #   idx_inclass = np.where(label_act==class_idx)[0]
    #   ind_correct = np.where(label_pred[idx_inclass]==label_act[idx_inclass])[0]
    #   acc_class[class_idx] = 100*ind_correct.size / idx_inclass.size

    # print("\n========================================")
    # print('Test accuracy: {:.2f}%'.format(test_acc))

    # # print(densenet.summary())
    # # for layer in densenet.layers:
    # #     print(layer.name)
    # # densenet = ...  # create the original model

    # ######################################
    # # Mean and cov_train
    # ######################################

    # x_test_classes = [None]*19

    # for n in range(19):
    #   ind_n = np.where(y_test.argmax(axis=1)==n)[0]
    #   x_test_classes[n] = x_test[ind_n]

    # x_test_classes = np.array(x_test_classes)

    # layer_name = densenet.layers[-1].name
    # print(layer_name)
    # model_2 = Model(inputs=densenet.input,
    #                 outputs=densenet.get_layer(layer_name).input)
    # weight, bias = densenet.get_layer(layer_name).get_weights()

    # logits_test = model_2.predict(x=x_test, batch_size=batch_size, verbose=0)
    # logits_test = logits_test.dot(weight) + bias

    output_dict = odict(acc=odict(), comp=odict(), loss=odict())

    if num_test_aug != num_test:

        num_test_per_aug = num_test_aug // num_test

        embeddings = densenet.layers[-2].output

        model2 = Model(densenet.input, embeddings)

        logits_test = model2.predict(x=dict_wifi['x_test'],
                                     batch_size=batch_size,
                                     verbose=0)

        softmax_test = densenet.predict(x=dict_wifi['x_test'],
                                        batch_size=batch_size,
                                        verbose=0)

        layer_name = densenet.layers[-1].name
        weight, bias = densenet.get_layer(layer_name).get_weights()

        logits_test = logits_test.dot(weight) + bias

        logits_test_new = np.zeros((num_test, num_classes))
        softmax_test_new = np.zeros((num_test, num_classes))
        softmax_test_augs = np.zeros((num_test_per_aug, num_test, num_classes))
        for i in range(num_test_per_aug):
            # list_x_test.append(x_test_aug[i*num_test:(i+1)*num_test])

            logits_test_new += logits_test[i * num_test:(i + 1) * num_test]
            softmax_test_new += softmax_test[i * num_test:(i + 1) * num_test]
            softmax_test_augs[i] = softmax_test[i * num_test:(i + 1) *
                                                num_test]
        # Adding LLRs for num_channel_aug_test test augmentations
        label_pred_llr = logits_test_new.argmax(axis=1)
        label_act = dict_wifi['y_test'][:num_test].argmax(axis=1)
        ind_correct = np.where(label_pred_llr == label_act)[0]
        ind_wrong = np.where(label_pred_llr != label_act)[0]
        assert (num_test == ind_wrong.size +
                ind_correct.size), 'Major calculation mistake!'
        test_acc_llr = 100. * ind_correct.size / num_test

        # Adding LLRs for num_channel_aug_test test augmentations
        label_pred_soft = softmax_test_new.argmax(axis=1)
        label_act = dict_wifi['y_test'][:num_test].argmax(axis=1)
        ind_correct = np.where(label_pred_soft == label_act)[0]
        ind_wrong = np.where(label_pred_soft != label_act)[0]
        assert (num_test == ind_wrong.size +
                ind_correct.size), 'Major calculation mistake!'
        test_acc_soft = 100. * ind_correct.size / num_test

        # 1 test augmentation
        probs = densenet.predict(x=dict_wifi['x_test'][:num_test],
                                 batch_size=batch_size,
                                 verbose=0)
        label_pred = probs.argmax(axis=1)
        ind_correct = np.where(label_pred == label_act)[0]
        ind_wrong = np.where(label_pred != label_act)[0]
        assert (num_test == ind_wrong.size +
                ind_correct.size), 'Major calculation mistake!'
        test_acc = 100. * ind_correct.size / num_test

        # No test augmentations
        probs = densenet.predict(x=dict_wifi_no_aug['x_test'],
                                 batch_size=batch_size,
                                 verbose=0)
        label_pred = probs.argmax(axis=1)
        label_act = y_test_orig.argmax(axis=1)
        ind_correct = np.where(label_pred == label_act)[0]
        ind_wrong = np.where(label_pred != label_act)[0]
        assert (num_test == ind_wrong.size +
                ind_correct.size), 'Major calculation mistake!'
        test_acc_no_aug = 100. * ind_correct.size / num_test

        # print("\n========================================")
        print('Test accuracy (0 aug): {:.2f}%'.format(test_acc_no_aug))
        print('Test accuracy (1 aug): {:.2f}%'.format(test_acc))
        print('Test accuracy ({} aug) llr: {:.2f}%'.format(
            num_test_per_aug, test_acc_llr))
        print('Test accuracy ({} aug) softmax avg: {:.2f}%'.format(
            num_test_per_aug, test_acc_soft))

        output_dict['acc']['test_zero_aug'] = test_acc_no_aug
        output_dict['acc']['test_one_aug'] = test_acc
        output_dict['acc']['test_many_aug'] = test_acc_llr
        output_dict['acc']['test_many_aug_soft_avg'] = test_acc_soft

    else:
        probs = densenet.predict(x=dict_wifi['x_test'],
                                 batch_size=batch_size,
                                 verbose=0)
        label_pred = probs.argmax(axis=1)
        label_act = y_test_orig.argmax(axis=1)
        ind_correct = np.where(label_pred == label_act)[0]
        ind_wrong = np.where(label_pred != label_act)[0]
        assert (dict_wifi['x_test'].shape[0] == ind_wrong.size +
                ind_correct.size), 'Major calculation mistake!'
        test_acc_no_aug = 100. * ind_correct.size / dict_wifi['x_test'].shape[0]

        # print("\n========================================")
        print('Test accuracy (no aug): {:.2f}%'.format(test_acc_no_aug))
        output_dict['acc']['test'] = test_acc_no_aug

        softmax_test_new = probs
        softmax_test_augs = probs.reshape(1, -1, num_classes)

    return output_dict, num_test_aug // num_test, softmax_test_new, softmax_test_augs, label_act
Exemplo n.º 2
0
def multiple_day_fingerprint(architecture,
                             config,
                             num_days,
                             seed_days,
                             seed_test_day,
                             experiment_setup,
                             n_val=True):

    # print(architecture)

    # -------------------------------------------------
    # Analysis
    # -------------------------------------------------

    plot_signal = False
    check_signal_power_effect = False

    # -------------------------------------------------
    # Data configuration
    # -------------------------------------------------

    exp_dir = config['exp_dir']
    sample_duration = config['sample_duration']
    preprocess_type = config['preprocess_type']
    sample_rate = config['sample_rate']

    # -------------------------------------------------
    # Training configuration
    # -------------------------------------------------
    epochs = config['epochs']

    # -------------------------------------------------
    # Equalization before any preprocessing
    # -------------------------------------------------
    equalize_train_before = experiment_setup['equalize_train_before']
    equalize_test_before = experiment_setup['equalize_test_before']

    # -------------------------------------------------
    # Physical Channel Parameters
    # -------------------------------------------------
    add_channel = experiment_setup['add_channel']

    phy_method = num_days
    seed_phy_train = seed_days
    # seed_phy_test = config['seed_phy_test']
    seed_phy_test = seed_test_day
    channel_type_phy_train = config['channel_type_phy_train']
    channel_type_phy_test = config['channel_type_phy_test']
    phy_noise = config['phy_noise']
    snr_train_phy = config['snr_train_phy']
    snr_test_phy = config['snr_test_phy']

    # -------------------------------------------------
    # Physical CFO parameters
    # -------------------------------------------------

    add_cfo = experiment_setup['add_cfo']
    remove_cfo = experiment_setup['remove_cfo']

    phy_method_cfo = phy_method  # config["phy_method_cfo"]
    df_phy_train = config['df_phy_train']
    df_phy_test = config['df_phy_test']
    seed_phy_train_cfo = seed_phy_train  # config['seed_phy_train_cfo']
    seed_phy_test_cfo = seed_phy_test  # config['seed_phy_test_cfo']

    # -------------------------------------------------
    # Equalization params
    # -------------------------------------------------
    equalize_train = experiment_setup['equalize_train']
    equalize_test = experiment_setup['equalize_test']
    verbose_train = False
    verbose_test = False

    # -------------------------------------------------
    # Augmentation channel parameters
    # -------------------------------------------------
    augment_channel = experiment_setup['augment_channel']

    channel_type_aug_train = config['channel_type_aug_train']
    channel_type_aug_test = config['channel_type_aug_test']
    num_aug_train = config['num_aug_train']
    num_aug_test = config['num_aug_test']
    aug_type = config['aug_type']
    num_ch_train = config['num_ch_train']
    num_ch_test = config['num_ch_test']
    channel_method = config['channel_method']
    noise_method = config['noise_method']
    delay_seed_aug_train = False
    delay_seed_aug_test = False
    keep_orig_train = False
    keep_orig_test = False
    snr_train = config['snr_train']
    snr_test = config['snr_test']
    beta = config['beta']

    # -------------------------------------------------
    # Augmentation CFO parameters
    # -------------------------------------------------
    augment_cfo = experiment_setup['augment_cfo']

    df_aug_train = df_phy_train
    rand_aug_train = config['rand_aug_train']
    num_aug_train_cfo = config['num_aug_train_cfo']
    keep_orig_train_cfo = config['keep_orig_train_cfo']
    aug_type_cfo = config['aug_type_cfo']

    # -------------------------------------------------
    # Residuals
    # -------------------------------------------------

    obtain_residuals = experiment_setup['obtain_residuals']

    # -------------------------------------------------
    # Loading Data
    # -------------------------------------------------

    data_format = '{:.0f}-pp-{:.0f}-fs-{:.0f}'.format(sample_duration,
                                                      preprocess_type,
                                                      sample_rate)
    outfile = exp_dir + '/sym-' + data_format + '.npz'

    np_dict = np.load(outfile)
    dict_wifi = {}
    dict_wifi['x_train'] = np_dict['arr_0']
    dict_wifi['y_train'] = np_dict['arr_1']
    dict_wifi['x_test'] = np_dict['arr_2']
    dict_wifi['y_test'] = np_dict['arr_3']
    dict_wifi['fc_train'] = np_dict['arr_4']
    dict_wifi['fc_test'] = np_dict['arr_5']
    dict_wifi['num_classes'] = dict_wifi['y_test'].shape[1]

    # import pdb
    # pdb.set_trace()

    data_format += '_{}'.format(architecture)

    num_train = dict_wifi['x_train'].shape[0]
    num_test = dict_wifi['x_test'].shape[0]
    num_classes = dict_wifi['y_train'].shape[1]

    sampling_rate = sample_rate * 1e+6
    fs = sample_rate * 1e+6

    x_train_orig = dict_wifi['x_train'].copy()
    y_train_orig = dict_wifi['y_train'].copy()

    x_test_orig = dict_wifi['x_test'].copy()
    y_test_orig = dict_wifi['y_test'].copy()

    if check_signal_power_effect == True:
        dict_wifi, data_format = signal_power_effect(dict_wifi=dict_wifi,
                                                     data_format=data_format)

    if plot_signal == True:
        plot_signals(dict_wifi=dict_wifi)

    if equalize_train_before or equalize_test_before:
        print('\nEqualization Before')
        print('\tTrain: {}, Test: {}'.format(equalize_train_before,
                                             equalize_test_before))

        data_format = data_format + '-eq'

    if equalize_train_before is True:
        dict_wifi, data_format = equalize_channel(dict_wifi=dict_wifi,
                                                  sampling_rate=sampling_rate,
                                                  data_format=data_format,
                                                  verbosity=verbose_train,
                                                  which_set='x_train')

    if equalize_test_before is True:
        dict_wifi, data_format = equalize_channel(dict_wifi=dict_wifi,
                                                  sampling_rate=sampling_rate,
                                                  data_format=data_format,
                                                  verbosity=verbose_test,
                                                  which_set='x_test')

    # --------------------------------------------------------------------------------------------
    # Physical channel simulation (different days)
    # --------------------------------------------------------------------------------------------
    if add_channel:
        dict_wifi, data_format = physical_layer_channel(
            dict_wifi=dict_wifi,
            phy_method=phy_method,
            channel_type_phy_train=channel_type_phy_train,
            channel_type_phy_test=channel_type_phy_test,
            channel_method=channel_method,
            noise_method=noise_method,
            seed_phy_train=seed_phy_train,
            seed_phy_test=seed_phy_test,
            sampling_rate=sampling_rate,
            data_format=data_format)

    # --------------------------------------------------------------------------------------------
    # Physical offset simulation (different days)
    # --------------------------------------------------------------------------------------------
    if add_cfo:

        dict_wifi, data_format = physical_layer_cfo(
            dict_wifi=dict_wifi,
            df_phy_train=df_phy_train,
            df_phy_test=df_phy_test,
            seed_phy_train_cfo=seed_phy_train_cfo,
            seed_phy_test_cfo=seed_phy_test_cfo,
            sampling_rate=sampling_rate,
            phy_method_cfo=phy_method_cfo,
            data_format=data_format)

    # --------------------------------------------------------------------------------------------
    # Physical offset compensation
    # --------------------------------------------------------------------------------------------
    if remove_cfo:
        dict_wifi, data_format = cfo_compansator(dict_wifi=dict_wifi,
                                                 sampling_rate=sampling_rate,
                                                 data_format=data_format)

    # --------------------------------------------------------------------------------------------
    # Equalization
    # --------------------------------------------------------------------------------------------
    if equalize_train or equalize_test:
        print('\nEqualization')
        print('\tTrain: {}, Test: {}'.format(equalize_train, equalize_test))

        data_format = data_format + '-eq'

    if equalize_train is True:
        dict_wifi, data_format = equalize_channel(dict_wifi=dict_wifi,
                                                  sampling_rate=sampling_rate,
                                                  data_format=data_format,
                                                  verbosity=verbose_train,
                                                  which_set='x_train')

    if equalize_test is True:
        dict_wifi, data_format = equalize_channel(dict_wifi=dict_wifi,
                                                  sampling_rate=sampling_rate,
                                                  data_format=data_format,
                                                  verbosity=verbose_test,
                                                  which_set='x_test')

    # --------------------------------------------------------------------------------------------
    # Channel augmentation
    # --------------------------------------------------------------------------------------------
    if augment_channel is True:

        seed_aug = np.max(seed_phy_train) + seed_phy_test + num_classes + 1

        dict_wifi, data_format = augment_with_channel(
            dict_wifi=dict_wifi,
            aug_type=aug_type,
            channel_method=channel_method,
            num_aug_train=num_aug_train,
            num_aug_test=num_aug_test,
            keep_orig_train=keep_orig_train,
            keep_orig_test=keep_orig_test,
            num_ch_train=num_ch_train,
            num_ch_test=num_ch_test,
            channel_type_aug_train=channel_type_aug_train,
            channel_type_aug_test=channel_type_aug_test,
            delay_seed_aug_train=delay_seed_aug_train,
            snr_train=snr_train,
            noise_method=noise_method,
            seed_aug=seed_aug,
            sampling_rate=sampling_rate,
            data_format=data_format)

    # --------------------------------------------------------------------------------------------
    # Carrier Frequency Offset augmentation
    # --------------------------------------------------------------------------------------------
    if augment_cfo is True:

        seed_aug_cfo = np.max(
            seed_phy_train_cfo) + seed_phy_test_cfo + num_classes + 1

        dict_wifi, data_format = augment_with_cfo(
            dict_wifi=dict_wifi,
            aug_type_cfo=aug_type_cfo,
            df_aug_train=df_aug_train,
            num_aug_train_cfo=num_aug_train_cfo,
            keep_orig_train_cfo=keep_orig_train_cfo,
            rand_aug_train=rand_aug_train,
            sampling_rate=sampling_rate,
            seed_aug_cfo=seed_aug_cfo,
            data_format=data_format)

    if obtain_residuals is True:
        print('Residuals are being obtained.')

        dict_wifi, data_format = get_residual(dict_wifi=dict_wifi,
                                              sampling_rate=sampling_rate,
                                              data_format=data_format,
                                              verbosity=verbose_train,
                                              which_set='x_train')

        dict_wifi, _ = get_residual(dict_wifi=dict_wifi,
                                    sampling_rate=sampling_rate,
                                    data_format=data_format,
                                    verbosity=verbose_test,
                                    which_set='x_test')

    print(data_format)
    # --------------------------------------------------------------------------------------------
    # Train
    # --------------------------------------------------------------------------------------------

    # Checkpoint path
    exp_dir += "/CFO_channel_experiments"
    checkpoint = str(exp_dir + '/ckpt-' + data_format)

    if augment_channel is False:
        num_aug_test = 0

    print(checkpoint)
    print('-----------------------\nExperiment:\n' + exp_dir +
          '\n-----------------------')
    if sample_rate == 20:
        train_output, model_name, summary = train_20(dict_wifi,
                                                     checkpoint_in=None,
                                                     num_aug_test=num_aug_test,
                                                     checkpoint_out=checkpoint,
                                                     architecture=architecture,
                                                     epochs=epochs)
    elif sample_rate == 200:
        train_output, model_name, summary = train_200(
            dict_wifi,
            checkpoint_in=None,
            num_aug_test=num_aug_test,
            checkpoint_out=checkpoint,
            architecture=architecture,
            epochs=epochs,
            n_val=n_val)

    else:
        raise NotImplementedError
    print('-----------------------\nExperiment:\n' + exp_dir +
          '\n-----------------------')

    # --------------------------------------------------------------------------------------------
    # Write in log file
    # --------------------------------------------------------------------------------------------

    # Write logs
    with open(exp_dir + '/logs-' + data_format + '.txt', 'a+') as f:
        f.write('\n\n-----------------------\n' + str(model_name) + '\n\n')

        # f.write('Different day scenario\n')
        # if equalize_train is True:
        #   f.write('With preamble equalization\n\n')
        # f.write('Channel augmentations: {}, keep_orig: {} \n'.format(num_aug_train, keep_orig_train))
        # f.write('Channel type: Phy_train: {}, Phy_test: {}, Aug_Train: {}, Aug_Test: {} \n'.format(channel_type_phy_train, channel_type_phy_test, channel_type_aug_train, channel_type_aug_test))
        # f.write('Seed: Phy_train: {}, Phy_test: {}'.format(seed_phy_train, seed_phy_test))
        # f.write('No of channels: Train: {}, Test: {} \n'.format(num_ch_train, num_ch_test))
        # f.write('SNR: Train: {} dB, Test {} dB\n'.format(snr_train, snr_test))

        f.write('\nPreprocessing\n\tType: {}\n\tFs: {} MHz\n\tLength: {} us'.
                format(preprocess_type, sample_rate, sample_duration))

        if equalize_train_before == True:
            f.write('\nEqualized signals before any preprocessing')

        if add_channel is True:
            f.write('\nPhysical Channel is added!')
            f.write('\nPhysical channel simulation (different days)')
            f.write('\tMethod: {}'.format(phy_method))
            f.write('\tChannel type: Train: {}, Test: {}'.format(
                channel_type_phy_train, channel_type_phy_test))
            f.write('\tSeed: Train: {}, Test: {}'.format(
                seed_phy_train, seed_phy_test))
        else:
            f.write('\nPhysical Channel is not added!')

        if add_cfo is True:
            f.write('\nPhysical CFO is added!')
            f.write('\nPhysical CFO simulation (different days)')
            f.write('\tMethod: {}'.format(phy_method_cfo))
            f.write('\tdf_train: {}, df_test: {}'.format(
                df_phy_train, df_phy_test))
            f.write('\tSeed: Train: {}, Test: {}'.format(
                seed_phy_train_cfo, seed_phy_test_cfo))
        else:
            f.write('\nPhysical CFO is not added!')

        if remove_cfo is True:
            f.write('\nCFO is compensated!')
        else:
            f.write('\nCFO is not compensated!')

        f.write('\nEqualization')
        f.write('\tTrain: {}, Test: {}'.format(equalize_train, equalize_test))

        if augment_channel is True:
            f.write('\nChannel augmentation')
            f.write('\tAugmentation type: {}'.format(aug_type))
            f.write('\tChannel Method: {}'.format(channel_method))
            f.write('\tNoise Method: {}'.format(noise_method))
            f.write(
                '\tNo of augmentations: Train: {}, Test: {}\n\tKeep originals: Train: {}, Test: {}'
                .format(num_aug_train, num_aug_test, keep_orig_train,
                        keep_orig_test))
            f.write('\tNo. of channels per aug: Train: {}, Test: {}'.format(
                num_ch_train, num_ch_test))
            f.write('\tChannel type: Train: {}, Test: {}'.format(
                channel_type_aug_train, channel_type_aug_test))
            f.write('\tSNR: Train: {}, Test: {}'.format(snr_train, snr_test))
            f.write('\tBeta {}'.format(beta))
        else:
            f.write('\nChannel is not augmented')

        if augment_cfo is True:
            f.write('\nCFO augmentation')
            f.write('\tAugmentation type: {}'.format(aug_type_cfo))
            f.write(
                '\tNo of augmentations: Train: {}, \n\tKeep originals: Train: {}'
                .format(num_aug_train_cfo, keep_orig_train))
        else:
            f.write('\nCFO is not augmented')

        for keys, dicts in train_output.items():
            f.write(str(keys) + ':\n')
            for key, value in dicts.items():
                f.write('\t' + str(key) + ': {:.2f}%'.format(value) + '\n')
        f.write('\n' + str(summary))

    print('\nPreprocessing\n\tType: {}\n\tFs: {} MHz\n\tLength: {} us'.format(
        preprocess_type, sample_rate, sample_duration))

    if equalize_train_before == True:
        print('\nEqualized signals before any preprocessing')

    if add_channel is True:
        print('\nPhysical Channel is added!')
        print('\nPhysical channel simulation (different days)')
        print('\tMethod: {}'.format(phy_method))
        print('\tChannel type: Train: {}, Test: {}'.format(
            channel_type_phy_train, channel_type_phy_test))
        print('\tSeed: Train: {}, Test: {}'.format(seed_phy_train,
                                                   seed_phy_test))
    else:
        print('\nPhysical Channel is not added!')

    if add_cfo is True:
        print('\nPhysical CFO is added!')
        print('\nPhysical CFO simulation (different days)')
        print('\tMethod: {}'.format(phy_method_cfo))
        print('\tdf_train: {}, df_test: {}'.format(df_phy_train, df_phy_test))
        print('\tSeed: Train: {}, Test: {}'.format(seed_phy_train_cfo,
                                                   seed_phy_test_cfo))
    else:
        print('\nPhysical CFO is not added!')

    if remove_cfo is True:
        print('\nCFO is compensated!')
    else:
        print('\nCFO is not compensated!')

    print('\nEqualization')
    print('\tTrain: {}, Test: {}'.format(equalize_train, equalize_test))

    if augment_channel is True:
        print('\nChannel augmentation')
        print('\tAugmentation type: {}'.format(aug_type))
        print('\tChannel Method: {}'.format(channel_method))
        print('\tNoise Method: {}'.format(noise_method))
        print(
            '\tNo of augmentations: Train: {}, Test: {}\n\tKeep originals: Train: {}, Test: {}'
            .format(num_aug_train, num_aug_test, keep_orig_train,
                    keep_orig_test))
        print('\tNo. of channels per aug: Train: {}, Test: {}'.format(
            num_ch_train, num_ch_test))
        print('\tChannel type: Train: {}, Test: {}'.format(
            channel_type_aug_train, channel_type_aug_test))
        print('\tSNR: Train: {}, Test: {}'.format(snr_train, snr_test))
        print('\tBeta {}'.format(beta))
    else:
        print('\nChannel is not augmented')

    if augment_cfo is True:
        print('\nCFO augmentation')
        print('\tAugmentation type: {}'.format(aug_type_cfo))
        print(
            '\tNo of augmentations: Train: {}, \n\tKeep originals: Train: {}'.
            format(num_aug_train_cfo, keep_orig_train))
    else:
        print('\nCFO is not augmented')

    return train_output