示例#1
0
def test_per_channel_per_snr(args, h, net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt, if_val):
    if torch.cuda.is_available():
        net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt))
    else:
        net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt, map_location = torch.device('cpu')))


    batch_size = args.test_size
    success_test = 0
    Eb_over_N_test = pow(10, (test_snr / 10))
    R = args.bit_num / args.channel_num
    noise_var_test = 1 / (2 * R * Eb_over_N_test)

    Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num),
                                                                            noise_var_test * torch.eye(
                                                                                actual_channel_num))
    m_test, label_test = message_gen(args.bit_num, batch_size)
    m_test = m_test.type(torch.FloatTensor).to(args.device)
    label_test = label_test.type(torch.LongTensor).to(args.device)

    out_test = net_for_testtraining(m_test, h, Noise_test, args.device, args.if_RTN)
    for ind_mb in range(label_test.shape[0]):
        assert label_test.shape[0] == batch_size
        if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]:  # means correct classification
            success_test += 1
        else:
            pass
    accuracy = success_test / label_test.shape[0]
    if not if_val:
        print('for snr: ', test_snr, 'bler: ', 1 - accuracy)

    return 1 - accuracy
def test_per_channel_per_snr_conven_approach(args, est_h, test_size, h, test_snr, actual_channel_num):
    batch_size = test_size
    success_test = 0
    Eb_over_N_test = pow(10, (test_snr / 10))
    R = args.bit_num / args.channel_num
    noise_var_test = 1 / (2 * R * Eb_over_N_test)
    Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num),
                                                                            noise_var_test * torch.eye(actual_channel_num))
    print('payload num for conv', batch_size)
    m_test, label_test = message_gen(args.bit_num, batch_size)
    label_test = label_test.type(torch.LongTensor).to(args.device)

    if args.if_fix_random_seed:
        reset_randomness(args.random_seed + 7777) # always have same noise for test since this can see the actual effect of # of adaptations

    actual_transmitted_symbol = four_qam_uncoded_modulation(args.bit_num, args.channel_num, label_test)#m_test)

    received_signal = channel(h, actual_transmitted_symbol, Noise_test, args.device, args.if_AWGN)
    # rx (maximum likelihood)
    out_test = demodulation(args.bit_num, args.channel_num, args.tap_num, est_h, received_signal, args.device)
    for ind_mb in range(label_test.shape[0]):
        assert label_test.shape[0] == batch_size
        if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]:  # means correct classification
            success_test += 1
        else:
            pass
    accuracy = success_test / label_test.shape[0]

    print('accuracy', accuracy)

    return 1 - accuracy
def one_frame_conventional_training_tx_nn_rx_nn(args, h, Noise, Noise_relax,
                                                tx_net_for_testtraining,
                                                rx_net_for_testtraining,
                                                rx_testtraining_optimiser,
                                                epochs,
                                                num_pilots_test_in_one_mb):
    # only for runtime
    relax_sigma = args.relax_sigma
    m, label = message_gen(args.bit_num, num_pilots_test_in_one_mb)
    m = m.type(torch.FloatTensor).to(args.device)
    label = label.type(torch.LongTensor).to(args.device)
    tx_net_for_testtraining.zero_grad()
    # tx
    tx_symb_mean, actual_transmitted_symbol = tx_net_for_testtraining(
        m, args.device, relax_sigma, Noise_relax)
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise, args.device,
                              args.if_AWGN)

    #### we can do multiple training with given received_signal
    for rx_training_iter in range(args.fix_tx_multi_adapt_rx_iter_num):
        rx_net_for_testtraining.zero_grad()
        received_signal_curr_mb = received_signal
        label_curr_mb = label
        out = rx_net_for_testtraining(received_signal_curr_mb, args.if_RTN,
                                      args.device)
        loss_rx = receiver_loss(out, label_curr_mb)

        loss_rx.backward()

        if args.if_test_training_adam:
            if args.if_adam_after_sgd:
                if epochs < args.num_meta_local_updates:
                    for f in rx_net_for_testtraining.parameters():
                        if f.grad is not None:
                            f.data.sub_(f.grad.data * args.lr_meta_inner)
                elif epochs == args.num_meta_local_updates:
                    rx_testtraining_optimiser = torch.optim.Adam(
                        rx_net_for_testtraining.parameters(),
                        args.lr_testtraining)
                    rx_testtraining_optimiser.step()
                else:
                    rx_testtraining_optimiser.step()
            else:
                rx_testtraining_optimiser.step()
        else:
            for f in rx_net_for_testtraining.parameters():
                if f.grad is not None:
                    f.data.sub_(f.grad.detach() * args.lr_testtraining)
    loss_tx = 0

    return rx_testtraining_optimiser, float(loss_rx), float(loss_tx)
