def test_per_channel_per_snr(args, h, net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt, if_val): if torch.cuda.is_available(): net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt)) else: net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt, map_location = torch.device('cpu'))) batch_size = args.test_size success_test = 0 Eb_over_N_test = pow(10, (test_snr / 10)) R = args.bit_num / args.channel_num noise_var_test = 1 / (2 * R * Eb_over_N_test) Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num), noise_var_test * torch.eye( actual_channel_num)) m_test, label_test = message_gen(args.bit_num, batch_size) m_test = m_test.type(torch.FloatTensor).to(args.device) label_test = label_test.type(torch.LongTensor).to(args.device) out_test = net_for_testtraining(m_test, h, Noise_test, args.device, args.if_RTN) for ind_mb in range(label_test.shape[0]): assert label_test.shape[0] == batch_size if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]: # means correct classification success_test += 1 else: pass accuracy = success_test / label_test.shape[0] if not if_val: print('for snr: ', test_snr, 'bler: ', 1 - accuracy) return 1 - accuracy
def test_per_channel_per_snr_conven_approach(args, est_h, test_size, h, test_snr, actual_channel_num): batch_size = test_size success_test = 0 Eb_over_N_test = pow(10, (test_snr / 10)) R = args.bit_num / args.channel_num noise_var_test = 1 / (2 * R * Eb_over_N_test) Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num), noise_var_test * torch.eye(actual_channel_num)) print('payload num for conv', batch_size) m_test, label_test = message_gen(args.bit_num, batch_size) label_test = label_test.type(torch.LongTensor).to(args.device) if args.if_fix_random_seed: reset_randomness(args.random_seed + 7777) # always have same noise for test since this can see the actual effect of # of adaptations actual_transmitted_symbol = four_qam_uncoded_modulation(args.bit_num, args.channel_num, label_test)#m_test) received_signal = channel(h, actual_transmitted_symbol, Noise_test, args.device, args.if_AWGN) # rx (maximum likelihood) out_test = demodulation(args.bit_num, args.channel_num, args.tap_num, est_h, received_signal, args.device) for ind_mb in range(label_test.shape[0]): assert label_test.shape[0] == batch_size if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]: # means correct classification success_test += 1 else: pass accuracy = success_test / label_test.shape[0] print('accuracy', accuracy) return 1 - accuracy
def one_frame_conventional_training_tx_nn_rx_nn(args, h, Noise, Noise_relax, tx_net_for_testtraining, rx_net_for_testtraining, rx_testtraining_optimiser, epochs, num_pilots_test_in_one_mb): # only for runtime relax_sigma = args.relax_sigma m, label = message_gen(args.bit_num, num_pilots_test_in_one_mb) m = m.type(torch.FloatTensor).to(args.device) label = label.type(torch.LongTensor).to(args.device) tx_net_for_testtraining.zero_grad() # tx tx_symb_mean, actual_transmitted_symbol = tx_net_for_testtraining( m, args.device, relax_sigma, Noise_relax) # channel received_signal = channel(h, actual_transmitted_symbol, Noise, args.device, args.if_AWGN) #### we can do multiple training with given received_signal for rx_training_iter in range(args.fix_tx_multi_adapt_rx_iter_num): rx_net_for_testtraining.zero_grad() received_signal_curr_mb = received_signal label_curr_mb = label out = rx_net_for_testtraining(received_signal_curr_mb, args.if_RTN, args.device) loss_rx = receiver_loss(out, label_curr_mb) loss_rx.backward() if args.if_test_training_adam: if args.if_adam_after_sgd: if epochs < args.num_meta_local_updates: for f in rx_net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.data * args.lr_meta_inner) elif epochs == args.num_meta_local_updates: rx_testtraining_optimiser = torch.optim.Adam( rx_net_for_testtraining.parameters(), args.lr_testtraining) rx_testtraining_optimiser.step() else: rx_testtraining_optimiser.step() else: rx_testtraining_optimiser.step() else: for f in rx_net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.detach() * args.lr_testtraining) loss_tx = 0 return rx_testtraining_optimiser, float(loss_rx), float(loss_tx)
def one_frame_conventional_training_tx_bpsk_rx_nn(args, h, Noise, rx_net_for_testtraining, rx_testtraining_optimiser, num_pilots_test_in_one_mb): # only for runtime m, label = message_gen(args.bit_num, num_pilots_test_in_one_mb) m = m.type(torch.FloatTensor).to(args.device) label = label.type(torch.LongTensor).to(args.device) tx_net_for_testtraining = None # we do not need tx neural net # tx actual_transmitted_symbol = four_qam_uncoded_modulation( args.bit_num, args.channel_num, label) # label instead m # channel received_signal = channel(h, actual_transmitted_symbol, Noise, args.device, args.if_AWGN) #### we can do multiple training with given received_signal for rx_training_iter in range(args.fix_tx_multi_adapt_rx_iter_num): rx_net_for_testtraining.zero_grad() received_signal_curr_mb = received_signal label_curr_mb = label out = rx_net_for_testtraining(received_signal_curr_mb, args.if_RTN, args.device) loss_rx = receiver_loss(out, label_curr_mb) loss_rx.backward() if args.if_test_training_adam: if args.if_adam_after_sgd: if rx_training_iter < args.num_meta_local_updates: for f in rx_net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.data * args.lr_meta_inner) elif rx_training_iter == args.num_meta_local_updates: rx_testtraining_optimiser = torch.optim.Adam( rx_net_for_testtraining.parameters(), args.lr_testtraining) rx_testtraining_optimiser.step() else: rx_testtraining_optimiser.step() else: rx_testtraining_optimiser.step() else: for f in rx_net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.detach() * args.lr_testtraining) loss_tx = 0 return rx_testtraining_optimiser, float(loss_rx), float(loss_tx)
def one_frame_joint_training(args, h, Noise, Noise_relax, curr_tx_net_list, curr_rx_net_list, init_tx_net_list, init_rx_net_list): # we only transmit once and update both rx and tx since we are consdiering stochastic encoder always # joint training can be obtained via MAML without no inner update tx_meta_intermediate = meta_tx(if_relu=args.if_relu) rx_meta_intermediate = meta_rx(if_relu=args.if_relu) relax_sigma = args.relax_sigma m, label = message_gen(args.bit_num, args.pilots_num_meta_train_query) m = m.type(torch.FloatTensor).to(args.device) label = label.type(torch.LongTensor).to(args.device) if args.fix_bpsk_tx: actual_transmitted_symbol = four_qam_uncoded_modulation( args.bit_num, args.channel_num, label) # label instead m else: # tx tx_symb_mean, actual_transmitted_symbol = tx_meta_intermediate( m, curr_tx_net_list, args.if_bias, args.device, relax_sigma, Noise_relax) # channel received_signal = channel(h, actual_transmitted_symbol, Noise, args.device, args.if_AWGN) # rx out = rx_meta_intermediate(received_signal, curr_rx_net_list, args.if_bias, args.device, args.if_RTN) # for rx loss_rx = receiver_loss(out, label) joint_grad_rx = torch.autograd.grad(loss_rx, init_rx_net_list, create_graph=False) if args.fix_bpsk_tx: joint_grad_tx = None loss_tx = 0 else: received_reward_from_rx = feedback_from_rx(out, label) loss_tx = transmitter_loss(actual_transmitted_symbol, tx_symb_mean, received_reward_from_rx, args.relax_sigma) joint_grad_tx = torch.autograd.grad(loss_tx, init_tx_net_list, create_graph=False, retain_graph=True) return joint_grad_rx, joint_grad_tx, float(loss_rx), float(loss_tx)
def test_per_channel_per_snr(args, test_size, h, tx_net_for_testtraining, rx_net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt_tx, PATH_after_adapt_rx): tx_net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt_tx)) rx_net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt_rx)) batch_size = test_size success_test = 0 Eb_over_N_test = pow(10, (test_snr / 10)) R = args.bit_num / args.channel_num noise_var_test = 1 / (2 * R * Eb_over_N_test) Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num), noise_var_test * torch.eye(actual_channel_num)) m_test, label_test = message_gen(args.bit_num, batch_size) m_test = m_test.type(torch.FloatTensor).to(args.device) label_test = label_test.type(torch.LongTensor).to(args.device) for f in tx_net_for_testtraining.parameters(): if f.grad is not None: f.grad.detach() f.grad.zero_() rx_net_for_testtraining.zero_grad() if args.if_fix_random_seed: reset_randomness(args.random_seed + 7777) # always have same noise for test since this can see the actual effect of # of adaptations if args.fix_bpsk_tx_train_nn_rx_during_runtime: actual_transmitted_symbol = four_qam_uncoded_modulation(args.bit_num, args.channel_num, label_test) # label instead m else: tx_symb_mean, actual_transmitted_symbol = tx_net_for_testtraining(m_test, args.device, 0, None) # no relaxation tx_symb_mean = None # we don't need during test # channel received_signal = channel(h, actual_transmitted_symbol, Noise_test, args.device, args.if_AWGN) # rx out_test = rx_net_for_testtraining(received_signal, args.if_RTN, args.device) for ind_mb in range(label_test.shape[0]): assert label_test.shape[0] == batch_size if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]: # means correct classification success_test += 1 else: pass accuracy = success_test / label_test.shape[0] return 1 - accuracy
def test_training(args, h, net_for_testtraining, Noise, PATH_before_adapt, PATH_after_adapt, adapt_steps): #PATH_before_adapt can be meta-learneds # initialize network (net_for_testtraining) (net is for meta-training) if torch.cuda.is_available(): net_for_testtraining.load_state_dict(torch.load(PATH_before_adapt)) else: net_for_testtraining.load_state_dict(torch.load(PATH_before_adapt, map_location = torch.device('cpu'))) if args.if_test_training_adam and not args.if_adam_after_sgd: testtraining_optimiser = torch.optim.Adam(net_for_testtraining.parameters(), args.lr_testtraining) else: pass num_adapt = adapt_steps for epochs in range(num_adapt): m, label = message_gen(args.bit_num, args.mb_size) m = m.type(torch.FloatTensor).to(args.device) label = label.type(torch.LongTensor).to(args.device) for f in net_for_testtraining.parameters(): if f.grad is not None: f.grad.detach() f.grad.zero_() out = net_for_testtraining(m, h, Noise, args.device, args.if_RTN) loss = torch.nn.functional.cross_entropy(out, label) # grad calculation loss.backward() ### adapt (update) parameter if args.if_test_training_adam: if args.if_adam_after_sgd: if epochs < args.num_meta_local_updates: for f in net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.data * args.lr_meta_inner) elif epochs == args.num_meta_local_updates: testtraining_optimiser = torch.optim.Adam(net_for_testtraining.parameters(), args.lr_testtraining) testtraining_optimiser.step() else: testtraining_optimiser.step() else: testtraining_optimiser.step() else: for f in net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.data * args.lr_testtraining) # saved adapted network for calculate BLER torch.save(net_for_testtraining.state_dict(), PATH_after_adapt)
def one_iter_mmse_ch_est(args, h, Noise, num_pilots_test): m, label = message_gen(args.bit_num, num_pilots_test) m = m.type(torch.FloatTensor).to(args.device) label = label.type(torch.LongTensor).to(args.device) actual_transmitted_symbol = four_qam_uncoded_modulation( args.bit_num, args.channel_num, label) # label instead m # channel received_signal = channel(h, actual_transmitted_symbol, Noise, args.device, args.if_AWGN) est_h = channel_estimation(args, actual_transmitted_symbol, received_signal, args.tap_num) h_numpy = h.cpu().numpy() h_complex = np.zeros((args.tap_num, 1), dtype=complex) for ind_h in range(args.tap_num): h_complex[ind_h] = h_numpy[2 * ind_h] + h_numpy[2 * ind_h + 1] * 1j error_h = np.matmul(np.conj(np.transpose(h_complex - est_h)), h_complex - est_h) return est_h, error_h
def one_iter_joint_training_sim_fix_stoch_encoder( args, h, Noise, Noise_relax, Noise_feedback, curr_tx_net_list, curr_rx_net_list, init_tx_net_list, init_rx_net_list, if_local_update, remove_dependency_to_updated_para_tmp, inner_loop): # given net list, run meta_net # we only transmit once and update both rx and tx since we are consdiering stochastic encoder always tx_meta_intermediate = meta_tx(if_relu=args.if_relu) rx_meta_intermediate = meta_rx(if_relu=args.if_relu) relax_sigma = args.relax_sigma m, label = message_gen(args.bit_num, args.mb_size) m = m.type(torch.FloatTensor).to(args.device) label = label.type(torch.LongTensor).to(args.device) # tx tx_symb_mean, actual_transmitted_symbol = tx_meta_intermediate( m, curr_tx_net_list, args.if_bias, args.device, relax_sigma, Noise_relax) # channel received_signal = channel(h, actual_transmitted_symbol, Noise, args.device, args.if_AWGN) # rx out = rx_meta_intermediate(received_signal, curr_rx_net_list, args.if_bias, args.device, args.if_RTN) # for rx loss_rx = receiver_loss(out, label) joint_grad_rx = torch.autograd.grad(loss_rx, init_rx_net_list, create_graph=False) received_reward_from_rx = feedback_from_rx(out, label, Noise_feedback, args.device) loss_tx = transmitter_loss(actual_transmitted_symbol, tx_symb_mean, received_reward_from_rx, args.relax_sigma) joint_grad_tx = torch.autograd.grad(loss_tx, init_tx_net_list, create_graph=False, retain_graph=True) return joint_grad_rx, joint_grad_tx, float(loss_rx), float(loss_tx)
def one_frame_hybrid_training(args, h, Noise, Noise_relax, para_tx_net_list, para_rx_net_list): tx_meta_intermediate = meta_tx(if_relu=args.if_relu) rx_meta_intermediate = meta_rx(if_relu=args.if_relu) relax_sigma = args.relax_sigma m, label = message_gen( args.bit_num, args.pilots_num_meta_train_query) # always send this amount m = m.type(torch.FloatTensor).to(args.device) label = label.type(torch.LongTensor).to(args.device) #### we only transmit once and do meta-learning (rx) and joint learning (tx) if args.fix_bpsk_tx: actual_transmitted_symbol = four_qam_uncoded_modulation( args.bit_num, args.channel_num, label) # label instead m else: # tx tx_symb_mean, actual_transmitted_symbol = tx_meta_intermediate( m, para_tx_net_list, args.if_bias, args.device, relax_sigma, Noise_relax) # channel received_signal = channel(h, actual_transmitted_symbol, Noise, args.device, args.if_AWGN) ###### change support num pilots here, transmission is only done once with mb_size_meta_test num_pilots_for_meta_train_supp = args.pilots_num_meta_train_supp # rx for ind_sim_iter_rx in range(args.num_meta_local_updates): if ind_sim_iter_rx == 0: out = rx_meta_intermediate( received_signal[0:num_pilots_for_meta_train_supp], para_rx_net_list, args.if_bias, args.device, args.if_RTN) loss_rx = receiver_loss(out, label[0:num_pilots_for_meta_train_supp]) local_grad_rx = torch.autograd.grad(loss_rx, para_rx_net_list, create_graph=True) intermediate_updated_para_list_rx = list( map(lambda p: p[1] - args.lr_meta_inner * p[0], zip(local_grad_rx, para_rx_net_list))) first_loss_curr_rx = float(loss_rx.clone().detach()) else: out = rx_meta_intermediate( received_signal[0:num_pilots_for_meta_train_supp], intermediate_updated_para_list_rx, args.if_bias, args.device, args.if_RTN) loss_rx = receiver_loss(out, label[0:num_pilots_for_meta_train_supp]) local_grad_rx = torch.autograd.grad( loss_rx, intermediate_updated_para_list_rx, create_graph=True) intermediate_updated_para_list_rx = list( map(lambda p: p[1] - args.lr_meta_inner * p[0], zip(local_grad_rx, intermediate_updated_para_list_rx))) ### now meta-gradient if args.separate_meta_training_support_query_set: end_ind_for_supp = num_pilots_for_meta_train_supp else: # use whole for query end_ind_for_supp = 0 out = rx_meta_intermediate(received_signal[end_ind_for_supp:], intermediate_updated_para_list_rx, args.if_bias, args.device, args.if_RTN) loss_rx_after_local_adaptation = receiver_loss(out, label[end_ind_for_supp:]) meta_grad_rx = torch.autograd.grad(loss_rx_after_local_adaptation, para_rx_net_list, create_graph=False) # use all transmission blocks in one frame for joint training of encoder end_ind_for_supp = 0 if args.separate_meta_training_support_query_set: # we need to get out for whole messages out = rx_meta_intermediate(received_signal, intermediate_updated_para_list_rx, args.if_bias, args.device, args.if_RTN) else: pass if args.fix_bpsk_tx: loss_tx = 0 joint_grad_tx = None else: # joint grad. for tx received_reward_from_rx = feedback_from_rx( out, label[end_ind_for_supp:]) # feedback with adapted rx loss_tx = transmitter_loss( actual_transmitted_symbol[end_ind_for_supp:], tx_symb_mean[end_ind_for_supp:], received_reward_from_rx, args.relax_sigma) joint_grad_tx = torch.autograd.grad(loss_tx, para_tx_net_list, create_graph=False, retain_graph=True) return meta_grad_rx, joint_grad_tx, first_loss_curr_rx, float( loss_tx), float(loss_rx_after_local_adaptation), float(loss_tx)
def one_iter_rx_meta_tx_joint_sim_fix_stoch_encoder(args, h, Noise, Noise_relax, Noise_feedback, para_tx_net_list, para_rx_net_list): tx_meta_intermediate = meta_tx(if_relu=args.if_relu) rx_meta_intermediate = meta_rx(if_relu=args.if_relu) relax_sigma = args.relax_sigma m, label = message_gen(args.bit_num, args.mb_size) m = m.type(torch.FloatTensor).to(args.device) label = label.type(torch.LongTensor).to(args.device) #### we only transmit once and do meta-learning (rx) and joint learning (tx) # tx tx_symb_mean, actual_transmitted_symbol = tx_meta_intermediate( m, para_tx_net_list, args.if_bias, args.device, relax_sigma, Noise_relax) # channel received_signal = channel(h, actual_transmitted_symbol, Noise, args.device, args.if_AWGN) # rx for ind_sim_iter_rx in range(args.num_meta_local_updates): if ind_sim_iter_rx == 0: out = rx_meta_intermediate(received_signal, para_rx_net_list, args.if_bias, args.device, args.if_RTN) loss_rx = receiver_loss(out, label) local_grad_rx = torch.autograd.grad(loss_rx, para_rx_net_list, create_graph=True) intermediate_updated_para_list_rx = list( map(lambda p: p[1] - args.lr_meta_inner * p[0], zip(local_grad_rx, para_rx_net_list))) first_loss_curr_rx = float(loss_rx.clone().detach()) else: out = rx_meta_intermediate(received_signal, intermediate_updated_para_list_rx, args.if_bias, args.device, args.if_RTN) loss_rx = receiver_loss(out, label) local_grad_rx = torch.autograd.grad( loss_rx, intermediate_updated_para_list_rx, create_graph=True) intermediate_updated_para_list_rx = list( map(lambda p: p[1] - args.lr_meta_inner * p[0], zip(local_grad_rx, intermediate_updated_para_list_rx))) ### now meta-gradient out = rx_meta_intermediate(received_signal, intermediate_updated_para_list_rx, args.if_bias, args.device, args.if_RTN) loss_rx_after_local_adaptation = receiver_loss(out, label) meta_grad_rx = torch.autograd.grad(loss_rx_after_local_adaptation, para_rx_net_list, create_graph=False) # joint grad. for tx received_reward_from_rx = feedback_from_rx( out, label, Noise_feedback, args.device) # feedback with adapted rx loss_tx = transmitter_loss(actual_transmitted_symbol, tx_symb_mean, received_reward_from_rx, args.relax_sigma) joint_grad_tx = torch.autograd.grad(loss_tx, para_tx_net_list, create_graph=False, retain_graph=True) return meta_grad_rx, joint_grad_tx, first_loss_curr_rx, float( loss_tx), float(loss_rx_after_local_adaptation), float(loss_tx)
def one_iter_sim_fix_stoch_encoder_from_rx_meta_tx_joint( args, h, Noise, Noise_relax, Noise_feedback, tx_net_for_testtraining, rx_net_for_testtraining, tx_testtraining_optimiser, rx_testtraining_optimiser, epochs): relax_sigma = args.relax_sigma m, label = message_gen(args.bit_num, args.mb_size) m = m.type(torch.FloatTensor).to(args.device) label = label.type(torch.LongTensor).to(args.device) tx_net_for_testtraining.zero_grad() rx_net_for_testtraining.zero_grad() # tx tx_symb_mean, actual_transmitted_symbol = tx_net_for_testtraining( m, args.device, relax_sigma, Noise_relax) # channel received_signal = channel(h, actual_transmitted_symbol, Noise, args.device, args.if_AWGN) # rx out = rx_net_for_testtraining(received_signal, args.if_RTN, args.device) loss_rx = receiver_loss(out, label) loss_rx.backward() if args.if_test_training_adam: if args.if_adam_after_sgd: if epochs < args.num_meta_local_updates: for f in rx_net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.data * args.lr_meta_inner) elif epochs == args.num_meta_local_updates: rx_testtraining_optimiser = torch.optim.Adam( rx_net_for_testtraining.parameters(), args.lr_testtraining) rx_testtraining_optimiser.step() else: rx_testtraining_optimiser.step() else: rx_testtraining_optimiser.step() else: for f in rx_net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.detach() * args.lr_testtraining) if args.fix_joint_trained_tx_only_adapt_meta_trained_rx: tx_testtraining_optimiser = None loss_tx = 0 else: out = rx_net_for_testtraining(received_signal, args.if_RTN, args.device) received_reward_from_rx = feedback_from_rx(out, label, Noise_feedback, args.device) loss_tx = transmitter_loss(actual_transmitted_symbol, tx_symb_mean, received_reward_from_rx, args.relax_sigma) loss_tx.backward() if args.if_test_training_adam: if args.if_adam_after_sgd: if epochs < args.num_meta_local_updates: for f in tx_net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.data * args.lr_meta_inner) elif epochs == args.num_meta_local_updates: tx_testtraining_optimiser = torch.optim.Adam( tx_net_for_testtraining.parameters(), args.lr_testtraining) tx_testtraining_optimiser.step() else: tx_testtraining_optimiser.step() else: tx_testtraining_optimiser.step() else: for f in tx_net_for_testtraining.parameters(): if f.grad is not None: f.data.sub_(f.grad.detach() * args.lr_testtraining) return rx_testtraining_optimiser, tx_testtraining_optimiser, float( loss_rx), float(loss_tx)