def test_per_channel_per_snr_conven_approach(args, est_h, test_size, h, test_snr, actual_channel_num):
    batch_size = test_size
    success_test = 0
    Eb_over_N_test = pow(10, (test_snr / 10))
    R = args.bit_num / args.channel_num
    noise_var_test = 1 / (2 * R * Eb_over_N_test)
    Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num),
                                                                            noise_var_test * torch.eye(actual_channel_num))
    print('payload num for conv', batch_size)
    m_test, label_test = message_gen(args.bit_num, batch_size)
    label_test = label_test.type(torch.LongTensor).to(args.device)

    if args.if_fix_random_seed:
        reset_randomness(args.random_seed + 7777) # always have same noise for test since this can see the actual effect of # of adaptations

    actual_transmitted_symbol = four_qam_uncoded_modulation(args.bit_num, args.channel_num, label_test)#m_test)

    received_signal = channel(h, actual_transmitted_symbol, Noise_test, args.device, args.if_AWGN)
    # rx (maximum likelihood)
    out_test = demodulation(args.bit_num, args.channel_num, args.tap_num, est_h, received_signal, args.device)
    for ind_mb in range(label_test.shape[0]):
        assert label_test.shape[0] == batch_size
        if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]:  # means correct classification
            success_test += 1
        else:
            pass
    accuracy = success_test / label_test.shape[0]

    print('accuracy', accuracy)

    return 1 - accuracy
Esempio n. 2
0
def test_with_adapt_compact_during_online_meta_training(
        args, common_dir, curr_meta_training_epoch, test_snr_range,
        num_pilots_test, tx_net_for_testtraining, rx_net_for_testtraining,
        Noise, Noise_relax, actual_channel_num, PATH_before_adapt_tx,
        PATH_before_adapt_rx):
    # generate or load test channels
    if args.path_for_test_channels is None:
        raise NotImplementedError
    else:
        print('load previously generated channels')
        h_list_test_path = args.path_for_test_channels + '/' + 'test_channels.pckl'
        f_test_channels = open(h_list_test_path, 'rb')
        h_list_test = pickle.load(f_test_channels)
        f_test_channels.close()

    if len(h_list_test) > args.num_channels_test:
        h_list_test = h_list_test[:args.num_channels_test]

    # reset again # to make fair comp. per adapt. and per meta-training epochs
    if args.if_fix_random_seed:
        reset_randomness(args.random_seed + 2)
    print('curr pilots used for test during online (meta-)learning: ',
          num_pilots_test)

    os.makedirs(common_dir + 'saved_model/' + 'with_meta_training_epoch/' +
                str(curr_meta_training_epoch) + '/' + 'tx/' + 'after_adapt/' +
                str(num_pilots_test) + '_num_pilots_test/')
    os.makedirs(common_dir + 'saved_model/' + 'with_meta_training_epoch/' +
                str(curr_meta_training_epoch) + '/' + 'rx/' + 'after_adapt/' +
                str(num_pilots_test) + '_num_pilots_test/')

    block_error_rate = torch.zeros(args.num_channels_test, len(test_snr_range))
    ind_h = 0
    for h in h_list_test:
        PATH_after_adapt_tx = common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str(
            curr_meta_training_epoch) + '/' + 'tx/' + 'after_adapt/' + str(
                num_pilots_test) + '_num_pilots_test/' + str(
                    ind_h) + 'th_adapted_net'
        PATH_after_adapt_rx = common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str(
            curr_meta_training_epoch) + '/' + 'rx/' + 'after_adapt/' + str(
                num_pilots_test) + '_num_pilots_test/' + str(
                    ind_h) + 'th_adapted_net'
        test_training(args, h, tx_net_for_testtraining,
                      rx_net_for_testtraining, Noise, Noise_relax,
                      PATH_before_adapt_tx, PATH_before_adapt_rx,
                      PATH_after_adapt_tx, PATH_after_adapt_rx,
                      num_pilots_test)
        # test
        ind_snr = 0
        for test_snr in test_snr_range:
            block_error_rate_per_snr_per_channel = test_per_channel_per_snr(
                args, args.test_size_during_meta_update, h,
                tx_net_for_testtraining, rx_net_for_testtraining, test_snr,
                actual_channel_num, PATH_after_adapt_tx, PATH_after_adapt_rx)
            block_error_rate[ind_h,
                             ind_snr] = block_error_rate_per_snr_per_channel
            ind_snr += 1
        ind_h += 1
    return torch.mean(block_error_rate[:, :])