def one_frame_conventional_training_tx_bpsk_rx_nn(args, h, Noise,
                                                  rx_net_for_testtraining,
                                                  rx_testtraining_optimiser,
                                                  num_pilots_test_in_one_mb):
    # only for runtime
    m, label = message_gen(args.bit_num, num_pilots_test_in_one_mb)
    m = m.type(torch.FloatTensor).to(args.device)
    label = label.type(torch.LongTensor).to(args.device)

    tx_net_for_testtraining = None  # we do not need tx neural net
    # tx
    actual_transmitted_symbol = four_qam_uncoded_modulation(
        args.bit_num, args.channel_num, label)  # label instead m
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise, args.device,
                              args.if_AWGN)

    #### we can do multiple training with given received_signal
    for rx_training_iter in range(args.fix_tx_multi_adapt_rx_iter_num):
        rx_net_for_testtraining.zero_grad()

        received_signal_curr_mb = received_signal
        label_curr_mb = label

        out = rx_net_for_testtraining(received_signal_curr_mb, args.if_RTN,
                                      args.device)
        loss_rx = receiver_loss(out, label_curr_mb)

        loss_rx.backward()

        if args.if_test_training_adam:
            if args.if_adam_after_sgd:
                if rx_training_iter < args.num_meta_local_updates:
                    for f in rx_net_for_testtraining.parameters():
                        if f.grad is not None:
                            f.data.sub_(f.grad.data * args.lr_meta_inner)
                elif rx_training_iter == args.num_meta_local_updates:
                    rx_testtraining_optimiser = torch.optim.Adam(
                        rx_net_for_testtraining.parameters(),
                        args.lr_testtraining)
                    rx_testtraining_optimiser.step()
                else:
                    rx_testtraining_optimiser.step()
            else:
                rx_testtraining_optimiser.step()
        else:
            for f in rx_net_for_testtraining.parameters():
                if f.grad is not None:
                    f.data.sub_(f.grad.detach() * args.lr_testtraining)
    loss_tx = 0
    return rx_testtraining_optimiser, float(loss_rx), float(loss_tx)
def one_frame_joint_training(args, h, Noise, Noise_relax, curr_tx_net_list,
                             curr_rx_net_list, init_tx_net_list,
                             init_rx_net_list):
    # we only transmit once and update both rx and tx since we are consdiering stochastic encoder always
    # joint training can be obtained via MAML without no inner update
    tx_meta_intermediate = meta_tx(if_relu=args.if_relu)
    rx_meta_intermediate = meta_rx(if_relu=args.if_relu)

    relax_sigma = args.relax_sigma

    m, label = message_gen(args.bit_num, args.pilots_num_meta_train_query)
    m = m.type(torch.FloatTensor).to(args.device)
    label = label.type(torch.LongTensor).to(args.device)

    if args.fix_bpsk_tx:
        actual_transmitted_symbol = four_qam_uncoded_modulation(
            args.bit_num, args.channel_num, label)  # label instead m
    else:
        # tx
        tx_symb_mean, actual_transmitted_symbol = tx_meta_intermediate(
            m, curr_tx_net_list, args.if_bias, args.device, relax_sigma,
            Noise_relax)
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise, args.device,
                              args.if_AWGN)
    # rx
    out = rx_meta_intermediate(received_signal, curr_rx_net_list, args.if_bias,
                               args.device, args.if_RTN)

    # for rx
    loss_rx = receiver_loss(out, label)
    joint_grad_rx = torch.autograd.grad(loss_rx,
                                        init_rx_net_list,
                                        create_graph=False)

    if args.fix_bpsk_tx:
        joint_grad_tx = None
        loss_tx = 0
    else:
        received_reward_from_rx = feedback_from_rx(out, label)
        loss_tx = transmitter_loss(actual_transmitted_symbol, tx_symb_mean,
                                   received_reward_from_rx, args.relax_sigma)
        joint_grad_tx = torch.autograd.grad(loss_tx,
                                            init_tx_net_list,
                                            create_graph=False,
                                            retain_graph=True)

    return joint_grad_rx, joint_grad_tx, float(loss_rx), float(loss_tx)
