Exemple #1
0
    def backward(ctx, grad_output):
        if ctx.args.enc_clipping in ['inputs', 'both']:
            input, = ctx.saved_tensors
            grad_output[input > ctx.args.enc_value_limit] = 0
            grad_output[input < -ctx.args.enc_value_limit] = 0

        if ctx.args.enc_clipping in ['gradient', 'both']:
            grad_output = torch.clamp(grad_output, -ctx.args.enc_grad_limit,
                                      ctx.args.enc_grad_limit)

        if ctx.args.train_channel_mode not in [
                'group_norm_noisy', 'group_norm_noisy_quantize'
        ]:
            grad_input = grad_output.clone()
        else:
            # Experimental pass gradient noise to encoder.
            grad_noise = snr_db2sigma(ctx.args.fb_noise_snr) * torch.randn(
                grad_output[0].shape, dtype=torch.float)
            ave_temp = grad_output.mean(dim=0) + grad_noise
            ave_grad = torch.stack(
                [ave_temp for _ in range(ctx.args.batch_size)],
                dim=2).permute(2, 0, 1)
            grad_input = ave_grad + grad_noise

        return grad_input, None
Exemple #2
0
def generate_noise(X_train_shape,
                   args,
                   test_sigma='default',
                   snr_low=0.0,
                   snr_high=0.0,
                   mode='encoder'):
    # SNRs at training
    if test_sigma == 'default':
        if args.channel == 'bec':
            if mode == 'encoder':
                this_sigma = args.bec_p_enc
            else:
                this_sigma = args.bec_p_dec

        elif args.channel in ['bsc', 'ge']:
            if mode == 'encoder':
                this_sigma = args.bsc_p_enc
            else:
                this_sigma = args.bsc_p_dec
        else:  # general AWGN cases
            this_sigma_low = snr_db2sigma(snr_low)
            this_sigma_high = snr_db2sigma(snr_high)
            # mixture of noise sigma.
            this_sigma = (this_sigma_low - this_sigma_high) * torch.rand(
                (X_train_shape[0], X_train_shape[1],
                 args.code_rate_n)) + this_sigma_high

    else:
        if args.channel in ['bec', 'bsc', 'ge']:  # bsc/bec noises
            this_sigma = test_sigma
        else:
            this_sigma = snr_db2sigma(test_sigma)

    # SNRs at testing
    if args.channel == 'awgn':
        fwd_noise = this_sigma * torch.randn(
            (X_train_shape[0], X_train_shape[1], args.code_rate_n),
            dtype=torch.float)

    elif args.channel == 't-dist':
        fwd_noise = this_sigma * torch.from_numpy(
            np.sqrt((args.vv - 2) / args.vv) * np.random.standard_t(
                args.vv,
                size=(X_train_shape[0], X_train_shape[1],
                      args.code_rate_n))).type(torch.FloatTensor)

    elif args.channel == 'radar':
        add_pos = np.random.choice(
            [0.0, 1.0], (X_train_shape[0], X_train_shape[1], args.code_rate_n),
            p=[1 - args.radar_prob, args.radar_prob])

        corrupted_signal = args.radar_power * np.random.standard_normal(size=(
            X_train_shape[0], X_train_shape[1], args.code_rate_n)) * add_pos
        fwd_noise = this_sigma * torch.randn((X_train_shape[0], X_train_shape[1], args.code_rate_n), dtype=torch.float) +\
                    torch.from_numpy(corrupted_signal).type(torch.FloatTensor)

    elif args.channel == 'bec':
        fwd_noise = torch.from_numpy(
            np.random.choice(
                [0.0, 1.0],
                (X_train_shape[0], X_train_shape[1], args.code_rate_n),
                p=[this_sigma, 1 - this_sigma])).type(torch.FloatTensor)

    elif args.channel == 'bsc':
        fwd_noise = torch.from_numpy(
            np.random.choice(
                [0.0, 1.0],
                (X_train_shape[0], X_train_shape[1], args.code_rate_n),
                p=[this_sigma, 1 - this_sigma])).type(torch.FloatTensor)
    elif args.channel == 'ge_awgn':
        #G-E AWGN channel
        p_gg = 0.8  # stay in good state
        p_bb = 0.8
        bsc_k = snr_db2sigma(snr_sigma2db(this_sigma) +
                             1)  # accuracy on good state
        bsc_h = snr_db2sigma(snr_sigma2db(this_sigma) -
                             1)  # accuracy on good state

        fwd_noise = np.zeros(
            (X_train_shape[0], X_train_shape[1], args.code_rate_n))
        for batch_idx in range(X_train_shape[0]):
            for code_idx in range(args.code_rate_n):

                good = True
                for time_idx in range(X_train_shape[1]):
                    if good:
                        if test_sigma == 'default':
                            fwd_noise[batch_idx, time_idx,
                                      code_idx] = bsc_k[batch_idx, time_idx,
                                                        code_idx]
                        else:
                            fwd_noise[batch_idx, time_idx, code_idx] = bsc_k
                        good = np.random.random() < p_gg
                    elif not good:
                        if test_sigma == 'default':
                            fwd_noise[batch_idx, time_idx,
                                      code_idx] = bsc_h[batch_idx, time_idx,
                                                        code_idx]
                        else:
                            fwd_noise[batch_idx, time_idx, code_idx] = bsc_h
                        good = np.random.random() < p_bb
                    else:
                        print('bad!!! something happens')

        fwd_noise = torch.from_numpy(fwd_noise).type(
            torch.FloatTensor) * torch.randn(
                (X_train_shape[0], X_train_shape[1], args.code_rate_n),
                dtype=torch.float)

    elif args.channel == 'ge':
        #G-E discrete channel
        p_gg = 0.8  # stay in good state
        p_bb = 0.8
        bsc_k = 1.0  # accuracy on good state
        bsc_h = this_sigma  # accuracy on good state

        fwd_noise = np.zeros(
            (X_train_shape[0], X_train_shape[1], args.code_rate_n))
        for batch_idx in range(X_train_shape[0]):
            for code_idx in range(args.code_rate_n):

                good = True
                for time_idx in range(X_train_shape[1]):
                    if good:
                        tmp = np.random.choice([0.0, 1.0],
                                               p=[1 - bsc_k, bsc_k])
                        fwd_noise[batch_idx, time_idx, code_idx] = tmp
                        good = np.random.random() < p_gg
                    elif not good:
                        tmp = np.random.choice([0.0, 1.0],
                                               p=[1 - bsc_h, bsc_h])
                        fwd_noise[batch_idx, time_idx, code_idx] = tmp
                        good = np.random.random() < p_bb
                    else:
                        print('bad!!! something happens')

        fwd_noise = torch.from_numpy(fwd_noise).type(torch.FloatTensor)

    else:
        # Unspecific channel, use AWGN channel.
        fwd_noise = this_sigma * torch.randn(
            (X_train_shape[0], X_train_shape[1], args.code_rate_n),
            dtype=torch.float)

    return fwd_noise