def test_per_channel_per_snr(args, test_size, h, tx_net_for_testtraining, rx_net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt_tx, PATH_after_adapt_rx):

    tx_net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt_tx))
    rx_net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt_rx))

    batch_size = test_size
    success_test = 0
    Eb_over_N_test = pow(10, (test_snr / 10))
    R = args.bit_num / args.channel_num
    noise_var_test = 1 / (2 * R * Eb_over_N_test)
    Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num),
                                                                            noise_var_test * torch.eye(actual_channel_num))
    m_test, label_test = message_gen(args.bit_num, batch_size)
    m_test = m_test.type(torch.FloatTensor).to(args.device)
    label_test = label_test.type(torch.LongTensor).to(args.device)

    for f in tx_net_for_testtraining.parameters():
        if f.grad is not None:
            f.grad.detach()
            f.grad.zero_()

    rx_net_for_testtraining.zero_grad()

    if args.if_fix_random_seed:
        reset_randomness(args.random_seed + 7777) # always have same noise for test since this can see the actual effect of # of adaptations


    if args.fix_bpsk_tx_train_nn_rx_during_runtime:
        actual_transmitted_symbol = four_qam_uncoded_modulation(args.bit_num, args.channel_num,
                                                                label_test)  # label instead m
    else:
        tx_symb_mean, actual_transmitted_symbol = tx_net_for_testtraining(m_test, args.device, 0, None) # no relaxation
    tx_symb_mean = None # we don't need during test
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise_test, args.device, args.if_AWGN)
    # rx
    out_test = rx_net_for_testtraining(received_signal, args.if_RTN, args.device)

    for ind_mb in range(label_test.shape[0]):
        assert label_test.shape[0] == batch_size
        if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]:  # means correct classification
            success_test += 1
        else:
            pass
    accuracy = success_test / label_test.shape[0]

    return 1 - accuracy
Esempio n. 4
0
def test_conven_commun_during_online_meta_training(args, test_snr_range,
                                                   num_pilots_test, Noise,
                                                   actual_channel_num):
    # generate or load test channels
    if args.path_for_test_channels is None:
        print(
            'we need to first make the test channels a priori and load it via path (args.path_for_test_channels)'
        )
        raise NotImplementedError
    else:
        print('load previously generated channels')
        h_list_test_path = args.path_for_test_channels + '/' + 'test_channels.pckl'
        f_test_channels = open(h_list_test_path, 'rb')
        h_list_test = pickle.load(f_test_channels)
        f_test_channels.close()

    if len(h_list_test) > args.num_channels_test:
        h_list_test = h_list_test[:args.num_channels_test]
    print('used test channels', h_list_test)

    # reset again # to make fair comp. per adapt. and per meta-training epochs
    if args.if_fix_random_seed:
        reset_randomness(args.random_seed + 2)
    print('curr num pilots: ', num_pilots_test)

    block_error_rate = torch.zeros(args.num_channels_test, len(test_snr_range))
    ind_h = 0
    ch_est_error_avg = 0
    for h in h_list_test:
        est_h, error_h = test_training_conven_commun(args, h, Noise,
                                                     num_pilots_test)
        ind_snr = 0
        for test_snr in test_snr_range:
            block_error_rate_per_snr_per_channel = test_per_channel_per_snr_conven_approach(
                args, est_h, args.conv_payload_num, h, test_snr,
                actual_channel_num)
            block_error_rate[ind_h,
                             ind_snr] = block_error_rate_per_snr_per_channel
            ind_snr += 1
        ind_h += 1
        ch_est_error_avg += error_h
    ch_est_error_avg = ch_est_error_avg / ind_h

    return torch.mean(block_error_rate[:, :]), ch_est_error_avg
Esempio n. 5
0
 def __init__(self, M, n, n_inv_filter, num_neurons_decoder, if_bias,
              if_relu, if_RTN, if_fix_random_seed, random_seed):
     super(receiver, self).__init__()
     num_inv_filter = 2 * n_inv_filter
     if if_RTN:
         if if_fix_random_seed:
             reset_randomness(random_seed + 1)
         self.rtn_1 = nn.Linear(n, n, bias=if_bias)
         self.rtn_2 = nn.Linear(n, n, bias=if_bias)
         self.rtn_3 = nn.Linear(n, num_inv_filter, bias=if_bias)
     else:
         pass
     if if_fix_random_seed:
         reset_randomness(random_seed)
     self.dec_fc1 = nn.Linear(n, num_neurons_decoder, bias=if_bias)
     self.dec_fc2 = nn.Linear(num_neurons_decoder, M, bias=if_bias)
     if if_relu:
         self.activ = nn.ReLU()
     else:
         self.activ = nn.Tanh()
     self.tanh = nn.Tanh()