def test_per_channel_per_snr(args, test_size, h, tx_net_for_testtraining, rx_net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt_tx, PATH_after_adapt_rx):

    tx_net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt_tx))
    rx_net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt_rx))

    batch_size = test_size
    success_test = 0
    Eb_over_N_test = pow(10, (test_snr / 10))
    R = args.bit_num / args.channel_num
    noise_var_test = 1 / (2 * R * Eb_over_N_test)
    Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num),
                                                                            noise_var_test * torch.eye(actual_channel_num))
    m_test, label_test = message_gen(args.bit_num, batch_size)
    m_test = m_test.type(torch.FloatTensor).to(args.device)
    label_test = label_test.type(torch.LongTensor).to(args.device)

    for f in tx_net_for_testtraining.parameters():
        if f.grad is not None:
            f.grad.detach()
            f.grad.zero_()

    rx_net_for_testtraining.zero_grad()

    if args.if_fix_random_seed:
        reset_randomness(args.random_seed + 7777) # always have same noise for test since this can see the actual effect of # of adaptations


    if args.fix_bpsk_tx_train_nn_rx_during_runtime:
        actual_transmitted_symbol = four_qam_uncoded_modulation(args.bit_num, args.channel_num,
                                                                label_test)  # label instead m
    else:
        tx_symb_mean, actual_transmitted_symbol = tx_net_for_testtraining(m_test, args.device, 0, None) # no relaxation
    tx_symb_mean = None # we don't need during test
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise_test, args.device, args.if_AWGN)
    # rx
    out_test = rx_net_for_testtraining(received_signal, args.if_RTN, args.device)

    for ind_mb in range(label_test.shape[0]):
        assert label_test.shape[0] == batch_size
        if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]:  # means correct classification
            success_test += 1
        else:
            pass
    accuracy = success_test / label_test.shape[0]

    return 1 - accuracy
示例#7
0
def test_training(args, h, net_for_testtraining, Noise, PATH_before_adapt, PATH_after_adapt, adapt_steps): #PATH_before_adapt can be meta-learneds
    # initialize network (net_for_testtraining) (net is for meta-training)
    if torch.cuda.is_available():
        net_for_testtraining.load_state_dict(torch.load(PATH_before_adapt))
    else:
        net_for_testtraining.load_state_dict(torch.load(PATH_before_adapt, map_location = torch.device('cpu')))
    if args.if_test_training_adam and not args.if_adam_after_sgd:
        testtraining_optimiser = torch.optim.Adam(net_for_testtraining.parameters(), args.lr_testtraining)
    else:
        pass

    num_adapt = adapt_steps
    for epochs in range(num_adapt):
        m, label = message_gen(args.bit_num, args.mb_size)
        m = m.type(torch.FloatTensor).to(args.device)
        label = label.type(torch.LongTensor).to(args.device)
        for f in net_for_testtraining.parameters():
            if f.grad is not None:
                f.grad.detach()
                f.grad.zero_()

        out = net_for_testtraining(m, h, Noise, args.device, args.if_RTN)
        loss = torch.nn.functional.cross_entropy(out, label)
        # grad calculation
        loss.backward()
        ### adapt (update) parameter
        if args.if_test_training_adam:
            if args.if_adam_after_sgd:
                if epochs < args.num_meta_local_updates:
                    for f in net_for_testtraining.parameters():
                        if f.grad is not None:
                            f.data.sub_(f.grad.data * args.lr_meta_inner)
                elif epochs == args.num_meta_local_updates:
                    testtraining_optimiser = torch.optim.Adam(net_for_testtraining.parameters(),
                                                              args.lr_testtraining)
                    testtraining_optimiser.step()
                else:
                    testtraining_optimiser.step()
            else:
                testtraining_optimiser.step()
        else:
            for f in net_for_testtraining.parameters():
                if f.grad is not None:
                    f.data.sub_(f.grad.data * args.lr_testtraining)
    # saved adapted network for calculate BLER
    torch.save(net_for_testtraining.state_dict(), PATH_after_adapt)
