Exemplo n.º 1
0
def train(epoch, model, optimizer, args, use_cuda = False, verbose = True, mode = 'encoder'):

    device = torch.device("cuda" if use_cuda else "cpu")

    model.train()
    start_time = time.time()
    train_loss = 0.0
    k_same_code_counter = 0


    for batch_idx in range(int(args.num_block/args.batch_size)):

        optimizer.zero_grad()

        if args.is_k_same_code and mode == 'encoder':
            if batch_idx == 0:
                k_same_code_counter += 1
                X_train    = torch.randint(0, 2, (args.batch_size, args.block_len, args.code_rate_k), dtype=torch.float)
            elif k_same_code_counter == args.k_same_code:
                k_same_code_counter = 1
                X_train    = torch.randint(0, 2, (args.batch_size, args.block_len, args.code_rate_k), dtype=torch.float)
            else:
                k_same_code_counter += 1
        else:
            X_train    = torch.randint(0, 2, (args.batch_size, args.block_len, args.code_rate_k), dtype=torch.float)

        # train encoder/decoder with different SNR... seems to be a good practice.
        if mode == 'encoder':
            fwd_noise  = generate_noise(X_train.shape, args, snr_low=args.train_enc_channel_low, snr_high=args.train_enc_channel_high, mode = 'encoder')
        else:
            fwd_noise  = generate_noise(X_train.shape, args, snr_low=args.train_dec_channel_low, snr_high=args.train_dec_channel_high, mode = 'decoder')

        X_train, fwd_noise = X_train.to(device), fwd_noise.to(device)

        output, code = model(X_train, fwd_noise)
        output = torch.clamp(output, 0.0, 1.0)

        if mode == 'encoder':
            loss = customized_loss(output, X_train, args, noise=fwd_noise, code = code)

        else:
            loss = customized_loss(output, X_train, args, noise=fwd_noise, code = code)
            #loss = F.binary_cross_entropy(output, X_train)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    end_time = time.time()
    train_loss = train_loss /(args.num_block/args.batch_size)
    if verbose:
        print('====> Epoch: {} Average loss: {:.8f}'.format(epoch, train_loss), \
            ' running time', str(end_time - start_time))

    return train_loss
Exemplo n.º 2
0
def validate(model, optimizer, args, use_cuda=False, verbose=True):

    device = torch.device("cuda" if use_cuda else "cpu")

    model.eval()
    test_bce_loss, test_custom_loss, test_ber = 0.0, 0.0, 0.0

    with torch.no_grad():
        num_test_batch = int(args.num_block / args.batch_size *
                             args.test_ratio)
        for batch_idx in range(num_test_batch):
            X_test = torch.randint(
                0,
                2, (args.batch_size, args.block_len, args.code_rate_k),
                dtype=torch.float)
            noise_shape = (args.batch_size, args.block_len, args.code_rate_n)
            fwd_noise = generate_noise(noise_shape,
                                       args,
                                       snr_low=args.train_enc_channel_low,
                                       snr_high=args.train_enc_channel_low)

            X_test, fwd_noise = X_test.to(device), fwd_noise.to(device)

            optimizer.zero_grad()
            output, codes = model(X_test, fwd_noise)

            output = torch.clamp(output, 0.0, 1.0)

            output = output.detach()
            X_test = X_test.detach()

            test_bce_loss += F.binary_cross_entropy(output, X_test)
            test_custom_loss += customized_loss(output,
                                                X_test,
                                                noise=fwd_noise,
                                                args=args,
                                                code=codes)
            test_ber += errors_ber(output, X_test)

    test_bce_loss /= num_test_batch
    test_custom_loss /= num_test_batch
    test_ber /= num_test_batch

    if verbose:
        print(
            '====> Test set BCE loss',
            float(test_bce_loss),
            'Custom Loss',
            float(test_custom_loss),
            'with ber ',
            float(test_ber),
        )

    report_loss = float(test_bce_loss)
    report_ber = float(test_ber)

    return report_loss, report_ber
