def train_network(self, BP_iter_nums, SNR_set): start = datetime.datetime.now() bp_decoder = BP_QMS_Decoder.BP_NetDecoder( self.code.H_matrix, self.train_config.training_minibatch_size) x_in, xe_0, y_out = self.build_network(bp_decoder) y_label = tf.placeholder(tf.float32, [ self.train_config.training_minibatch_size, self.train_config.label_length ]) training_loss = self.cal_training_loss(y_out, y_label) test_loss = training_loss orig_loss_for_test = self.cal_training_loss(xe_0, y_label) # SGD_Adam train_step = tf.train.AdamOptimizer().minimize(training_loss) # init operation init = tf.global_variables_initializer() # create a session sess = tf.Session() for SNR in SNR_set: training_feature_file = format( '%s_%.1f.dat' % (self.train_config.training_feature_file, SNR)) training_label_file = format( '%s_%.1f.dat' % (self.train_config.training_label_file, SNR)) test_feature_file = format( '%s_%.1f.dat' % (self.train_config.test_feature_file, SNR)) test_label_file = format('%s_%.1f.dat' % (self.train_config.test_label_file, SNR)) dataio_train = DataIO.TrainingDataIO( training_feature_file, training_label_file, self.train_config.training_sample_num, self.train_config.feature_length, self.train_config.label_length) dataio_test = DataIO.TestDataIO(test_feature_file, test_label_file, self.train_config.test_sample_num, self.train_config.feature_length, self.train_config.label_length) for iteration in range(0, BP_iter_nums[0]): start1 = datetime.datetime.now() sess.run(init) # calculate the loss before training and assign it to min_loss ave_orig_loss, min_loss = self.test_network_online( dataio_test, bp_decoder, iteration, x_in, xe_0, y_label, orig_loss_for_test, test_loss, True, sess) self.save_network_temporarily(sess) # Train count = 0 epoch = 0 print( 'Iteration\tBest alpha\tBest beta\tCurrent loss\tCurrent alpha\tCurrent beta' ) alpha_set = [] beta_set = [] while epoch < self.train_config.epoch_num: epoch += 1 batch_xs, batch_ys = dataio_train.load_next_mini_batch( self.train_config.training_minibatch_size) llr_into_nn_net, xe0_into_nn_net = bp_decoder.quantized_decode_before_nn( batch_xs, iteration, self.train_config.alpha, self.train_config.beta) sess.run( [train_step], feed_dict={ x_in: llr_into_nn_net, xe_0: xe0_into_nn_net, y_label: batch_ys }) a, b = sess.run([self.alpha, self.beta]) alpha_set.append(a) beta_set.append(b) if epoch % 100 == 0 or epoch == self.train_config.epoch_num: _, ave_loss_after_train = self.test_network_online( dataio_test, bp_decoder, iteration, x_in, xe_0, y_label, orig_loss_for_test, test_loss, False, sess) if ave_loss_after_train < min_loss: print('%d\t\t%f\t%f\t%f\t%f\t%f' % (epoch, sess.run(self.best_alpha), sess.run( self.best_beta), ave_loss_after_train, sess.run(self.alpha), sess.run(self.beta))) min_loss = ave_loss_after_train self.save_network_temporarily(sess) count = 0 else: print('%d\t\t%f\t%f\t%f\t%f\t%f' % (epoch, sess.run(self.best_alpha), sess.run( self.best_beta), ave_loss_after_train, sess.run(self.alpha), sess.run(self.beta))) count += 1 if count >= 8: # no patience break best_alpha = sess.run(self.best_alpha) best_beta = sess.run(self.best_beta) self.train_config.alpha[iteration] = best_alpha self.train_config.beta[iteration] = best_beta para_file = format( '%sPARA(%d_%d)_SNR%.1f_Iter%d.txt' % (self.train_config.para_folder, self.train_config.feature_length, self.train_config.label_length, SNR, iteration + 1)) np.savetxt( para_file, np.vstack( (self.train_config.alpha, self.train_config.beta))) end1 = datetime.datetime.now() print('Used time for %dth training: %ds' % (iteration + 1, (end1 - start1).seconds)) print('\n') sess.close() end = datetime.datetime.now() print('Used time for training: %ds' % (end - start).seconds)
def LDPC_BP_MS_ACGN_test(code, dec_config, simutimes_range, target_err_bits_num, batch_size): ## load configurations from dec_config N = dec_config.N_code K = dec_config.K_code H_matrix = code.H_matrix SNR_set = dec_config.SNR_set BP_iter_num = dec_config.BP_iter_nums alpha = dec_config.alpha beta = dec_config.beta function = 'LDPC_BP_MS_ACGN_test' # build BP decoding network bp_decoder = BP_MS_Decoder.BP_NetDecoder(H_matrix, batch_size, alpha, beta) # init gragh init = tf.global_variables_initializer() sess = tf.Session() print('Open a tf session!') sess.run(init) ## initialize simulation times max_simutimes = simutimes_range[1] min_simutimes = simutimes_range[0] max_batches, residual_times = np.array(divmod(max_simutimes, batch_size), np.int32) if residual_times!=0: max_batches += 1 ## generate out ber file bp_str = np.array2string(BP_iter_num, separator='_', formatter={'int': lambda d: "%d" % d}) bp_str = bp_str[1:(len(bp_str) - 1)] ber_file = format('%sBER(%d_%d)_BP(%s)' % (dec_config.results_folder, N, K, bp_str)) ber_file = format('%s_%s' % (ber_file, function)) ber_file = format('%s.txt' % ber_file) fout_ber = open(ber_file, 'wt') ## simulation starts start = datetime.datetime.now() for SNR in SNR_set: y_recieve_file = format('%s_%.1f.dat' % (dec_config.decoding_y_file, SNR)) x_transmit_file = format('%s_%.1f.dat' % (dec_config.decoding_x_file, SNR)) dataio_decode = DataIO.BPdecDataIO(y_recieve_file, x_transmit_file, dec_config) real_batch_size = batch_size # simulation part actual_simutimes = 0 bit_errs_iter = np.zeros(1, dtype=np.int32) for ik in range(0, max_batches): print('Batch %d in total %d batches.' % (ik, int(max_batches)), end=' ') if ik == max_batches - 1 and residual_times != 0: real_batch_size = residual_times #encode and transmisssion y_receive, x_bits = dataio_decode.load_next_batch(batch_size, ik) u_coded_bits = code.encode_LDPC(x_bits) s_mod = Modulation.BPSK(u_coded_bits) ch_noise = y_receive - s_mod LLR = y_receive ##practical noise noise_power = np.mean(np.square(ch_noise)) practical_snr = 10*np.log10(1 / (noise_power * 2.0)) print('Practical EbN0: %.2f' % practical_snr) #BP decoder u_BP_decoded = bp_decoder.decode(LLR.astype(np.float32), BP_iter_num[0]) #BER output_x = code.dec_src_bits(u_BP_decoded) bit_errs_iter[0] += np.sum(output_x != x_bits) actual_simutimes += real_batch_size if bit_errs_iter[0] >= target_err_bits_num and actual_simutimes >= min_simutimes: break print('%d bits are simulated!' % (actual_simutimes * K)) # load to files ber_iter = np.zeros(1, dtype=np.float64) fout_ber.write(str(SNR) + '\t') ber_iter[0] = bit_errs_iter[0] / float(K * actual_simutimes) fout_ber.write(str(ber_iter[0])) fout_ber.write('\n') #simulation finished fout_ber.close() end = datetime.datetime.now() print('Time: %ds' % (end-start).seconds) print("end\n") sess.close() print('Close the tf session!')