Exemple #3
0
def ftae_test(model, args, use_cuda=False):

    device = torch.device("cuda" if use_cuda else "cpu")
    model.eval()

    # Precomputes Norm Statistics.
    if args.precompute_norm_stats:
        num_test_batch = int(args.num_block / (args.batch_size) *
                             args.test_ratio)
        for batch_idx in range(num_test_batch):
            X_test = torch.randint(
                0,
                2, (args.batch_size, args.block_len, args.code_rate_k),
                dtype=torch.float)
            X_test = X_test.to(device)
            _ = model.enc(X_test)
        print('Pre-computed norm statistics mean ', model.enc.mean_scalar,
              'std ', model.enc.std_scalar)

    ber_res, bler_res = [], []
    snr_interval = (args.snr_test_end -
                    args.snr_test_start) * 1.0 / (args.snr_points - 1)
    snrs = [
        snr_interval * item + args.snr_test_start
        for item in range(args.snr_points)
    ]
    print('SNRS', snrs)
    sigmas = snrs

    for sigma, this_snr in zip(sigmas, snrs):
        test_ber, test_bler = .0, .0
        with torch.no_grad():
            num_test_batch = int(args.num_block / (args.batch_size) *
                                 args.test_ratio)
            for batch_idx in range(num_test_batch):
                X_test = torch.randint(
                    0,
                    2, (args.batch_size, args.block_len, args.code_rate_k),
                    dtype=torch.float)
                fwd_noise = generate_noise(X_test.shape,
                                           args,
                                           test_sigma=sigma)

                X_test, fwd_noise = X_test.to(device), fwd_noise.to(device)

                X_hat_test, the_codes = model(X_test, fwd_noise)

                test_ber += errors_ber(X_hat_test, X_test)
                test_bler += errors_bler(X_hat_test, X_test)

                if batch_idx == 0:
                    test_pos_ber = errors_ber_pos(X_hat_test, X_test)
                    codes_power = code_power(the_codes)
                else:
                    test_pos_ber += errors_ber_pos(X_hat_test, X_test)
                    codes_power += code_power(the_codes)

            if args.print_pos_power:
                print('code power', codes_power / num_test_batch)
            if args.print_pos_ber:
                print('positional ber', test_pos_ber / num_test_batch)

        test_ber /= num_test_batch
        test_bler /= num_test_batch
        print('Test SNR', this_snr, 'with ber ', float(test_ber), 'with bler',
              float(test_bler))
        ber_res.append(float(test_ber))
        bler_res.append(float(test_bler))

    print('final results on SNRs ', snrs)
    print('BER', ber_res)
    print('BLER', bler_res)

    # compute adjusted SNR. (some quantization might make power!=1.0)
    enc_power = 0.0
    with torch.no_grad():
        for idx in range(num_test_batch):
            X_test = torch.randint(
                0,
                2, (args.batch_size, args.block_len, args.code_rate_k),
                dtype=torch.float)
            X_test = X_test.to(device)
            X_code = model.enc(X_test)
            enc_power += torch.std(X_code)
    enc_power /= float(num_test_batch)
    print('encoder power is', enc_power)
    adj_snrs = [snr_sigma2db(snr_db2sigma(item) / enc_power) for item in snrs]
    print('adjusted SNR should be', adj_snrs)
