def test_per_channel_per_snr_conven_approach(args, est_h, test_size, h, test_snr, actual_channel_num): batch_size = test_size success_test = 0 Eb_over_N_test = pow(10, (test_snr / 10)) R = args.bit_num / args.channel_num noise_var_test = 1 / (2 * R * Eb_over_N_test) Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num), noise_var_test * torch.eye(actual_channel_num)) print('payload num for conv', batch_size) m_test, label_test = message_gen(args.bit_num, batch_size) label_test = label_test.type(torch.LongTensor).to(args.device) if args.if_fix_random_seed: reset_randomness(args.random_seed + 7777) # always have same noise for test since this can see the actual effect of # of adaptations actual_transmitted_symbol = four_qam_uncoded_modulation(args.bit_num, args.channel_num, label_test)#m_test) received_signal = channel(h, actual_transmitted_symbol, Noise_test, args.device, args.if_AWGN) # rx (maximum likelihood) out_test = demodulation(args.bit_num, args.channel_num, args.tap_num, est_h, received_signal, args.device) for ind_mb in range(label_test.shape[0]): assert label_test.shape[0] == batch_size if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]: # means correct classification success_test += 1 else: pass accuracy = success_test / label_test.shape[0] print('accuracy', accuracy) return 1 - accuracy
def test_with_adapt_compact_during_online_meta_training( args, common_dir, curr_meta_training_epoch, test_snr_range, num_pilots_test, tx_net_for_testtraining, rx_net_for_testtraining, Noise, Noise_relax, actual_channel_num, PATH_before_adapt_tx, PATH_before_adapt_rx): # generate or load test channels if args.path_for_test_channels is None: raise NotImplementedError else: print('load previously generated channels') h_list_test_path = args.path_for_test_channels + '/' + 'test_channels.pckl' f_test_channels = open(h_list_test_path, 'rb') h_list_test = pickle.load(f_test_channels) f_test_channels.close() if len(h_list_test) > args.num_channels_test: h_list_test = h_list_test[:args.num_channels_test] # reset again # to make fair comp. per adapt. and per meta-training epochs if args.if_fix_random_seed: reset_randomness(args.random_seed + 2) print('curr pilots used for test during online (meta-)learning: ', num_pilots_test) os.makedirs(common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str(curr_meta_training_epoch) + '/' + 'tx/' + 'after_adapt/' + str(num_pilots_test) + '_num_pilots_test/') os.makedirs(common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str(curr_meta_training_epoch) + '/' + 'rx/' + 'after_adapt/' + str(num_pilots_test) + '_num_pilots_test/') block_error_rate = torch.zeros(args.num_channels_test, len(test_snr_range)) ind_h = 0 for h in h_list_test: PATH_after_adapt_tx = common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str( curr_meta_training_epoch) + '/' + 'tx/' + 'after_adapt/' + str( num_pilots_test) + '_num_pilots_test/' + str( ind_h) + 'th_adapted_net' PATH_after_adapt_rx = common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str( curr_meta_training_epoch) + '/' + 'rx/' + 'after_adapt/' + str( num_pilots_test) + '_num_pilots_test/' + str( ind_h) + 'th_adapted_net' test_training(args, h, tx_net_for_testtraining, rx_net_for_testtraining, Noise, Noise_relax, PATH_before_adapt_tx, PATH_before_adapt_rx, PATH_after_adapt_tx, PATH_after_adapt_rx, num_pilots_test) # test ind_snr = 0 for test_snr in test_snr_range: block_error_rate_per_snr_per_channel = test_per_channel_per_snr( args, args.test_size_during_meta_update, h, tx_net_for_testtraining, rx_net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt_tx, PATH_after_adapt_rx) block_error_rate[ind_h, ind_snr] = block_error_rate_per_snr_per_channel ind_snr += 1 ind_h += 1 return torch.mean(block_error_rate[:, :])
def test_per_channel_per_snr(args, test_size, h, tx_net_for_testtraining, rx_net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt_tx, PATH_after_adapt_rx): tx_net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt_tx)) rx_net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt_rx)) batch_size = test_size success_test = 0 Eb_over_N_test = pow(10, (test_snr / 10)) R = args.bit_num / args.channel_num noise_var_test = 1 / (2 * R * Eb_over_N_test) Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num), noise_var_test * torch.eye(actual_channel_num)) m_test, label_test = message_gen(args.bit_num, batch_size) m_test = m_test.type(torch.FloatTensor).to(args.device) label_test = label_test.type(torch.LongTensor).to(args.device) for f in tx_net_for_testtraining.parameters(): if f.grad is not None: f.grad.detach() f.grad.zero_() rx_net_for_testtraining.zero_grad() if args.if_fix_random_seed: reset_randomness(args.random_seed + 7777) # always have same noise for test since this can see the actual effect of # of adaptations if args.fix_bpsk_tx_train_nn_rx_during_runtime: actual_transmitted_symbol = four_qam_uncoded_modulation(args.bit_num, args.channel_num, label_test) # label instead m else: tx_symb_mean, actual_transmitted_symbol = tx_net_for_testtraining(m_test, args.device, 0, None) # no relaxation tx_symb_mean = None # we don't need during test # channel received_signal = channel(h, actual_transmitted_symbol, Noise_test, args.device, args.if_AWGN) # rx out_test = rx_net_for_testtraining(received_signal, args.if_RTN, args.device) for ind_mb in range(label_test.shape[0]): assert label_test.shape[0] == batch_size if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]: # means correct classification success_test += 1 else: pass accuracy = success_test / label_test.shape[0] return 1 - accuracy
def test_conven_commun_during_online_meta_training(args, test_snr_range, num_pilots_test, Noise, actual_channel_num): # generate or load test channels if args.path_for_test_channels is None: print( 'we need to first make the test channels a priori and load it via path (args.path_for_test_channels)' ) raise NotImplementedError else: print('load previously generated channels') h_list_test_path = args.path_for_test_channels + '/' + 'test_channels.pckl' f_test_channels = open(h_list_test_path, 'rb') h_list_test = pickle.load(f_test_channels) f_test_channels.close() if len(h_list_test) > args.num_channels_test: h_list_test = h_list_test[:args.num_channels_test] print('used test channels', h_list_test) # reset again # to make fair comp. per adapt. and per meta-training epochs if args.if_fix_random_seed: reset_randomness(args.random_seed + 2) print('curr num pilots: ', num_pilots_test) block_error_rate = torch.zeros(args.num_channels_test, len(test_snr_range)) ind_h = 0 ch_est_error_avg = 0 for h in h_list_test: est_h, error_h = test_training_conven_commun(args, h, Noise, num_pilots_test) ind_snr = 0 for test_snr in test_snr_range: block_error_rate_per_snr_per_channel = test_per_channel_per_snr_conven_approach( args, est_h, args.conv_payload_num, h, test_snr, actual_channel_num) block_error_rate[ind_h, ind_snr] = block_error_rate_per_snr_per_channel ind_snr += 1 ind_h += 1 ch_est_error_avg += error_h ch_est_error_avg = ch_est_error_avg / ind_h return torch.mean(block_error_rate[:, :]), ch_est_error_avg
def __init__(self, M, n, n_inv_filter, num_neurons_decoder, if_bias, if_relu, if_RTN, if_fix_random_seed, random_seed): super(receiver, self).__init__() num_inv_filter = 2 * n_inv_filter if if_RTN: if if_fix_random_seed: reset_randomness(random_seed + 1) self.rtn_1 = nn.Linear(n, n, bias=if_bias) self.rtn_2 = nn.Linear(n, n, bias=if_bias) self.rtn_3 = nn.Linear(n, num_inv_filter, bias=if_bias) else: pass if if_fix_random_seed: reset_randomness(random_seed) self.dec_fc1 = nn.Linear(n, num_neurons_decoder, bias=if_bias) self.dec_fc2 = nn.Linear(num_neurons_decoder, M, bias=if_bias) if if_relu: self.activ = nn.ReLU() else: self.activ = nn.Tanh() self.tanh = nn.Tanh()
def multi_task_learning(args, tx_net,rx_net, h_list_meta, writer_meta_training, Noise, Noise_relax, Noise_feedback, PATH_before_adapt_rx_intermediate, PATH_before_adapt_tx_intermediate): meta_optimiser_tx = torch.optim.Adam(tx_net.parameters(), args.lr_meta_update) meta_optimiser_rx = torch.optim.Adam(rx_net.parameters(), args.lr_meta_update) h_list_train = h_list_meta[:args.num_channels_meta] if args.if_fix_random_seed: random_seed_init = args.random_seed + 99999 else: pass previous_channel = None for epochs in range(args.num_epochs_meta_train): if epochs % 100 == 0: curr_path_rx = PATH_before_adapt_rx_intermediate + str(epochs) curr_path_tx = PATH_before_adapt_tx_intermediate + str(epochs) torch.save(rx_net.state_dict(), curr_path_rx) torch.save(tx_net.state_dict(), curr_path_tx) print('stochactic meta-learning epoch', epochs) first_loss_rx = 0 second_loss_rx = 0 first_loss_tx = 0 second_loss_tx = 0 iter_in_sampled_device = 0 # for averaging meta-devices for ind_meta_dev in range(args.tasks_per_metaupdate): if args.if_always_generate_new_meta_training_channels: if args.if_fix_random_seed: reset_randomness(random_seed_init) # to make noise same as much as possible for joint and meta random_seed_init += 1 else: pass if args.if_Rayleigh_channel_model_AR: if epochs % args.keep_AR_period == 0: h_var_dist = torch.distributions.uniform.Uniform(torch.FloatTensor([args.mul_h_var_min]), torch.FloatTensor([args.mul_h_var_max])) mul_h_var = h_var_dist.sample() if_reset_AR = True else: if_reset_AR = False current_channel = channel_set_gen_AR(args.tap_num, previous_channel, args.rho, mul_h_var, if_reset_AR) # num_channels = 1 since we are generating per channel previous_channel = current_channel else: raise NotImplementedError else: if args.if_fix_random_seed: reset_randomness(random_seed_init) # to make noise same as much as possible for joint and meta random_seed_init += 1 else: pass # during this, meta-gradients are accumulated channel_list_total = torch.randperm(len(h_list_train)) # sampling with replacement current_channel_ind = channel_list_total[ ind_meta_dev] # randomly sample meta-batches (no rep. inside meta-batch) current_channel = h_list_train[current_channel_ind] if args.if_joint_training: iter_in_sampled_device, first_loss_curr_rx, first_loss_curr_tx, second_loss_curr_rx, second_loss_curr_tx = joint_training(args, iter_in_sampled_device, tx_net, rx_net, current_channel, Noise, Noise_relax, Noise_feedback) elif args.if_joint_training_tx_meta_training_rx: iter_in_sampled_device, first_loss_curr_rx, first_loss_curr_tx, second_loss_curr_rx, second_loss_curr_tx = maml_for_rx_joint_for_tx( args, iter_in_sampled_device, tx_net, rx_net, current_channel, Noise, Noise_relax, Noise_feedback) else: # maml raise NotImplementedError first_loss_rx = first_loss_rx + first_loss_curr_rx second_loss_rx = second_loss_rx + second_loss_curr_rx first_loss_tx = first_loss_tx + first_loss_curr_tx second_loss_tx = second_loss_tx + second_loss_curr_tx first_loss_tx = first_loss_tx / args.tasks_per_metaupdate second_loss_tx = second_loss_tx / args.tasks_per_metaupdate first_loss_rx = first_loss_rx / args.tasks_per_metaupdate second_loss_rx = second_loss_rx / args.tasks_per_metaupdate writer_meta_training.add_scalar('first rx loss', first_loss_rx, epochs) writer_meta_training.add_scalar('first tx loss', first_loss_tx, epochs) writer_meta_training.add_scalar('second rx loss', second_loss_rx, epochs) writer_meta_training.add_scalar('second tx loss', second_loss_tx, epochs) meta_optimiser_rx.zero_grad() meta_optimiser_tx.zero_grad() for f in rx_net.parameters(): f.grad = f.total_grad.clone() / args.tasks_per_metaupdate for f in tx_net.parameters(): f.grad = f.total_grad.clone() / args.tasks_per_metaupdate meta_optimiser_rx.step() # Adam meta_optimiser_tx.step() # Adam
def test_with_adapt(args, common_dir, common_dir_over_multi_rand_seeds, test_snr_range, test_num_pilots_available, meta_training_epoch_for_test, tx_net_for_testtraining, rx_net_for_testtraining, Noise, Noise_relax, actual_channel_num, save_test_result_dict_total, test_result_all_PATH_for_all_meta_training_epochs, PATH_before_adapt_tx, PATH_before_adapt_rx): # generate or load test channels if args.path_for_test_channels is None: if args.if_fix_random_seed: reset_randomness(args.random_seed + 11) print('generate test channels') h_list_test = channel_set_gen(args.num_channels_test, args.tap_num, args.if_toy) h_list_test_path = common_dir + 'test_channels/' + 'test_channels.pckl' f_test_channels = open(h_list_test_path, 'wb') pickle.dump(h_list_test, f_test_channels) f_test_channels.close() else: print('load previously generated channels') h_list_test_path = args.path_for_test_channels + '/' + 'test_channels.pckl' f_test_channels = open(h_list_test_path, 'rb') h_list_test = pickle.load(f_test_channels) f_test_channels.close() if len(h_list_test) > args.num_channels_test: h_list_test = h_list_test[:args.num_channels_test] print('used test channels', h_list_test) dir_test = common_dir + 'TB/' + 'test' writer_test = SummaryWriter(dir_test) total_total_block_error_rate = torch.zeros( args.num_channels_test, len(test_snr_range), len(test_num_pilots_available), len(meta_training_epoch_for_test)) ind_meta_training_epoch = 0 for meta_training_epochs in meta_training_epoch_for_test: if common_dir_over_multi_rand_seeds is not None: os.makedirs(common_dir_over_multi_rand_seeds + 'test_result_after_meta_training/' + 'iter/' + str(args.fix_tx_multi_adapt_rx_iter_num) + 'rho/' + str(args.rho) + '/rand_seeds/' + str(args.random_seed) + '/test_result/' + 'with_meta_training_epoch/' + str(meta_training_epochs) + '/') test_result_all_PATH_per_rand_seeds = common_dir_over_multi_rand_seeds + 'test_result_after_meta_training/' + 'iter/' + str( args.fix_tx_multi_adapt_rx_iter_num) + 'rho/' + str( args.rho) + '/rand_seeds/' + str( args.random_seed ) + '/test_result/' + 'with_meta_training_epoch/' + str( meta_training_epochs) + '/' + 'test_result.mat' save_test_result_dict = {} # start with given initialization print('start adaptation with test set') total_block_error_rate = torch.zeros(args.num_channels_test, len(test_snr_range), len(test_num_pilots_available)) ind_num_pilots_test = 0 for num_pilots_test in test_num_pilots_available: # reset again # to make fair comp. per adapt. and per meta-training epochs if args.if_fix_random_seed: reset_randomness(args.random_seed + 2) print('curr pilots num: ', num_pilots_test) os.makedirs(common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str(meta_training_epochs) + '/' + 'tx/' + 'after_adapt/' + str(num_pilots_test) + '_num_pilots_test/') os.makedirs(common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str(meta_training_epochs) + '/' + 'rx/' + 'after_adapt/' + str(num_pilots_test) + '_num_pilots_test/') os.makedirs(common_dir + 'test_result/' + 'with_meta_training_epoch/' + str(meta_training_epochs) + '/' + str(num_pilots_test) + '_num_pilots_test/') test_result_per_num_pilots_test = common_dir + 'test_result/' + 'with_meta_training_epoch/' + str( meta_training_epochs) + '/' + str( num_pilots_test) + '_num_pilots_test/' + 'test_result.mat' save_test_result_dict_per_num_pilots_test = {} block_error_rate = torch.zeros(args.num_channels_test, len(test_snr_range)) ind_h = 0 for h in h_list_test: print('current channel ind', ind_h) PATH_after_adapt_tx = common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str( meta_training_epochs) + '/' + 'tx/' + 'after_adapt/' + str( num_pilots_test) + '_num_pilots_test/' + str( ind_h) + 'th_adapted_net' PATH_after_adapt_rx = common_dir + 'saved_model/' + 'with_meta_training_epoch/' + str( meta_training_epochs) + '/' + 'rx/' + 'after_adapt/' + str( num_pilots_test) + '_num_pilots_test/' + str( ind_h) + 'th_adapted_net' test_training(args, h, tx_net_for_testtraining, rx_net_for_testtraining, Noise, Noise_relax, PATH_before_adapt_tx, PATH_before_adapt_rx, PATH_after_adapt_tx, PATH_after_adapt_rx, num_pilots_test) # test ind_snr = 0 for test_snr in test_snr_range: block_error_rate_per_snr_per_channel = test_per_channel_per_snr( args, args.test_size, h, tx_net_for_testtraining, rx_net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt_tx, PATH_after_adapt_rx) block_error_rate[ ind_h, ind_snr] = block_error_rate_per_snr_per_channel total_block_error_rate[ ind_h, ind_snr, ind_num_pilots_test] = block_error_rate_per_snr_per_channel total_total_block_error_rate[ ind_h, ind_snr, ind_num_pilots_test, ind_meta_training_epoch] = block_error_rate_per_snr_per_channel ind_snr += 1 ind_h += 1 save_test_result_dict_per_num_pilots_test[ 'block_error_rate'] = block_error_rate.detach().numpy() sio.savemat(test_result_per_num_pilots_test, save_test_result_dict_per_num_pilots_test) writer_test.add_scalar( 'average (h) block error rate per num pilots', torch.mean(block_error_rate[:, :]), num_pilots_test) ind_num_pilots_test += 1 print('curr pilots num', num_pilots_test, 'bler', torch.mean(block_error_rate[:, :])) save_test_result_dict[ 'block_error_rate_total'] = total_block_error_rate.detach().numpy( ) if common_dir_over_multi_rand_seeds is not None: sio.savemat(test_result_all_PATH_per_rand_seeds, save_test_result_dict) else: os.makedirs(common_dir + 'test_result/' + 'with_meta_training_epoch/' + str(meta_training_epochs) + '/') test_result_all_PATH = common_dir + 'test_result/' + 'with_meta_training_epoch/' + str( meta_training_epochs) + '/' + 'test_result.mat' sio.savemat(test_result_all_PATH, save_test_result_dict) if args.path_for_meta_trained_net_total_per_epoch: save_test_result_dict_total[ 'block_error_rate_total_total_meta_training_epoch'] = total_total_block_error_rate.detach( ).numpy() sio.savemat(test_result_all_PATH_for_all_meta_training_epochs, save_test_result_dict_total)
def multi_task_learning(args, common_dir, tx_net, rx_net, writer_meta_training, Noise, Noise_relax, actual_channel_num, PATH_before_adapt_rx_intermediate, PATH_before_adapt_tx_intermediate, rx_net_for_testtraining, tx_net_for_testtraining): meta_optimiser_tx = torch.optim.Adam(tx_net.parameters(), args.lr_meta_update) meta_optimiser_rx = torch.optim.Adam(rx_net.parameters(), args.lr_meta_update) test_result_PATH_per_meta_training_test_bler_dir = common_dir + 'test_result_during_meta_training/' save_test_result_dict_total_per_meta_training_test_bler = {} if args.if_fix_random_seed: random_seed_init = args.random_seed + 99999 else: pass # for online previous_channel = [] for ind_dev in range(args.tasks_per_metaupdate): previous_channel.append(None) test_bler_per_meta_training_epochs = [] channel_per_meta_training_epochs = [] second_loss_rx_best_for_stopping_criteria = 99999999999 for epochs in range(args.num_epochs_meta_train): if epochs % args.meta_tr_epoch_num_for_test == 0: curr_path_rx = PATH_before_adapt_rx_intermediate + str(epochs) curr_path_tx = PATH_before_adapt_tx_intermediate + str(epochs) torch.save(rx_net.state_dict(), curr_path_rx) torch.save(tx_net.state_dict(), curr_path_tx) print('meta-learning epoch', epochs) if args.see_test_bler_during_meta_update: # see test bler here test_snr_range = [args.Eb_over_N_db_test] num_pilots_test = args.pilots_num_meta_test PATH_before_adapt_tx = curr_path_tx PATH_before_adapt_rx = curr_path_rx if args.if_get_conven_commun_performance: test_bler_mean_curr_epoch_conven_approach, ch_est_error_avg = test_conven_commun_during_online_meta_training( args, test_snr_range, num_pilots_test, Noise, actual_channel_num) writer_meta_training.add_scalar( 'conv. approach test bler during meta-training', test_bler_mean_curr_epoch_conven_approach, epochs) print('conven bler: ', test_bler_mean_curr_epoch_conven_approach) print('conven ch. est: ', ch_est_error_avg) print( 'as this is for BPSK with maximum likelihood decoder, we only need this once so we stop the code here' ) dfdfdfdfd else: pass test_bler_mean_curr_epoch = test_with_adapt_compact_during_online_meta_training( args, common_dir, epochs, test_snr_range, num_pilots_test, tx_net_for_testtraining, rx_net_for_testtraining, Noise, Noise_relax, actual_channel_num, PATH_before_adapt_tx, PATH_before_adapt_rx) writer_meta_training.add_scalar( 'test bler during meta-training', test_bler_mean_curr_epoch, epochs) test_bler_per_meta_training_epochs.append( test_bler_mean_curr_epoch) first_loss_rx = 0 second_loss_rx = 0 first_loss_tx = 0 second_loss_tx = 0 iter_in_sampled_device = 0 # for averaging meta-devices for ind_meta_dev in range(args.tasks_per_metaupdate): if args.if_fix_random_seed: reset_randomness( random_seed_init ) # to make noise same as much as possible for joint and meta random_seed_init += 1 else: pass if args.if_Rayleigh_channel_model_AR: current_channel = channel_set_gen_AR( args.tap_num, previous_channel[ind_meta_dev], args.rho ) # num_channels = 1 since we are generating per channel previous_channel[ind_meta_dev] = current_channel else: raise NotImplementedError if args.if_joint_training: iter_in_sampled_device, first_loss_curr_rx, first_loss_curr_tx, second_loss_curr_rx, second_loss_curr_tx = joint_training( args, iter_in_sampled_device, tx_net, rx_net, current_channel, Noise, Noise_relax) elif args.if_joint_training_tx_meta_training_rx: iter_in_sampled_device, first_loss_curr_rx, first_loss_curr_tx, second_loss_curr_rx, second_loss_curr_tx = maml_for_rx_joint_for_tx( args, iter_in_sampled_device, tx_net, rx_net, current_channel, Noise, Noise_relax) else: # maml raise NotImplementedError first_loss_rx = first_loss_rx + first_loss_curr_rx second_loss_rx = second_loss_rx + second_loss_curr_rx first_loss_tx = first_loss_tx + first_loss_curr_tx second_loss_tx = second_loss_tx + second_loss_curr_tx first_loss_tx = first_loss_tx / args.tasks_per_metaupdate second_loss_tx = second_loss_tx / args.tasks_per_metaupdate # we joint train tx so we only have one loss (which means, first_loss_tx = second_loss_tx) first_loss_rx = first_loss_rx / args.tasks_per_metaupdate second_loss_rx = second_loss_rx / args.tasks_per_metaupdate if args.if_TB_loss_ignore: pass else: writer_meta_training.add_scalar('RX loss before local adaptation', first_loss_rx, epochs) writer_meta_training.add_scalar('RX loss after local adaptation', second_loss_rx, epochs) writer_meta_training.add_scalar('TX loss', first_loss_tx, epochs) if args.if_use_stopping_criteria_during_meta_training: if second_loss_rx < second_loss_rx_best_for_stopping_criteria: curr_path_rx_best_training_loss = PATH_before_adapt_rx_intermediate + 'best_model_based_on_meta_training_loss' curr_path_tx_best_training_loss = PATH_before_adapt_tx_intermediate + 'best_model_based_on_meta_training_loss' torch.save(rx_net.state_dict(), curr_path_rx_best_training_loss) torch.save(tx_net.state_dict(), curr_path_tx_best_training_loss) else: pass else: pass meta_optimiser_rx.zero_grad() meta_optimiser_tx.zero_grad() # rx meta-update for f in rx_net.parameters(): f.grad = f.total_grad.clone() / args.tasks_per_metaupdate meta_optimiser_rx.step() # Adam # tx meta-update if args.fix_bpsk_tx: # nothing to meta-learn for tx since tx is BPSK encoder pass else: for f in tx_net.parameters(): f.grad = f.total_grad.clone() / args.tasks_per_metaupdate meta_optimiser_tx.step() # Adam if epochs % args.meta_tr_epoch_num_for_test == 0: os.makedirs(test_result_PATH_per_meta_training_test_bler_dir + 'epochs/' + str(epochs) + '/') test_result_PATH_per_meta_training_test_bler = test_result_PATH_per_meta_training_test_bler_dir + 'epochs/' + str( epochs) + '/' + 'test_result_per_meta_training_epochs.mat' save_test_result_dict_total_per_meta_training_test_bler[ 'test_bler_during_meta_training'] = test_bler_per_meta_training_epochs sio.savemat( test_result_PATH_per_meta_training_test_bler, save_test_result_dict_total_per_meta_training_test_bler) else: pass return test_bler_per_meta_training_epochs, channel_per_meta_training_epochs
if args.if_exp_over_multi_pilots_test: test_num_pilots_available = [1,2,4,8,16,32,64,128] else: test_num_pilots_available = [8] print('test available pilots: ', test_num_pilots_available) test_result_PATH_per_meta_training_test_bler = common_dir + 'test_result_during_meta_training/' + 'test_result_per_meta_training_epochs.mat' save_test_result_dict_total_per_meta_training_test_bler = {} # complex symbol actual_channel_num = args.channel_num * 2 if args.if_fix_random_seed: reset_randomness(args.random_seed + 999) tx_net = tx_dnn(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder, n=actual_channel_num, if_bias=args.if_bias, if_relu=args.if_relu) if torch.cuda.is_available(): tx_net = tx_net.to(args.device) if args.if_fix_random_seed: reset_randomness(args.random_seed + 999) tx_net_for_testtraining = tx_dnn(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder, n=actual_channel_num, if_bias=args.if_bias, if_relu=args.if_relu)