def one_iter_mmse_ch_est(args, h, Noise, num_pilots_test):
    m, label = message_gen(args.bit_num, num_pilots_test)
    m = m.type(torch.FloatTensor).to(args.device)
    label = label.type(torch.LongTensor).to(args.device)
    actual_transmitted_symbol = four_qam_uncoded_modulation(
        args.bit_num, args.channel_num, label)  # label instead m
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise, args.device,
                              args.if_AWGN)
    est_h = channel_estimation(args, actual_transmitted_symbol,
                               received_signal, args.tap_num)
    h_numpy = h.cpu().numpy()
    h_complex = np.zeros((args.tap_num, 1), dtype=complex)
    for ind_h in range(args.tap_num):
        h_complex[ind_h] = h_numpy[2 * ind_h] + h_numpy[2 * ind_h + 1] * 1j
    error_h = np.matmul(np.conj(np.transpose(h_complex - est_h)),
                        h_complex - est_h)
    return est_h, error_h
示例#9
0
def one_iter_joint_training_sim_fix_stoch_encoder(
        args, h, Noise, Noise_relax, Noise_feedback, curr_tx_net_list,
        curr_rx_net_list, init_tx_net_list, init_rx_net_list, if_local_update,
        remove_dependency_to_updated_para_tmp, inner_loop):
    # given net list, run meta_net
    # we only transmit once and update both rx and tx since we are consdiering stochastic encoder always
    tx_meta_intermediate = meta_tx(if_relu=args.if_relu)
    rx_meta_intermediate = meta_rx(if_relu=args.if_relu)
    relax_sigma = args.relax_sigma

    m, label = message_gen(args.bit_num, args.mb_size)
    m = m.type(torch.FloatTensor).to(args.device)
    label = label.type(torch.LongTensor).to(args.device)
    # tx
    tx_symb_mean, actual_transmitted_symbol = tx_meta_intermediate(
        m, curr_tx_net_list, args.if_bias, args.device, relax_sigma,
        Noise_relax)
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise, args.device,
                              args.if_AWGN)
    # rx
    out = rx_meta_intermediate(received_signal, curr_rx_net_list, args.if_bias,
                               args.device, args.if_RTN)
    # for rx
    loss_rx = receiver_loss(out, label)
    joint_grad_rx = torch.autograd.grad(loss_rx,
                                        init_rx_net_list,
                                        create_graph=False)
    received_reward_from_rx = feedback_from_rx(out, label, Noise_feedback,
                                               args.device)
    loss_tx = transmitter_loss(actual_transmitted_symbol, tx_symb_mean,
                               received_reward_from_rx, args.relax_sigma)
    joint_grad_tx = torch.autograd.grad(loss_tx,
                                        init_tx_net_list,
                                        create_graph=False,
                                        retain_graph=True)

    return joint_grad_rx, joint_grad_tx, float(loss_rx), float(loss_tx)