Exemple #4
0
def test(model, args, block_len='default', use_cuda=False):

    device = torch.device("cuda" if use_cuda else "cpu")
    model.eval()

    if block_len == 'default':
        block_len = args.block_len
    else:
        pass

    # Precomputes Norm Statistics.
    if args.precompute_norm_stats:
        with torch.no_grad():
            num_test_batch = int(args.num_block / (args.batch_size) *
                                 args.test_ratio)
            for batch_idx in range(num_test_batch):
                X_test = torch.randint(
                    0,
                    2, (args.batch_size, block_len, args.code_rate_k),
                    dtype=torch.float)
                X_test = X_test.to(device)
                _ = model.enc(X_test)
            print('Pre-computed norm statistics mean ', model.enc.mean_scalar,
                  'std ', model.enc.std_scalar)

    ber_res, bler_res = [], []
    ber_res_punc, bler_res_punc = [], []
    snr_interval = (args.snr_test_end -
                    args.snr_test_start) * 1.0 / (args.snr_points - 1)
    snrs = [
        snr_interval * item + args.snr_test_start
        for item in range(args.snr_points)
    ]
    print('SNRS', snrs)
    sigmas = snrs

    for sigma, this_snr in zip(sigmas, snrs):
        test_ber, test_bler = .0, .0
        with torch.no_grad():
            num_test_batch = int(args.num_block / (args.batch_size))
            for batch_idx in range(num_test_batch):
                X_test = torch.randint(
                    0,
                    2, (args.batch_size, block_len, args.code_rate_k),
                    dtype=torch.float)
                noise_shape = (args.batch_size,
                               int(args.block_len * args.code_rate_n /
                                   args.mod_rate), args.mod_rate)
                fwd_noise = generate_noise(noise_shape, args, test_sigma=sigma)

                X_test, fwd_noise = X_test.to(device), fwd_noise.to(device)

                X_hat_test, the_codes = model(X_test, fwd_noise)

                test_ber += errors_ber(X_hat_test, X_test)
                test_bler += errors_bler(X_hat_test, X_test)

                if batch_idx == 0:
                    test_pos_ber = errors_ber_pos(X_hat_test, X_test)
                    codes_power = code_power(the_codes)
                else:
                    test_pos_ber += errors_ber_pos(X_hat_test, X_test)
                    codes_power += code_power(the_codes)

            if args.print_pos_power:
                print('code power', codes_power / num_test_batch)
            if args.print_pos_ber:
                res_pos = test_pos_ber / num_test_batch
                res_pos_arg = np.array(res_pos.cpu()).argsort()[::-1]
                res_pos_arg = res_pos_arg.tolist()
                print('positional ber', res_pos)
                print('positional argmax', res_pos_arg)
            try:
                test_ber_punc, test_bler_punc = .0, .0
                for batch_idx in range(num_test_batch):
                    X_test = torch.randint(
                        0,
                        2, (args.batch_size, block_len, args.code_rate_k),
                        dtype=torch.float)
                    noise_shape = (args.batch_size,
                                   int(args.block_len * args.code_rate_n /
                                       args.mod_rate), args.mod_rate)
                    fwd_noise = generate_noise(noise_shape,
                                               args,
                                               test_sigma=sigma)
                    X_test, fwd_noise = X_test.to(device), fwd_noise.to(device)

                    X_hat_test, the_codes = model(X_test, fwd_noise)

                    test_ber_punc += errors_ber(
                        X_hat_test,
                        X_test,
                        positions=res_pos_arg[:args.num_ber_puncture])
                    test_bler_punc += errors_bler(
                        X_hat_test,
                        X_test,
                        positions=res_pos_arg[:args.num_ber_puncture])

                    if batch_idx == 0:
                        test_pos_ber = errors_ber_pos(X_hat_test, X_test)
                        codes_power = code_power(the_codes)
                    else:
                        test_pos_ber += errors_ber_pos(X_hat_test, X_test)
                        codes_power += code_power(the_codes)
            except:
                print('no pos BER specified.')

        test_ber /= num_test_batch
        test_bler /= num_test_batch
        print('Test SNR', this_snr, 'with ber ', float(test_ber), 'with bler',
              float(test_bler))
        ber_res.append(float(test_ber))
        bler_res.append(float(test_bler))

        try:
            test_ber_punc /= num_test_batch
            test_bler_punc /= num_test_batch
            print('Punctured Test SNR', this_snr, 'with ber ',
                  float(test_ber_punc), 'with bler', float(test_bler_punc))
            ber_res_punc.append(float(test_ber_punc))
            bler_res_punc.append(float(test_bler_punc))
        except:
            print('No puncturation is there.')

    print('final results on SNRs ', snrs)
    print('BER', ber_res)
    print('BLER', bler_res)
    print('final results on punctured SNRs ', snrs)
    print('BER', ber_res_punc)
    print('BLER', bler_res_punc)

    # compute adjusted SNR. (some quantization might make power!=1.0)
    enc_power = 0.0
    with torch.no_grad():
        for idx in range(num_test_batch):
            X_test = torch.randint(
                0,
                2, (args.batch_size, block_len, args.code_rate_k),
                dtype=torch.float)
            X_test = X_test.to(device)
            X_code = model.enc(X_test)
            enc_power += torch.std(X_code)
    enc_power /= float(num_test_batch)
    print('encoder power is', enc_power)
    adj_snrs = [snr_sigma2db(snr_db2sigma(item) / enc_power) for item in snrs]
    print('adjusted SNR should be', adj_snrs)
