def backward(ctx, grad_output): if ctx.args.enc_clipping in ['inputs', 'both']: input, = ctx.saved_tensors grad_output[input > ctx.args.enc_value_limit] = 0 grad_output[input < -ctx.args.enc_value_limit] = 0 if ctx.args.enc_clipping in ['gradient', 'both']: grad_output = torch.clamp(grad_output, -ctx.args.enc_grad_limit, ctx.args.enc_grad_limit) if ctx.args.train_channel_mode not in [ 'group_norm_noisy', 'group_norm_noisy_quantize' ]: grad_input = grad_output.clone() else: # Experimental pass gradient noise to encoder. grad_noise = snr_db2sigma(ctx.args.fb_noise_snr) * torch.randn( grad_output[0].shape, dtype=torch.float) ave_temp = grad_output.mean(dim=0) + grad_noise ave_grad = torch.stack( [ave_temp for _ in range(ctx.args.batch_size)], dim=2).permute(2, 0, 1) grad_input = ave_grad + grad_noise return grad_input, None
def generate_noise(X_train_shape, args, test_sigma='default', snr_low=0.0, snr_high=0.0, mode='encoder'): # SNRs at training if test_sigma == 'default': if args.channel == 'bec': if mode == 'encoder': this_sigma = args.bec_p_enc else: this_sigma = args.bec_p_dec elif args.channel in ['bsc', 'ge']: if mode == 'encoder': this_sigma = args.bsc_p_enc else: this_sigma = args.bsc_p_dec else: # general AWGN cases this_sigma_low = snr_db2sigma(snr_low) this_sigma_high = snr_db2sigma(snr_high) # mixture of noise sigma. this_sigma = (this_sigma_low - this_sigma_high) * torch.rand( (X_train_shape[0], X_train_shape[1], args.code_rate_n)) + this_sigma_high else: if args.channel in ['bec', 'bsc', 'ge']: # bsc/bec noises this_sigma = test_sigma else: this_sigma = snr_db2sigma(test_sigma) # SNRs at testing if args.channel == 'awgn': fwd_noise = this_sigma * torch.randn( (X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) elif args.channel == 't-dist': fwd_noise = this_sigma * torch.from_numpy( np.sqrt((args.vv - 2) / args.vv) * np.random.standard_t( args.vv, size=(X_train_shape[0], X_train_shape[1], args.code_rate_n))).type(torch.FloatTensor) elif args.channel == 'radar': add_pos = np.random.choice( [0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n), p=[1 - args.radar_prob, args.radar_prob]) corrupted_signal = args.radar_power * np.random.standard_normal(size=( X_train_shape[0], X_train_shape[1], args.code_rate_n)) * add_pos fwd_noise = this_sigma * torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) +\ torch.from_numpy(corrupted_signal).type(torch.FloatTensor) elif args.channel == 'bec': fwd_noise = torch.from_numpy( np.random.choice( [0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n), p=[this_sigma, 1 - this_sigma])).type(torch.FloatTensor) elif args.channel == 'bsc': fwd_noise = torch.from_numpy( np.random.choice( [0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n), p=[this_sigma, 1 - this_sigma])).type(torch.FloatTensor) elif args.channel == 'ge_awgn': #G-E AWGN channel p_gg = 0.8 # stay in good state p_bb = 0.8 bsc_k = snr_db2sigma(snr_sigma2db(this_sigma) + 1) # accuracy on good state bsc_h = snr_db2sigma(snr_sigma2db(this_sigma) - 1) # accuracy on good state fwd_noise = np.zeros( (X_train_shape[0], X_train_shape[1], args.code_rate_n)) for batch_idx in range(X_train_shape[0]): for code_idx in range(args.code_rate_n): good = True for time_idx in range(X_train_shape[1]): if good: if test_sigma == 'default': fwd_noise[batch_idx, time_idx, code_idx] = bsc_k[batch_idx, time_idx, code_idx] else: fwd_noise[batch_idx, time_idx, code_idx] = bsc_k good = np.random.random() < p_gg elif not good: if test_sigma == 'default': fwd_noise[batch_idx, time_idx, code_idx] = bsc_h[batch_idx, time_idx, code_idx] else: fwd_noise[batch_idx, time_idx, code_idx] = bsc_h good = np.random.random() < p_bb else: print('bad!!! something happens') fwd_noise = torch.from_numpy(fwd_noise).type( torch.FloatTensor) * torch.randn( (X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) elif args.channel == 'ge': #G-E discrete channel p_gg = 0.8 # stay in good state p_bb = 0.8 bsc_k = 1.0 # accuracy on good state bsc_h = this_sigma # accuracy on good state fwd_noise = np.zeros( (X_train_shape[0], X_train_shape[1], args.code_rate_n)) for batch_idx in range(X_train_shape[0]): for code_idx in range(args.code_rate_n): good = True for time_idx in range(X_train_shape[1]): if good: tmp = np.random.choice([0.0, 1.0], p=[1 - bsc_k, bsc_k]) fwd_noise[batch_idx, time_idx, code_idx] = tmp good = np.random.random() < p_gg elif not good: tmp = np.random.choice([0.0, 1.0], p=[1 - bsc_h, bsc_h]) fwd_noise[batch_idx, time_idx, code_idx] = tmp good = np.random.random() < p_bb else: print('bad!!! something happens') fwd_noise = torch.from_numpy(fwd_noise).type(torch.FloatTensor) else: # Unspecific channel, use AWGN channel. fwd_noise = this_sigma * torch.randn( (X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) return fwd_noise
def ftae_test(model, args, use_cuda=False): device = torch.device("cuda" if use_cuda else "cpu") model.eval() # Precomputes Norm Statistics. if args.precompute_norm_stats: num_test_batch = int(args.num_block / (args.batch_size) * args.test_ratio) for batch_idx in range(num_test_batch): X_test = torch.randint( 0, 2, (args.batch_size, args.block_len, args.code_rate_k), dtype=torch.float) X_test = X_test.to(device) _ = model.enc(X_test) print('Pre-computed norm statistics mean ', model.enc.mean_scalar, 'std ', model.enc.std_scalar) ber_res, bler_res = [], [] snr_interval = (args.snr_test_end - args.snr_test_start) * 1.0 / (args.snr_points - 1) snrs = [ snr_interval * item + args.snr_test_start for item in range(args.snr_points) ] print('SNRS', snrs) sigmas = snrs for sigma, this_snr in zip(sigmas, snrs): test_ber, test_bler = .0, .0 with torch.no_grad(): num_test_batch = int(args.num_block / (args.batch_size) * args.test_ratio) for batch_idx in range(num_test_batch): X_test = torch.randint( 0, 2, (args.batch_size, args.block_len, args.code_rate_k), dtype=torch.float) fwd_noise = generate_noise(X_test.shape, args, test_sigma=sigma) X_test, fwd_noise = X_test.to(device), fwd_noise.to(device) X_hat_test, the_codes = model(X_test, fwd_noise) test_ber += errors_ber(X_hat_test, X_test) test_bler += errors_bler(X_hat_test, X_test) if batch_idx == 0: test_pos_ber = errors_ber_pos(X_hat_test, X_test) codes_power = code_power(the_codes) else: test_pos_ber += errors_ber_pos(X_hat_test, X_test) codes_power += code_power(the_codes) if args.print_pos_power: print('code power', codes_power / num_test_batch) if args.print_pos_ber: print('positional ber', test_pos_ber / num_test_batch) test_ber /= num_test_batch test_bler /= num_test_batch print('Test SNR', this_snr, 'with ber ', float(test_ber), 'with bler', float(test_bler)) ber_res.append(float(test_ber)) bler_res.append(float(test_bler)) print('final results on SNRs ', snrs) print('BER', ber_res) print('BLER', bler_res) # compute adjusted SNR. (some quantization might make power!=1.0) enc_power = 0.0 with torch.no_grad(): for idx in range(num_test_batch): X_test = torch.randint( 0, 2, (args.batch_size, args.block_len, args.code_rate_k), dtype=torch.float) X_test = X_test.to(device) X_code = model.enc(X_test) enc_power += torch.std(X_code) enc_power /= float(num_test_batch) print('encoder power is', enc_power) adj_snrs = [snr_sigma2db(snr_db2sigma(item) / enc_power) for item in snrs] print('adjusted SNR should be', adj_snrs)
def test(model, args, block_len='default', use_cuda=False): device = torch.device("cuda" if use_cuda else "cpu") model.eval() if block_len == 'default': block_len = args.block_len else: pass # Precomputes Norm Statistics. if args.precompute_norm_stats: with torch.no_grad(): num_test_batch = int(args.num_block / (args.batch_size) * args.test_ratio) for batch_idx in range(num_test_batch): X_test = torch.randint( 0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) X_test = X_test.to(device) _ = model.enc(X_test) print('Pre-computed norm statistics mean ', model.enc.mean_scalar, 'std ', model.enc.std_scalar) ber_res, bler_res = [], [] ber_res_punc, bler_res_punc = [], [] snr_interval = (args.snr_test_end - args.snr_test_start) * 1.0 / (args.snr_points - 1) snrs = [ snr_interval * item + args.snr_test_start for item in range(args.snr_points) ] print('SNRS', snrs) sigmas = snrs for sigma, this_snr in zip(sigmas, snrs): test_ber, test_bler = .0, .0 with torch.no_grad(): num_test_batch = int(args.num_block / (args.batch_size)) for batch_idx in range(num_test_batch): X_test = torch.randint( 0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) noise_shape = (args.batch_size, int(args.block_len * args.code_rate_n / args.mod_rate), args.mod_rate) fwd_noise = generate_noise(noise_shape, args, test_sigma=sigma) X_test, fwd_noise = X_test.to(device), fwd_noise.to(device) X_hat_test, the_codes = model(X_test, fwd_noise) test_ber += errors_ber(X_hat_test, X_test) test_bler += errors_bler(X_hat_test, X_test) if batch_idx == 0: test_pos_ber = errors_ber_pos(X_hat_test, X_test) codes_power = code_power(the_codes) else: test_pos_ber += errors_ber_pos(X_hat_test, X_test) codes_power += code_power(the_codes) if args.print_pos_power: print('code power', codes_power / num_test_batch) if args.print_pos_ber: res_pos = test_pos_ber / num_test_batch res_pos_arg = np.array(res_pos.cpu()).argsort()[::-1] res_pos_arg = res_pos_arg.tolist() print('positional ber', res_pos) print('positional argmax', res_pos_arg) try: test_ber_punc, test_bler_punc = .0, .0 for batch_idx in range(num_test_batch): X_test = torch.randint( 0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) noise_shape = (args.batch_size, int(args.block_len * args.code_rate_n / args.mod_rate), args.mod_rate) fwd_noise = generate_noise(noise_shape, args, test_sigma=sigma) X_test, fwd_noise = X_test.to(device), fwd_noise.to(device) X_hat_test, the_codes = model(X_test, fwd_noise) test_ber_punc += errors_ber( X_hat_test, X_test, positions=res_pos_arg[:args.num_ber_puncture]) test_bler_punc += errors_bler( X_hat_test, X_test, positions=res_pos_arg[:args.num_ber_puncture]) if batch_idx == 0: test_pos_ber = errors_ber_pos(X_hat_test, X_test) codes_power = code_power(the_codes) else: test_pos_ber += errors_ber_pos(X_hat_test, X_test) codes_power += code_power(the_codes) except: print('no pos BER specified.') test_ber /= num_test_batch test_bler /= num_test_batch print('Test SNR', this_snr, 'with ber ', float(test_ber), 'with bler', float(test_bler)) ber_res.append(float(test_ber)) bler_res.append(float(test_bler)) try: test_ber_punc /= num_test_batch test_bler_punc /= num_test_batch print('Punctured Test SNR', this_snr, 'with ber ', float(test_ber_punc), 'with bler', float(test_bler_punc)) ber_res_punc.append(float(test_ber_punc)) bler_res_punc.append(float(test_bler_punc)) except: print('No puncturation is there.') print('final results on SNRs ', snrs) print('BER', ber_res) print('BLER', bler_res) print('final results on punctured SNRs ', snrs) print('BER', ber_res_punc) print('BLER', bler_res_punc) # compute adjusted SNR. (some quantization might make power!=1.0) enc_power = 0.0 with torch.no_grad(): for idx in range(num_test_batch): X_test = torch.randint( 0, 2, (args.batch_size, block_len, args.code_rate_k), dtype=torch.float) X_test = X_test.to(device) X_code = model.enc(X_test) enc_power += torch.std(X_code) enc_power /= float(num_test_batch) print('encoder power is', enc_power) adj_snrs = [snr_sigma2db(snr_db2sigma(item) / enc_power) for item in snrs] print('adjusted SNR should be', adj_snrs)
def generate_bcjr_example(num_block, block_len, codec, num_iteration, is_save = True, train_snr_db = 0.0, save_path = './tmp/', **kwargs ): ''' Generate BCJR feature and target for training BCJR-like RNN codec from scratch ''' start_time = time.time() # print print('[BCJR] Block Length is ', block_len) print('[BCJR] Number of Block is ', num_block) input_feature_num = 3 noise_type = 'awgn' noise_sigma = snr_db2sigma(train_snr_db) identity = str(np.random.random()) # random id for saving # Unpack Codec trellis1 = codec[0] trellis2 = codec[1] interleaver = codec[2] # Initialize BCJR input/output Pair for training (Is that necessary?) bcjr_inputs = np.zeros([2*num_iteration, num_block, block_len ,input_feature_num]) bcjr_outputs = np.zeros([2*num_iteration, num_block, block_len ,1 ]) for block_idx in range(num_block): # Generate Noisy Input For Turbo Decoding message_bits = np.random.randint(0, 2, block_len) [sys, par1, par2] = turbo.turbo_encode(message_bits, trellis1, trellis2, interleaver) sys_r = corrupt_signal(sys, noise_type =noise_type, sigma = noise_sigma,) par1_r = corrupt_signal(par1, noise_type =noise_type, sigma = noise_sigma) par2_r = corrupt_signal(par2, noise_type =noise_type, sigma = noise_sigma) # Use the Commpy BCJR decoding algorithm sys_symbols = sys_r non_sys_symbols_1 = par1_r non_sys_symbols_2 = par2_r noise_variance = noise_sigma**2 #print("+++++++++++++++++++") #print("SYS_SYMBOLS: ", sys_symbols) #print("+++++++++++++++++++") sys_symbols_i = interleaver.interlv(sys_symbols) trellis = trellis1 L_int = None if L_int is None: L_int = np.zeros(len(sys_symbols)) L_int_1 = L_int L_ext_2 = L_int_1 weighted_sys = 2*sys_symbols*1.0/noise_variance # Is gonna be used in the final step of decoding. weighted_sys_int = interleaver.interlv(weighted_sys) for turbo_iteration_idx in range(num_iteration-1): L_int_1 = interleaver.deinterlv(L_ext_2) # MAP 1 [L_ext_1, decoded_bits] = turbo.map_decode(sys_symbols, non_sys_symbols_1, trellis, noise_variance, L_int_1, 'compute') L_ext_1 -= L_int_1 L_ext_1 -= weighted_sys # ADD Training Examples bcjr_inputs[2*turbo_iteration_idx,block_idx,:,:] = np.concatenate([sys_symbols.reshape(block_len,1), non_sys_symbols_1.reshape(block_len,1), L_int_1.reshape(block_len,1)], axis=1) bcjr_outputs[2*turbo_iteration_idx,block_idx,:,:]= L_ext_1.reshape(block_len,1) # MAP 2 L_int_2 = interleaver.interlv(L_ext_1) #print("+++++++++++++++++++++++++++++++++++") #print(sys_symbols_i) #print(sys_symbols_i.shape) #print("+++++++++++++++++++++++++++++++++++") [L_ext_2, decoded_bits] = turbo.map_decode(sys_symbols_i, non_sys_symbols_2, trellis, noise_variance, L_int_2, 'compute') L_ext_2 -= L_int_2 L_ext_2 -= weighted_sys_int # ADD Training Examples bcjr_inputs[2*turbo_iteration_idx+1,block_idx,:,:] = np.concatenate([sys_symbols_i.reshape(block_len,1), non_sys_symbols_2.reshape(block_len,1), L_int_2.reshape(block_len,1)], axis=1) bcjr_outputs[2*turbo_iteration_idx+1,block_idx,:,:] = L_ext_2.reshape(block_len,1) # MAP 1 L_int_1 = interleaver.deinterlv(L_ext_2) [L_ext_1, decoded_bits] = turbo.map_decode(sys_symbols, non_sys_symbols_1, trellis, noise_variance, L_int_1, 'compute') L_ext_1 -= L_int_1 L_ext_1 -= weighted_sys # ADD Training Examples bcjr_inputs[2*num_iteration-2,block_idx,:,:] = np.concatenate([sys_symbols.reshape(block_len,1), non_sys_symbols_1.reshape(block_len,1), L_int_1.reshape(block_len,1)], axis=1) bcjr_outputs[2*num_iteration-2,block_idx,:,:] = L_ext_1.reshape(block_len,1) # MAP 2 L_int_2 = interleaver.interlv(L_ext_1) [L_2, decoded_bits] = turbo.map_decode(sys_symbols_i, non_sys_symbols_2, trellis, noise_variance, L_int_2, 'decode') L_ext_2 = L_2 - L_int_2 L_ext_2 -= weighted_sys_int # ADD Training Examples bcjr_inputs[2*num_iteration-1,block_idx,:,:] = np.concatenate([sys_symbols_i.reshape(block_len,1), non_sys_symbols_2.reshape(block_len,1), L_int_2.reshape(block_len,1)], axis=1) bcjr_outputs[2*num_iteration-1,block_idx,:,:] = L_ext_2.reshape(block_len,1) end_time = time.time() print('[BCJR] The input feature has shape', bcjr_inputs.shape,'the output has shape', bcjr_outputs.shape) print('[BCJR] Generating Training Example takes ', end_time - start_time , 'secs') print('[BCJR] file id is', identity) bcjr_inputs_train = bcjr_inputs.reshape((-1, block_len,input_feature_num )) bcjr_outputs_train = bcjr_outputs.reshape((-1, block_len, 1)) target_train_select = bcjr_outputs_train[:,:,0] + bcjr_inputs_train[:,:,2] target_train_select[:,:] = math.e**target_train_select[:,:]*1.0/(1+math.e**target_train_select[:,:]) X_input = bcjr_inputs_train.reshape(-1,block_len,input_feature_num) X_target = target_train_select.reshape(-1,block_len,1) return X_input, X_target