def one_frame_hybrid_training(args, h, Noise, Noise_relax, para_tx_net_list,
                              para_rx_net_list):
    tx_meta_intermediate = meta_tx(if_relu=args.if_relu)
    rx_meta_intermediate = meta_rx(if_relu=args.if_relu)
    relax_sigma = args.relax_sigma
    m, label = message_gen(
        args.bit_num,
        args.pilots_num_meta_train_query)  # always send this amount

    m = m.type(torch.FloatTensor).to(args.device)
    label = label.type(torch.LongTensor).to(args.device)
    #### we only transmit once and do meta-learning (rx) and joint learning (tx)
    if args.fix_bpsk_tx:
        actual_transmitted_symbol = four_qam_uncoded_modulation(
            args.bit_num, args.channel_num, label)  # label instead m
    else:
        # tx
        tx_symb_mean, actual_transmitted_symbol = tx_meta_intermediate(
            m, para_tx_net_list, args.if_bias, args.device, relax_sigma,
            Noise_relax)
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise, args.device,
                              args.if_AWGN)

    ###### change support num pilots here, transmission is only done once with mb_size_meta_test
    num_pilots_for_meta_train_supp = args.pilots_num_meta_train_supp
    # rx
    for ind_sim_iter_rx in range(args.num_meta_local_updates):
        if ind_sim_iter_rx == 0:
            out = rx_meta_intermediate(
                received_signal[0:num_pilots_for_meta_train_supp],
                para_rx_net_list, args.if_bias, args.device, args.if_RTN)
            loss_rx = receiver_loss(out,
                                    label[0:num_pilots_for_meta_train_supp])
            local_grad_rx = torch.autograd.grad(loss_rx,
                                                para_rx_net_list,
                                                create_graph=True)
            intermediate_updated_para_list_rx = list(
                map(lambda p: p[1] - args.lr_meta_inner * p[0],
                    zip(local_grad_rx, para_rx_net_list)))
            first_loss_curr_rx = float(loss_rx.clone().detach())
        else:
            out = rx_meta_intermediate(
                received_signal[0:num_pilots_for_meta_train_supp],
                intermediate_updated_para_list_rx, args.if_bias, args.device,
                args.if_RTN)
            loss_rx = receiver_loss(out,
                                    label[0:num_pilots_for_meta_train_supp])
            local_grad_rx = torch.autograd.grad(
                loss_rx, intermediate_updated_para_list_rx, create_graph=True)
            intermediate_updated_para_list_rx = list(
                map(lambda p: p[1] - args.lr_meta_inner * p[0],
                    zip(local_grad_rx, intermediate_updated_para_list_rx)))

    ### now meta-gradient
    if args.separate_meta_training_support_query_set:
        end_ind_for_supp = num_pilots_for_meta_train_supp
    else:  # use whole for query
        end_ind_for_supp = 0

    out = rx_meta_intermediate(received_signal[end_ind_for_supp:],
                               intermediate_updated_para_list_rx, args.if_bias,
                               args.device, args.if_RTN)
    loss_rx_after_local_adaptation = receiver_loss(out,
                                                   label[end_ind_for_supp:])
    meta_grad_rx = torch.autograd.grad(loss_rx_after_local_adaptation,
                                       para_rx_net_list,
                                       create_graph=False)

    # use all transmission blocks in one frame for joint training of encoder
    end_ind_for_supp = 0
    if args.separate_meta_training_support_query_set:  # we need to get out for whole messages
        out = rx_meta_intermediate(received_signal,
                                   intermediate_updated_para_list_rx,
                                   args.if_bias, args.device, args.if_RTN)
    else:
        pass

    if args.fix_bpsk_tx:
        loss_tx = 0
        joint_grad_tx = None
    else:
        # joint grad. for tx
        received_reward_from_rx = feedback_from_rx(
            out, label[end_ind_for_supp:])  # feedback with adapted rx
        loss_tx = transmitter_loss(
            actual_transmitted_symbol[end_ind_for_supp:],
            tx_symb_mean[end_ind_for_supp:], received_reward_from_rx,
            args.relax_sigma)
        joint_grad_tx = torch.autograd.grad(loss_tx,
                                            para_tx_net_list,
                                            create_graph=False,
                                            retain_graph=True)

    return meta_grad_rx, joint_grad_tx, first_loss_curr_rx, float(
        loss_tx), float(loss_rx_after_local_adaptation), float(loss_tx)
示例#11
0
def one_iter_rx_meta_tx_joint_sim_fix_stoch_encoder(args, h, Noise,
                                                    Noise_relax,
                                                    Noise_feedback,
                                                    para_tx_net_list,
                                                    para_rx_net_list):
    tx_meta_intermediate = meta_tx(if_relu=args.if_relu)
    rx_meta_intermediate = meta_rx(if_relu=args.if_relu)
    relax_sigma = args.relax_sigma
    m, label = message_gen(args.bit_num, args.mb_size)
    m = m.type(torch.FloatTensor).to(args.device)
    label = label.type(torch.LongTensor).to(args.device)
    #### we only transmit once and do meta-learning (rx) and joint learning (tx)
    # tx
    tx_symb_mean, actual_transmitted_symbol = tx_meta_intermediate(
        m, para_tx_net_list, args.if_bias, args.device, relax_sigma,
        Noise_relax)
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise, args.device,
                              args.if_AWGN)
    # rx
    for ind_sim_iter_rx in range(args.num_meta_local_updates):
        if ind_sim_iter_rx == 0:
            out = rx_meta_intermediate(received_signal, para_rx_net_list,
                                       args.if_bias, args.device, args.if_RTN)
            loss_rx = receiver_loss(out, label)
            local_grad_rx = torch.autograd.grad(loss_rx,
                                                para_rx_net_list,
                                                create_graph=True)
            intermediate_updated_para_list_rx = list(
                map(lambda p: p[1] - args.lr_meta_inner * p[0],
                    zip(local_grad_rx, para_rx_net_list)))
            first_loss_curr_rx = float(loss_rx.clone().detach())
        else:
            out = rx_meta_intermediate(received_signal,
                                       intermediate_updated_para_list_rx,
                                       args.if_bias, args.device, args.if_RTN)
            loss_rx = receiver_loss(out, label)
            local_grad_rx = torch.autograd.grad(
                loss_rx, intermediate_updated_para_list_rx, create_graph=True)
            intermediate_updated_para_list_rx = list(
                map(lambda p: p[1] - args.lr_meta_inner * p[0],
                    zip(local_grad_rx, intermediate_updated_para_list_rx)))

    ### now meta-gradient
    out = rx_meta_intermediate(received_signal,
                               intermediate_updated_para_list_rx, args.if_bias,
                               args.device, args.if_RTN)
    loss_rx_after_local_adaptation = receiver_loss(out, label)
    meta_grad_rx = torch.autograd.grad(loss_rx_after_local_adaptation,
                                       para_rx_net_list,
                                       create_graph=False)

    # joint grad. for tx
    received_reward_from_rx = feedback_from_rx(
        out, label, Noise_feedback, args.device)  # feedback with adapted rx
    loss_tx = transmitter_loss(actual_transmitted_symbol, tx_symb_mean,
                               received_reward_from_rx, args.relax_sigma)
    joint_grad_tx = torch.autograd.grad(loss_tx,
                                        para_tx_net_list,
                                        create_graph=False,
                                        retain_graph=True)

    return meta_grad_rx, joint_grad_tx, first_loss_curr_rx, float(
        loss_tx), float(loss_rx_after_local_adaptation), float(loss_tx)