Exemple #5
0
def generate_bcjr_example(num_block, block_len, codec, num_iteration, is_save = True, train_snr_db = 0.0, save_path = './tmp/',
                          **kwargs ):
    '''
    Generate BCJR feature and target for training BCJR-like RNN codec from scratch
    '''

    start_time = time.time()
    # print
    print('[BCJR] Block Length is ', block_len)
    print('[BCJR] Number of Block is ', num_block)

    input_feature_num = 3
    noise_type  = 'awgn'
    noise_sigma = snr_db2sigma(train_snr_db)

    identity = str(np.random.random())    # random id for saving

    # Unpack Codec
    trellis1    = codec[0]
    trellis2    = codec[1]
    interleaver = codec[2]

    # Initialize BCJR input/output Pair for training (Is that necessary?)
    bcjr_inputs  = np.zeros([2*num_iteration, num_block, block_len ,input_feature_num])
    bcjr_outputs = np.zeros([2*num_iteration, num_block, block_len ,1        ])

    for block_idx in range(num_block):
        # Generate Noisy Input For Turbo Decoding
        message_bits = np.random.randint(0, 2, block_len)
        [sys, par1, par2] = turbo.turbo_encode(message_bits, trellis1, trellis2, interleaver)

        sys_r  = corrupt_signal(sys, noise_type =noise_type, sigma = noise_sigma,)
        par1_r = corrupt_signal(par1, noise_type =noise_type, sigma = noise_sigma)
        par2_r = corrupt_signal(par2, noise_type =noise_type, sigma = noise_sigma)

        # Use the Commpy BCJR decoding algorithm
        sys_symbols = sys_r
        non_sys_symbols_1 = par1_r
        non_sys_symbols_2 = par2_r
        noise_variance = noise_sigma**2
        #print("+++++++++++++++++++")
        #print("SYS_SYMBOLS: ", sys_symbols)
        #print("+++++++++++++++++++")
        sys_symbols_i = interleaver.interlv(sys_symbols)
        trellis = trellis1

        L_int = None
        if L_int is None:
            L_int = np.zeros(len(sys_symbols))

        L_int_1 = L_int
        L_ext_2 = L_int_1

        weighted_sys = 2*sys_symbols*1.0/noise_variance # Is gonna be used in the final step of decoding.
        weighted_sys_int = interleaver.interlv(weighted_sys)

        for turbo_iteration_idx in range(num_iteration-1):
            L_int_1 = interleaver.deinterlv(L_ext_2)
            # MAP 1
            [L_ext_1, decoded_bits] = turbo.map_decode(sys_symbols, non_sys_symbols_1,
                                                 trellis, noise_variance, L_int_1, 'compute')
            L_ext_1 -= L_int_1
            L_ext_1 -= weighted_sys

             # ADD Training Examples
            bcjr_inputs[2*turbo_iteration_idx,block_idx,:,:] = np.concatenate([sys_symbols.reshape(block_len,1),
                                                                               non_sys_symbols_1.reshape(block_len,1),
                                                                               L_int_1.reshape(block_len,1)],
                                                                              axis=1)
            bcjr_outputs[2*turbo_iteration_idx,block_idx,:,:]= L_ext_1.reshape(block_len,1)

            # MAP 2
            L_int_2 = interleaver.interlv(L_ext_1)

            #print("+++++++++++++++++++++++++++++++++++")
            #print(sys_symbols_i)
            #print(sys_symbols_i.shape)
            #print("+++++++++++++++++++++++++++++++++++")

            [L_ext_2, decoded_bits] = turbo.map_decode(sys_symbols_i, non_sys_symbols_2,
                                             trellis, noise_variance, L_int_2, 'compute')
            L_ext_2 -=  L_int_2
            L_ext_2 -=  weighted_sys_int
            # ADD Training Examples
            bcjr_inputs[2*turbo_iteration_idx+1,block_idx,:,:] = np.concatenate([sys_symbols_i.reshape(block_len,1),
                                                                                 non_sys_symbols_2.reshape(block_len,1),
                                                                                 L_int_2.reshape(block_len,1)],
                                                                                axis=1)
            bcjr_outputs[2*turbo_iteration_idx+1,block_idx,:,:] = L_ext_2.reshape(block_len,1)

        # MAP 1
        L_int_1 = interleaver.deinterlv(L_ext_2)
        [L_ext_1, decoded_bits] = turbo.map_decode(sys_symbols, non_sys_symbols_1,
                                             trellis, noise_variance, L_int_1, 'compute')
        L_ext_1 -= L_int_1
        L_ext_1 -= weighted_sys
         # ADD Training Examples


        bcjr_inputs[2*num_iteration-2,block_idx,:,:] = np.concatenate([sys_symbols.reshape(block_len,1),
                                                                     non_sys_symbols_1.reshape(block_len,1),
                                                                     L_int_1.reshape(block_len,1)],
                                                                    axis=1)
        bcjr_outputs[2*num_iteration-2,block_idx,:,:] = L_ext_1.reshape(block_len,1)

        # MAP 2
        L_int_2 = interleaver.interlv(L_ext_1)
        [L_2, decoded_bits] = turbo.map_decode(sys_symbols_i, non_sys_symbols_2,
                                         trellis, noise_variance, L_int_2, 'decode')
        L_ext_2 = L_2 - L_int_2
        L_ext_2 -=  weighted_sys_int
        # ADD Training Examples
        bcjr_inputs[2*num_iteration-1,block_idx,:,:] = np.concatenate([sys_symbols_i.reshape(block_len,1),
                                                                       non_sys_symbols_2.reshape(block_len,1),
                                                                       L_int_2.reshape(block_len,1)],
                                                                      axis=1)
        bcjr_outputs[2*num_iteration-1,block_idx,:,:] = L_ext_2.reshape(block_len,1)

    end_time = time.time()
    print('[BCJR] The input feature has shape', bcjr_inputs.shape,'the output has shape', bcjr_outputs.shape)
    print('[BCJR] Generating Training Example takes ', end_time - start_time , 'secs')
    print('[BCJR] file id is', identity)

    bcjr_inputs_train   = bcjr_inputs.reshape((-1, block_len,input_feature_num ))
    bcjr_outputs_train  = bcjr_outputs.reshape((-1,  block_len, 1))

    target_train_select = bcjr_outputs_train[:,:,0] + bcjr_inputs_train[:,:,2]

    target_train_select[:,:] = math.e**target_train_select[:,:]*1.0/(1+math.e**target_train_select[:,:])

    X_input  = bcjr_inputs_train.reshape(-1,block_len,input_feature_num)
    X_target = target_train_select.reshape(-1,block_len,1)

    return X_input, X_target