Esempio n. 6
0
def multi_task_learning(args, tx_net,rx_net, h_list_meta, writer_meta_training, Noise, Noise_relax, Noise_feedback, PATH_before_adapt_rx_intermediate, PATH_before_adapt_tx_intermediate):
    meta_optimiser_tx = torch.optim.Adam(tx_net.parameters(), args.lr_meta_update)
    meta_optimiser_rx = torch.optim.Adam(rx_net.parameters(), args.lr_meta_update)
    h_list_train = h_list_meta[:args.num_channels_meta]

    if args.if_fix_random_seed:
        random_seed_init = args.random_seed + 99999
    else:
        pass
    previous_channel = None

    for epochs in range(args.num_epochs_meta_train):
        if epochs % 100 == 0:
            curr_path_rx = PATH_before_adapt_rx_intermediate + str(epochs)
            curr_path_tx = PATH_before_adapt_tx_intermediate + str(epochs)
            torch.save(rx_net.state_dict(), curr_path_rx)
            torch.save(tx_net.state_dict(), curr_path_tx)
            print('stochactic meta-learning epoch', epochs)
        first_loss_rx = 0
        second_loss_rx = 0
        first_loss_tx = 0
        second_loss_tx = 0
        iter_in_sampled_device = 0  # for averaging meta-devices
        for ind_meta_dev in range(args.tasks_per_metaupdate):
            if args.if_always_generate_new_meta_training_channels:
                if args.if_fix_random_seed:
                    reset_randomness(random_seed_init) # to make noise same as much as possible for joint and meta
                    random_seed_init += 1
                else:
                    pass
                if args.if_Rayleigh_channel_model_AR:
                    if epochs % args.keep_AR_period == 0:
                        h_var_dist = torch.distributions.uniform.Uniform(torch.FloatTensor([args.mul_h_var_min]), torch.FloatTensor([args.mul_h_var_max]))
                        mul_h_var = h_var_dist.sample()
                        if_reset_AR = True
                    else:
                        if_reset_AR = False
                    current_channel = channel_set_gen_AR(args.tap_num, previous_channel, args.rho, mul_h_var, if_reset_AR)  # num_channels = 1 since we are generating per channel
                    previous_channel = current_channel
                else:
                    raise NotImplementedError
            else:
                if args.if_fix_random_seed:
                    reset_randomness(random_seed_init) # to make noise same as much as possible for joint and meta
                    random_seed_init += 1
                else:
                    pass
                # during this, meta-gradients are accumulated
                channel_list_total = torch.randperm(len(h_list_train))  # sampling with replacement
                current_channel_ind = channel_list_total[
                    ind_meta_dev]  # randomly sample meta-batches (no rep. inside meta-batch)
                current_channel = h_list_train[current_channel_ind]

            if args.if_joint_training:
                iter_in_sampled_device, first_loss_curr_rx, first_loss_curr_tx, second_loss_curr_rx, second_loss_curr_tx = joint_training(args, iter_in_sampled_device,
                                                                                 tx_net, rx_net,
                                                                                 current_channel, Noise, Noise_relax, Noise_feedback)
            elif args.if_joint_training_tx_meta_training_rx:
                iter_in_sampled_device, first_loss_curr_rx, first_loss_curr_tx, second_loss_curr_rx, second_loss_curr_tx = maml_for_rx_joint_for_tx(
                    args, iter_in_sampled_device,
                    tx_net, rx_net,
                    current_channel, Noise, Noise_relax, Noise_feedback)
            else:  # maml
                raise NotImplementedError
            first_loss_rx = first_loss_rx + first_loss_curr_rx
            second_loss_rx = second_loss_rx + second_loss_curr_rx
            first_loss_tx = first_loss_tx + first_loss_curr_tx
            second_loss_tx = second_loss_tx + second_loss_curr_tx
        first_loss_tx = first_loss_tx / args.tasks_per_metaupdate
        second_loss_tx = second_loss_tx / args.tasks_per_metaupdate
        first_loss_rx = first_loss_rx / args.tasks_per_metaupdate
        second_loss_rx = second_loss_rx / args.tasks_per_metaupdate
        writer_meta_training.add_scalar('first rx loss', first_loss_rx, epochs)
        writer_meta_training.add_scalar('first tx loss', first_loss_tx, epochs)
        writer_meta_training.add_scalar('second rx loss', second_loss_rx, epochs)
        writer_meta_training.add_scalar('second tx loss', second_loss_tx, epochs)
        meta_optimiser_rx.zero_grad()
        meta_optimiser_tx.zero_grad()
        for f in rx_net.parameters():
            f.grad = f.total_grad.clone() / args.tasks_per_metaupdate
        for f in tx_net.parameters():
            f.grad = f.total_grad.clone() / args.tasks_per_metaupdate
        meta_optimiser_rx.step()  # Adam
        meta_optimiser_tx.step()  # Adam