示例#12
0
def one_iter_sim_fix_stoch_encoder_from_rx_meta_tx_joint(
        args, h, Noise, Noise_relax, Noise_feedback, tx_net_for_testtraining,
        rx_net_for_testtraining, tx_testtraining_optimiser,
        rx_testtraining_optimiser, epochs):
    relax_sigma = args.relax_sigma
    m, label = message_gen(args.bit_num, args.mb_size)
    m = m.type(torch.FloatTensor).to(args.device)
    label = label.type(torch.LongTensor).to(args.device)

    tx_net_for_testtraining.zero_grad()
    rx_net_for_testtraining.zero_grad()
    # tx
    tx_symb_mean, actual_transmitted_symbol = tx_net_for_testtraining(
        m, args.device, relax_sigma, Noise_relax)
    # channel
    received_signal = channel(h, actual_transmitted_symbol, Noise, args.device,
                              args.if_AWGN)
    # rx
    out = rx_net_for_testtraining(received_signal, args.if_RTN, args.device)

    loss_rx = receiver_loss(out, label)
    loss_rx.backward()

    if args.if_test_training_adam:
        if args.if_adam_after_sgd:
            if epochs < args.num_meta_local_updates:
                for f in rx_net_for_testtraining.parameters():
                    if f.grad is not None:
                        f.data.sub_(f.grad.data * args.lr_meta_inner)
            elif epochs == args.num_meta_local_updates:
                rx_testtraining_optimiser = torch.optim.Adam(
                    rx_net_for_testtraining.parameters(), args.lr_testtraining)
                rx_testtraining_optimiser.step()
            else:
                rx_testtraining_optimiser.step()
        else:
            rx_testtraining_optimiser.step()
    else:
        for f in rx_net_for_testtraining.parameters():
            if f.grad is not None:
                f.data.sub_(f.grad.detach() * args.lr_testtraining)

    if args.fix_joint_trained_tx_only_adapt_meta_trained_rx:
        tx_testtraining_optimiser = None
        loss_tx = 0
    else:
        out = rx_net_for_testtraining(received_signal, args.if_RTN,
                                      args.device)
        received_reward_from_rx = feedback_from_rx(out, label, Noise_feedback,
                                                   args.device)
        loss_tx = transmitter_loss(actual_transmitted_symbol, tx_symb_mean,
                                   received_reward_from_rx, args.relax_sigma)
        loss_tx.backward()
        if args.if_test_training_adam:
            if args.if_adam_after_sgd:
                if epochs < args.num_meta_local_updates:
                    for f in tx_net_for_testtraining.parameters():
                        if f.grad is not None:
                            f.data.sub_(f.grad.data * args.lr_meta_inner)
                elif epochs == args.num_meta_local_updates:
                    tx_testtraining_optimiser = torch.optim.Adam(
                        tx_net_for_testtraining.parameters(),
                        args.lr_testtraining)
                    tx_testtraining_optimiser.step()
                else:
                    tx_testtraining_optimiser.step()
            else:
                tx_testtraining_optimiser.step()
        else:
            for f in tx_net_for_testtraining.parameters():
                if f.grad is not None:
                    f.data.sub_(f.grad.detach() * args.lr_testtraining)

    return rx_testtraining_optimiser, tx_testtraining_optimiser, float(
        loss_rx), float(loss_tx)