Exemplo n.º 3
0
def system(args, optimizer, enc, dec, use_cuda=False, verbose=True):
    device = torch.device("cuda" if use_cuda else "cpu")
    train_loss = 0.0
    for batch_idx in range(int(args.num_block / args.batch_size)):
        if args.is_variable_block_len:
            block_len = np.random.randint(args.block_len_low,
                                          args.block_len_high)
        else:
            block_len = args.block_len
        optimizer.zero_grad()
        # generate bit and noise
        X_train = torch.randint(0,
                                2,
                                (args.batch_size, block_len, args.code_rate_k),
                                dtype=torch.float)
        noise_shape = (args.batch_size, args.block_len, args.code_rate_n)
        fwd_noise = generate_noise(noise_shape,
                                   args,
                                   snr_low=args.train_dec_channel_low,
                                   snr_high=args.train_dec_channel_high,
                                   mode='decoder')
        X_train, fwd_noise = X_train.to(device), fwd_noise.to(device)
        # pass system
        if args.is_interleave == 0:
            pass
        elif args.is_same_interleaver == 0:
            interleaver = RandInterlv.RandInterlv(args.block_len,
                                                  np.random.randint(0, 1000))

            p_array = interleaver.p_array
            enc.set_interleaver(p_array)
            dec.set_interleaver(p_array)
        else:  # self.args.is_same_interleaver == 1
            interleaver = RandInterlv.RandInterlv(args.block_len,
                                                  0)  # not random anymore!
            p_array = interleaver.p_array
            enc.set_interleaver(p_array)
            dec.set_interleaver(p_array)
        codes = enc.encode(X_train)
        if self.args.channel in [
                'awgn', 't-dist', 'radar', 'ge_awgn', 'bikappa'
        ]:
            # print("noise_type:",self.args.channel)
            received_codes = codes + fwd_noise

        elif self.args.channel == 'bec':
            received_codes = codes * fwd_noise

        elif self.args.channel in ['bsc', 'ge']:
            received_codes = codes * (2.0 * fwd_noise - 1.0)
            received_codes = received_codes.type(torch.FloatTensor)
        else:
            print('default AWGN channel')
            received_codes = codes + fwd_noise
        if args.rec_quantize:
            myquantize = MyQuantize.apply
            received_codes = myquantize(received_codes,
                                        args.rec_quantize_level,
                                        args.rec_quantize_level)
        x_dec = dec(received_codes)
        loss = customized_loss(output,
                               X_train,
                               args,
                               noise=fwd_noise,
                               code=code)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    train_loss = train_loss / (args.num_block / args.batch_size)
    if verbose:
        print('====> Epoch: {} Average loss: {:.8f}'.format(epoch, train_loss), \
            ' running time', str(end_time - start_time))
Exemplo n.º 4
0
def ftae_train(epoch,
               model,
               optimizer,
               args,
               use_cuda=False,
               verbose=True,
               mode='encoder'):

    device = torch.device("cuda" if use_cuda else "cpu")

    model.train()
    start_time = time.time()
    train_loss = 0.0

    for batch_idx in range(int(args.num_block / args.batch_size)):

        optimizer.zero_grad()
        X_train = torch.randint(
            0,
            2, (args.batch_size, args.block_len, args.code_rate_k),
            dtype=torch.float)

        if mode == 'encoder':
            fwd_noise = generate_noise(X_train.shape,
                                       args,
                                       snr_low=args.train_enc_channel_low,
                                       snr_high=args.train_enc_channel_high,
                                       mode='encoder')
        else:
            fwd_noise = generate_noise(X_train.shape,
                                       args,
                                       snr_low=args.train_dec_channel_low,
                                       snr_high=args.train_dec_channel_high,
                                       mode='decoder')

        fb_noise = generate_noise(X_train.shape,
                                  args,
                                  snr_low=args.fb_channel_low,
                                  snr_high=args.fb_channel_high,
                                  mode='decoder')

        X_train, fwd_noise, fb_noise = X_train.to(device), fwd_noise.to(
            device), fb_noise.to(device)

        output, code = model(X_train, fwd_noise, fb_noise)
        output = torch.clamp(output, 0.0, 1.0)

        if mode == 'encoder':
            loss = customized_loss(output,
                                   X_train,
                                   args,
                                   noise=fwd_noise,
                                   code=code)

        else:
            loss = customized_loss(output,
                                   X_train,
                                   args,
                                   noise=fwd_noise,
                                   code=code)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    end_time = time.time()
    train_loss = train_loss / (args.num_block / args.batch_size)
    if verbose:
        print('====> Epoch: {} Average loss: {:.8f}'.format(epoch, train_loss), \
            ' running time', str(end_time - start_time))

    return train_loss