Esempio n. 7
0
def test_with_adapt(args, common_dir, common_dir_over_multi_rand_seeds,
                    test_snr_range, test_num_pilots_available,
                    meta_training_epoch_for_test, tx_net_for_testtraining,
                    rx_net_for_testtraining, Noise, Noise_relax,
                    actual_channel_num, save_test_result_dict_total,
                    test_result_all_PATH_for_all_meta_training_epochs,
                    PATH_before_adapt_tx, PATH_before_adapt_rx):
    # generate or load test channels
    if args.path_for_test_channels is None:
        if args.if_fix_random_seed:
            reset_randomness(args.random_seed + 11)
        print('generate test channels')
        h_list_test = channel_set_gen(args.num_channels_test, args.tap_num,
                                      args.if_toy)
        h_list_test_path = common_dir + 'test_channels/' + 'test_channels.pckl'
        f_test_channels = open(h_list_test_path, 'wb')
        pickle.dump(h_list_test, f_test_channels)
        f_test_channels.close()
    else:
        print('load previously generated channels')
        h_list_test_path = args.path_for_test_channels + '/' + 'test_channels.pckl'
        f_test_channels = open(h_list_test_path, 'rb')
        h_list_test = pickle.load(f_test_channels)
        f_test_channels.close()

    if len(h_list_test) > args.num_channels_test:
        h_list_test = h_list_test[:args.num_channels_test]
    print('used test channels', h_list_test)

    dir_test = common_dir + 'TB/' + 'test'
    writer_test = SummaryWriter(dir_test)

    total_total_block_error_rate = torch.zeros(
        args.num_channels_test, len(test_snr_range),
        len(test_num_pilots_available), len(meta_training_epoch_for_test))
    ind_meta_training_epoch = 0
    for meta_training_epochs in meta_training_epoch_for_test:
        if common_dir_over_multi_rand_seeds is not None:
            os.makedirs(common_dir_over_multi_rand_seeds +
                        'test_result_after_meta_training/' + 'iter/' +
                        str(args.fix_tx_multi_adapt_rx_iter_num) + 'rho/' +
                        str(args.rho) + '/rand_seeds/' +
                        str(args.random_seed) + '/test_result/' +
                        'with_meta_training_epoch/' +
                        str(meta_training_epochs) + '/')
            test_result_all_PATH_per_rand_seeds = common_dir_over_multi_rand_seeds + 'test_result_after_meta_training/' + 'iter/' + str(
                args.fix_tx_multi_adapt_rx_iter_num) + 'rho/' + str(
                    args.rho) + '/rand_seeds/' + str(
                        args.random_seed
                    ) + '/test_result/' + 'with_meta_training_epoch/' + str(
                        meta_training_epochs) + '/' + 'test_result.mat'

        save_test_result_dict = {}
        # start with given initialization
        print('start adaptation with test set')
        total_block_error_rate = torch.zeros(args.num_channels_test,
                                             len(test_snr_range),
                                             len(test_num_pilots_available))
        ind_num_pilots_test = 0
        for num_pilots_test in test_num_pilots_available:
            # reset again # to make fair comp. per adapt. and per meta-training epochs
            if args.if_fix_random_seed:
                reset_randomness(args.random_seed + 2)
            print('curr pilots num: ', num_pilots_test)
            os.makedirs(common_dir + 'saved_model/' +
                        'with_meta_training_epoch/' +
                        str(meta_training_epochs) + '/' + 'tx/' +
                        'after_adapt/' + str(num_pilots_test) +
                        '_num_pilots_test/')
            os.makedirs(common_dir + 'saved_model/' +
                        'with_meta_training_epoch/' +
                        str(meta_training_epochs) + '/' + 'rx/' +
                        'after_adapt/' + str(num_pilots_test) +
                        '_num_pilots_test/')
            os.makedirs(common_dir + 'test_result/' +
                        'with_meta_training_epoch/' +
                        str(meta_training_epochs) + '/' +
                        str(num_pilots_test) + '_num_pilots_test/')
            test_result_per_num_pilots_test = common_dir + 'test_result/' + 'with_meta_training_epoch/' + str(
                meta_training_epochs) + '/' + str(
                    num_pilots_test) + '_num_pilots_test/' + 'test_result.mat'
            save_test_result_dict_per_num_pilots_test = {}

            block_error_rate = torch.zeros(args.num_channels_test,
                                           len(test_snr_range))
            ind_h = 0
            for h in h_list_test:
                print('current channel ind', ind_h)
                PATH_after_adapt_tx = common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str(
                    meta_training_epochs) + '/' + 'tx/' + 'after_adapt/' + str(
                        num_pilots_test) + '_num_pilots_test/' + str(
                            ind_h) + 'th_adapted_net'
                PATH_after_adapt_rx = common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str(
                    meta_training_epochs) + '/' + 'rx/' + 'after_adapt/' + str(
                        num_pilots_test) + '_num_pilots_test/' + str(
                            ind_h) + 'th_adapted_net'
                test_training(args, h, tx_net_for_testtraining,
                              rx_net_for_testtraining, Noise, Noise_relax,
                              PATH_before_adapt_tx, PATH_before_adapt_rx,
                              PATH_after_adapt_tx, PATH_after_adapt_rx,
                              num_pilots_test)
                # test
                ind_snr = 0
                for test_snr in test_snr_range:
                    block_error_rate_per_snr_per_channel = test_per_channel_per_snr(
                        args, args.test_size, h, tx_net_for_testtraining,
                        rx_net_for_testtraining, test_snr, actual_channel_num,
                        PATH_after_adapt_tx, PATH_after_adapt_rx)
                    block_error_rate[
                        ind_h, ind_snr] = block_error_rate_per_snr_per_channel
                    total_block_error_rate[
                        ind_h, ind_snr,
                        ind_num_pilots_test] = block_error_rate_per_snr_per_channel
                    total_total_block_error_rate[
                        ind_h, ind_snr, ind_num_pilots_test,
                        ind_meta_training_epoch] = block_error_rate_per_snr_per_channel
                    ind_snr += 1
                ind_h += 1
            save_test_result_dict_per_num_pilots_test[
                'block_error_rate'] = block_error_rate.detach().numpy()
            sio.savemat(test_result_per_num_pilots_test,
                        save_test_result_dict_per_num_pilots_test)
            writer_test.add_scalar(
                'average (h) block error rate per num pilots',
                torch.mean(block_error_rate[:, :]), num_pilots_test)
            ind_num_pilots_test += 1
            print('curr pilots num', num_pilots_test, 'bler',
                  torch.mean(block_error_rate[:, :]))

        save_test_result_dict[
            'block_error_rate_total'] = total_block_error_rate.detach().numpy(
            )

        if common_dir_over_multi_rand_seeds is not None:
            sio.savemat(test_result_all_PATH_per_rand_seeds,
                        save_test_result_dict)
        else:
            os.makedirs(common_dir + 'test_result/' +
                        'with_meta_training_epoch/' +
                        str(meta_training_epochs) + '/')
            test_result_all_PATH = common_dir + 'test_result/' + 'with_meta_training_epoch/' + str(
                meta_training_epochs) + '/' + 'test_result.mat'

            sio.savemat(test_result_all_PATH, save_test_result_dict)

    if args.path_for_meta_trained_net_total_per_epoch:
        save_test_result_dict_total[
            'block_error_rate_total_total_meta_training_epoch'] = total_total_block_error_rate.detach(
            ).numpy()
        sio.savemat(test_result_all_PATH_for_all_meta_training_epochs,
                    save_test_result_dict_total)
def multi_task_learning(args, common_dir, tx_net, rx_net, writer_meta_training,
                        Noise, Noise_relax, actual_channel_num,
                        PATH_before_adapt_rx_intermediate,
                        PATH_before_adapt_tx_intermediate,
                        rx_net_for_testtraining, tx_net_for_testtraining):
    meta_optimiser_tx = torch.optim.Adam(tx_net.parameters(),
                                         args.lr_meta_update)
    meta_optimiser_rx = torch.optim.Adam(rx_net.parameters(),
                                         args.lr_meta_update)

    test_result_PATH_per_meta_training_test_bler_dir = common_dir + 'test_result_during_meta_training/'
    save_test_result_dict_total_per_meta_training_test_bler = {}

    if args.if_fix_random_seed:
        random_seed_init = args.random_seed + 99999
    else:
        pass
    # for online
    previous_channel = []
    for ind_dev in range(args.tasks_per_metaupdate):
        previous_channel.append(None)

    test_bler_per_meta_training_epochs = []
    channel_per_meta_training_epochs = []

    second_loss_rx_best_for_stopping_criteria = 99999999999

    for epochs in range(args.num_epochs_meta_train):
        if epochs % args.meta_tr_epoch_num_for_test == 0:
            curr_path_rx = PATH_before_adapt_rx_intermediate + str(epochs)
            curr_path_tx = PATH_before_adapt_tx_intermediate + str(epochs)
            torch.save(rx_net.state_dict(), curr_path_rx)
            torch.save(tx_net.state_dict(), curr_path_tx)
            print('meta-learning epoch', epochs)
            if args.see_test_bler_during_meta_update:
                # see test bler here
                test_snr_range = [args.Eb_over_N_db_test]
                num_pilots_test = args.pilots_num_meta_test
                PATH_before_adapt_tx = curr_path_tx
                PATH_before_adapt_rx = curr_path_rx

                if args.if_get_conven_commun_performance:
                    test_bler_mean_curr_epoch_conven_approach, ch_est_error_avg = test_conven_commun_during_online_meta_training(
                        args, test_snr_range, num_pilots_test, Noise,
                        actual_channel_num)
                    writer_meta_training.add_scalar(
                        'conv. approach test bler during meta-training',
                        test_bler_mean_curr_epoch_conven_approach, epochs)
                    print('conven bler: ',
                          test_bler_mean_curr_epoch_conven_approach)
                    print('conven ch. est: ', ch_est_error_avg)
                    print(
                        'as this is for BPSK with maximum likelihood decoder, we only need this once so we stop the code here'
                    )
                    dfdfdfdfd
                else:
                    pass

                test_bler_mean_curr_epoch = test_with_adapt_compact_during_online_meta_training(
                    args, common_dir, epochs, test_snr_range, num_pilots_test,
                    tx_net_for_testtraining, rx_net_for_testtraining, Noise,
                    Noise_relax, actual_channel_num, PATH_before_adapt_tx,
                    PATH_before_adapt_rx)
                writer_meta_training.add_scalar(
                    'test bler during meta-training',
                    test_bler_mean_curr_epoch, epochs)
                test_bler_per_meta_training_epochs.append(
                    test_bler_mean_curr_epoch)

        first_loss_rx = 0
        second_loss_rx = 0
        first_loss_tx = 0
        second_loss_tx = 0
        iter_in_sampled_device = 0  # for averaging meta-devices
        for ind_meta_dev in range(args.tasks_per_metaupdate):
            if args.if_fix_random_seed:
                reset_randomness(
                    random_seed_init
                )  # to make noise same as much as possible for joint and meta
                random_seed_init += 1
            else:
                pass
            if args.if_Rayleigh_channel_model_AR:
                current_channel = channel_set_gen_AR(
                    args.tap_num, previous_channel[ind_meta_dev], args.rho
                )  # num_channels = 1 since we are generating per channel
                previous_channel[ind_meta_dev] = current_channel
            else:
                raise NotImplementedError

            if args.if_joint_training:
                iter_in_sampled_device, first_loss_curr_rx, first_loss_curr_tx, second_loss_curr_rx, second_loss_curr_tx = joint_training(
                    args, iter_in_sampled_device, tx_net, rx_net,
                    current_channel, Noise, Noise_relax)
            elif args.if_joint_training_tx_meta_training_rx:
                iter_in_sampled_device, first_loss_curr_rx, first_loss_curr_tx, second_loss_curr_rx, second_loss_curr_tx = maml_for_rx_joint_for_tx(
                    args, iter_in_sampled_device, tx_net, rx_net,
                    current_channel, Noise, Noise_relax)
            else:  # maml
                raise NotImplementedError

            first_loss_rx = first_loss_rx + first_loss_curr_rx
            second_loss_rx = second_loss_rx + second_loss_curr_rx
            first_loss_tx = first_loss_tx + first_loss_curr_tx
            second_loss_tx = second_loss_tx + second_loss_curr_tx
        first_loss_tx = first_loss_tx / args.tasks_per_metaupdate
        second_loss_tx = second_loss_tx / args.tasks_per_metaupdate  # we joint train tx so we only have one loss (which means, first_loss_tx = second_loss_tx)
        first_loss_rx = first_loss_rx / args.tasks_per_metaupdate
        second_loss_rx = second_loss_rx / args.tasks_per_metaupdate
        if args.if_TB_loss_ignore:
            pass
        else:
            writer_meta_training.add_scalar('RX loss before local adaptation',
                                            first_loss_rx, epochs)
            writer_meta_training.add_scalar('RX loss after local adaptation',
                                            second_loss_rx, epochs)
            writer_meta_training.add_scalar('TX loss', first_loss_tx, epochs)
        if args.if_use_stopping_criteria_during_meta_training:
            if second_loss_rx < second_loss_rx_best_for_stopping_criteria:
                curr_path_rx_best_training_loss = PATH_before_adapt_rx_intermediate + 'best_model_based_on_meta_training_loss'
                curr_path_tx_best_training_loss = PATH_before_adapt_tx_intermediate + 'best_model_based_on_meta_training_loss'
                torch.save(rx_net.state_dict(),
                           curr_path_rx_best_training_loss)
                torch.save(tx_net.state_dict(),
                           curr_path_tx_best_training_loss)
            else:
                pass
        else:
            pass

        meta_optimiser_rx.zero_grad()
        meta_optimiser_tx.zero_grad()
        # rx meta-update
        for f in rx_net.parameters():
            f.grad = f.total_grad.clone() / args.tasks_per_metaupdate
        meta_optimiser_rx.step()  # Adam
        # tx meta-update
        if args.fix_bpsk_tx:  # nothing to meta-learn for tx since tx is BPSK encoder
            pass
        else:
            for f in tx_net.parameters():
                f.grad = f.total_grad.clone() / args.tasks_per_metaupdate
            meta_optimiser_tx.step()  # Adam

        if epochs % args.meta_tr_epoch_num_for_test == 0:
            os.makedirs(test_result_PATH_per_meta_training_test_bler_dir +
                        'epochs/' + str(epochs) + '/')
            test_result_PATH_per_meta_training_test_bler = test_result_PATH_per_meta_training_test_bler_dir + 'epochs/' + str(
                epochs) + '/' + 'test_result_per_meta_training_epochs.mat'
            save_test_result_dict_total_per_meta_training_test_bler[
                'test_bler_during_meta_training'] = test_bler_per_meta_training_epochs
            sio.savemat(
                test_result_PATH_per_meta_training_test_bler,
                save_test_result_dict_total_per_meta_training_test_bler)
        else:
            pass

    return test_bler_per_meta_training_epochs, channel_per_meta_training_epochs
    if args.if_exp_over_multi_pilots_test:
        test_num_pilots_available = [1,2,4,8,16,32,64,128]
    else:
        test_num_pilots_available = [8]
    print('test available pilots: ', test_num_pilots_available)

    test_result_PATH_per_meta_training_test_bler = common_dir + 'test_result_during_meta_training/' + 'test_result_per_meta_training_epochs.mat'

    save_test_result_dict_total_per_meta_training_test_bler = {}

    # complex symbol
    actual_channel_num = args.channel_num * 2

    if args.if_fix_random_seed:
        reset_randomness(args.random_seed + 999)

    tx_net = tx_dnn(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder, n=actual_channel_num, if_bias=args.if_bias,
              if_relu=args.if_relu)

    if torch.cuda.is_available():
        tx_net = tx_net.to(args.device)

    if args.if_fix_random_seed:
        reset_randomness(args.random_seed + 999)

    tx_net_for_testtraining = tx_dnn(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder,
                                     n=actual_channel_num,
                                     if_bias=args.if_bias,
                                     if_relu=args.if_relu)