Ejemplo n.º 1
0
def validate(model, optimizer, args, use_cuda=False, verbose=True):

    device = torch.device("cuda" if use_cuda else "cpu")

    model.eval()
    test_bce_loss, test_custom_loss, test_ber = 0.0, 0.0, 0.0

    with torch.no_grad():
        num_test_batch = int(args.num_block / args.batch_size *
                             args.test_ratio)
        for batch_idx in range(num_test_batch):
            X_test = torch.randint(
                0,
                2, (args.batch_size, args.block_len, args.code_rate_k),
                dtype=torch.float)
            noise_shape = (args.batch_size, args.block_len, args.code_rate_n)
            fwd_noise = generate_noise(noise_shape,
                                       args,
                                       snr_low=args.train_enc_channel_low,
                                       snr_high=args.train_enc_channel_low)

            X_test, fwd_noise = X_test.to(device), fwd_noise.to(device)

            optimizer.zero_grad()
            output, codes = model(X_test, fwd_noise)

            output = torch.clamp(output, 0.0, 1.0)

            output = output.detach()
            X_test = X_test.detach()

            test_bce_loss += F.binary_cross_entropy(output, X_test)
            test_custom_loss += customized_loss(output,
                                                X_test,
                                                noise=fwd_noise,
                                                args=args,
                                                code=codes)
            test_ber += errors_ber(output, X_test)

    test_bce_loss /= num_test_batch
    test_custom_loss /= num_test_batch
    test_ber /= num_test_batch

    if verbose:
        print(
            '====> Test set BCE loss',
            float(test_bce_loss),
            'Custom Loss',
            float(test_custom_loss),
            'with ber ',
            float(test_ber),
        )

    report_loss = float(test_bce_loss)
    report_ber = float(test_ber)

    return report_loss, report_ber
Ejemplo n.º 2
0
            #forward pass encoded real batch to decoder

            #add noise to image
            channel_noise = noise_std * torch.randn(
                fake.shape, dtype=torch.float, device=device)
            noisy_fake = channel_noise + fake

            output = netDec(noisy_fake.detach())
            #calculate loss
            errDec = criterion(output, u)
            #calculate gradient
            errDec.backward()
            D_E_x_2 = output.mean().item()
            # output right now is set to u decoded from decoder, so we can use that directly for ber
            # ber is calculated
            ber = errors_ber(output, u)
            #forward pass encoded fake batch to decoder
            #output = netDec(fake_enc_img.detach())
            #calculate loss
            #errDec_fake1 = criterion(output,u)
            #calculate gradient
            #errDec_fake1.backward()
            #D_E_G_z_1 = output.mean().item()
            #forward pass fake batch to decoder
            #output = netDec(fake.detach())
            #calculate loss
            #errDec_fake2 = criterion(output,u)
            #calcualte gradient
            #errDec_fake2.backward()
            #D_E_G_z_2 = output.mean().item()
            #add all the losses
Ejemplo n.º 3
0
                    # Adversarial loss
                    loss_D = -torch.mean(
                        discriminator(real_imgs)) + torch.mean(
                            discriminator(fake_imgs))

                    loss_D.backward()
                    optimizer_D.step()

                    # Clip weights of discriminator
                    for p in discriminator.parameters():
                        p.data.clamp_(-args.clip_value, args.clip_value)

            if i % 100 == 0:
                decoded_info = decoded_info.detach()
                u = u.detach()
                this_ber = errors_ber(decoded_info, u)
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [batch Dec BER: %f]"
                    % (epoch, args.num_epoch, i, len(train_dataloader),
                       d_loss.item(), g_loss.item(), this_ber))

            batches_done = epoch * len(train_dataloader) + i
            if batches_done % args.sample_interval == 0:
                save_image(gen_imgs.data[:25],
                           'images/' + identity + '/%d.png' % batches_done,
                           nrow=5,
                           normalize=True)

    # --------------------------
    #  Testing: only for BER
    # --------------------------
Ejemplo n.º 4
0
                u = torch.randint(0,
                                  2, (b_size, ud, im_u, im_u),
                                  dtype=torch.float,
                                  device=device)

                # forward pass fake encoded images
                # add noise to image
                fake_img = netG(noise)
                fake_enc = netE(fake_img.detach(), u)
                nfake_enc = channel(fake_enc.detach(), args.awgn, args)
                foutput = netDec(nfake_enc.detach())

                errDec_fakeenc = criterion(foutput, u)
                errDec_fakeenc.backward()
                fber = errors_ber(foutput, u)

                errDec = errDec_fakeenc
                #errDec.backward()

                ber = fber.item()
                optimizerDec.step()

            #######################################################################
            # Train Encoder+Generator, minimize
            # Encoder should encode real+fake images
            #######################################################################

            for run in range(args.num_train_Enc):
                netE.zero_grad()
                netG.zero_grad()
Ejemplo n.º 5
0
                    # noisy
                    errDec_fakeenc = criterion(foutput, u_message)
                    errDec_fakeenc.backward()

                    # noiseless
                    errDec_fakeenc_noiseless = criterion(
                        foutput_noiseless, u_message
                    )  # both noiseless and noisy decoding should be there.
                    errDec_fakeenc_noiseless.backward()

                    errDec = errDec_fakeenc

                    optimizerDec.step()

                    fber = errors_ber(foutput, u_message)
                    ber = fber.item()

                    fber_noiseless = errors_ber(foutput_noiseless, u_message)
                    ber_noiseless = fber_noiseless.item()

                #######################################################################
                # Train Encoder+Generator, minimize
                # Encoder should encode real+fake images
                #######################################################################

                for run in range(args.num_train_Enc):
                    netE.zero_grad()

                    u = torch.randint(0,
                                      2, (args.batch_size, args.code_rate_k,
Ejemplo n.º 6
0
def ftae_test(model, args, use_cuda=False):

    device = torch.device("cuda" if use_cuda else "cpu")
    model.eval()

    # Precomputes Norm Statistics.
    if args.precompute_norm_stats:
        num_test_batch = int(args.num_block / (args.batch_size) *
                             args.test_ratio)
        for batch_idx in range(num_test_batch):
            X_test = torch.randint(
                0,
                2, (args.batch_size, args.block_len, args.code_rate_k),
                dtype=torch.float)
            X_test = X_test.to(device)
            _ = model.enc(X_test)
        print('Pre-computed norm statistics mean ', model.enc.mean_scalar,
              'std ', model.enc.std_scalar)

    ber_res, bler_res = [], []
    snr_interval = (args.snr_test_end -
                    args.snr_test_start) * 1.0 / (args.snr_points - 1)
    snrs = [
        snr_interval * item + args.snr_test_start
        for item in range(args.snr_points)
    ]
    print('SNRS', snrs)
    sigmas = snrs

    for sigma, this_snr in zip(sigmas, snrs):
        test_ber, test_bler = .0, .0
        with torch.no_grad():
            num_test_batch = int(args.num_block / (args.batch_size) *
                                 args.test_ratio)
            for batch_idx in range(num_test_batch):
                X_test = torch.randint(
                    0,
                    2, (args.batch_size, args.block_len, args.code_rate_k),
                    dtype=torch.float)
                fwd_noise = generate_noise(X_test.shape,
                                           args,
                                           test_sigma=sigma)

                X_test, fwd_noise = X_test.to(device), fwd_noise.to(device)

                X_hat_test, the_codes = model(X_test, fwd_noise)

                test_ber += errors_ber(X_hat_test, X_test)
                test_bler += errors_bler(X_hat_test, X_test)

                if batch_idx == 0:
                    test_pos_ber = errors_ber_pos(X_hat_test, X_test)
                    codes_power = code_power(the_codes)
                else:
                    test_pos_ber += errors_ber_pos(X_hat_test, X_test)
                    codes_power += code_power(the_codes)

            if args.print_pos_power:
                print('code power', codes_power / num_test_batch)
            if args.print_pos_ber:
                print('positional ber', test_pos_ber / num_test_batch)

        test_ber /= num_test_batch
        test_bler /= num_test_batch
        print('Test SNR', this_snr, 'with ber ', float(test_ber), 'with bler',
              float(test_bler))
        ber_res.append(float(test_ber))
        bler_res.append(float(test_bler))

    print('final results on SNRs ', snrs)
    print('BER', ber_res)
    print('BLER', bler_res)

    # compute adjusted SNR. (some quantization might make power!=1.0)
    enc_power = 0.0
    with torch.no_grad():
        for idx in range(num_test_batch):
            X_test = torch.randint(
                0,
                2, (args.batch_size, args.block_len, args.code_rate_k),
                dtype=torch.float)
            X_test = X_test.to(device)
            X_code = model.enc(X_test)
            enc_power += torch.std(X_code)
    enc_power /= float(num_test_batch)
    print('encoder power is', enc_power)
    adj_snrs = [snr_sigma2db(snr_db2sigma(item) / enc_power) for item in snrs]
    print('adjusted SNR should be', adj_snrs)
Ejemplo n.º 7
0
def test(model, args, block_len='default', use_cuda=False):

    device = torch.device("cuda" if use_cuda else "cpu")
    model.eval()

    if block_len == 'default':
        block_len = args.block_len
    else:
        pass

    # Precomputes Norm Statistics.
    if args.precompute_norm_stats:
        with torch.no_grad():
            num_test_batch = int(args.num_block / (args.batch_size) *
                                 args.test_ratio)
            for batch_idx in range(num_test_batch):
                X_test = torch.randint(
                    0,
                    2, (args.batch_size, block_len, args.code_rate_k),
                    dtype=torch.float)
                X_test = X_test.to(device)
                _ = model.enc(X_test)
            print('Pre-computed norm statistics mean ', model.enc.mean_scalar,
                  'std ', model.enc.std_scalar)

    ber_res, bler_res = [], []
    ber_res_punc, bler_res_punc = [], []
    snr_interval = (args.snr_test_end -
                    args.snr_test_start) * 1.0 / (args.snr_points - 1)
    snrs = [
        snr_interval * item + args.snr_test_start
        for item in range(args.snr_points)
    ]
    print('SNRS', snrs)
    sigmas = snrs

    for sigma, this_snr in zip(sigmas, snrs):
        test_ber, test_bler = .0, .0
        with torch.no_grad():
            num_test_batch = int(args.num_block / (args.batch_size))
            for batch_idx in range(num_test_batch):
                X_test = torch.randint(
                    0,
                    2, (args.batch_size, block_len, args.code_rate_k),
                    dtype=torch.float)
                noise_shape = (args.batch_size,
                               int(args.block_len * args.code_rate_n /
                                   args.mod_rate), args.mod_rate)
                fwd_noise = generate_noise(noise_shape, args, test_sigma=sigma)

                X_test, fwd_noise = X_test.to(device), fwd_noise.to(device)

                X_hat_test, the_codes = model(X_test, fwd_noise)

                test_ber += errors_ber(X_hat_test, X_test)
                test_bler += errors_bler(X_hat_test, X_test)

                if batch_idx == 0:
                    test_pos_ber = errors_ber_pos(X_hat_test, X_test)
                    codes_power = code_power(the_codes)
                else:
                    test_pos_ber += errors_ber_pos(X_hat_test, X_test)
                    codes_power += code_power(the_codes)

            if args.print_pos_power:
                print('code power', codes_power / num_test_batch)
            if args.print_pos_ber:
                res_pos = test_pos_ber / num_test_batch
                res_pos_arg = np.array(res_pos.cpu()).argsort()[::-1]
                res_pos_arg = res_pos_arg.tolist()
                print('positional ber', res_pos)
                print('positional argmax', res_pos_arg)
            try:
                test_ber_punc, test_bler_punc = .0, .0
                for batch_idx in range(num_test_batch):
                    X_test = torch.randint(
                        0,
                        2, (args.batch_size, block_len, args.code_rate_k),
                        dtype=torch.float)
                    noise_shape = (args.batch_size,
                                   int(args.block_len * args.code_rate_n /
                                       args.mod_rate), args.mod_rate)
                    fwd_noise = generate_noise(noise_shape,
                                               args,
                                               test_sigma=sigma)
                    X_test, fwd_noise = X_test.to(device), fwd_noise.to(device)

                    X_hat_test, the_codes = model(X_test, fwd_noise)

                    test_ber_punc += errors_ber(
                        X_hat_test,
                        X_test,
                        positions=res_pos_arg[:args.num_ber_puncture])
                    test_bler_punc += errors_bler(
                        X_hat_test,
                        X_test,
                        positions=res_pos_arg[:args.num_ber_puncture])

                    if batch_idx == 0:
                        test_pos_ber = errors_ber_pos(X_hat_test, X_test)
                        codes_power = code_power(the_codes)
                    else:
                        test_pos_ber += errors_ber_pos(X_hat_test, X_test)
                        codes_power += code_power(the_codes)
            except:
                print('no pos BER specified.')

        test_ber /= num_test_batch
        test_bler /= num_test_batch
        print('Test SNR', this_snr, 'with ber ', float(test_ber), 'with bler',
              float(test_bler))
        ber_res.append(float(test_ber))
        bler_res.append(float(test_bler))

        try:
            test_ber_punc /= num_test_batch
            test_bler_punc /= num_test_batch
            print('Punctured Test SNR', this_snr, 'with ber ',
                  float(test_ber_punc), 'with bler', float(test_bler_punc))
            ber_res_punc.append(float(test_ber_punc))
            bler_res_punc.append(float(test_bler_punc))
        except:
            print('No puncturation is there.')

    print('final results on SNRs ', snrs)
    print('BER', ber_res)
    print('BLER', bler_res)
    print('final results on punctured SNRs ', snrs)
    print('BER', ber_res_punc)
    print('BLER', bler_res_punc)

    # compute adjusted SNR. (some quantization might make power!=1.0)
    enc_power = 0.0
    with torch.no_grad():
        for idx in range(num_test_batch):
            X_test = torch.randint(
                0,
                2, (args.batch_size, block_len, args.code_rate_k),
                dtype=torch.float)
            X_test = X_test.to(device)
            X_code = model.enc(X_test)
            enc_power += torch.std(X_code)
    enc_power /= float(num_test_batch)
    print('encoder power is', enc_power)
    adj_snrs = [snr_sigma2db(snr_db2sigma(item) / enc_power) for item in snrs]
    print('adjusted SNR should be', adj_snrs)
Ejemplo n.º 8
0
            errED_2 = errED_fake2 - errED_fake1

            errED = (errED_1 + errED_2) / 2.0

            optimizerED.step()

            ########## Train decoder, maximize decinfoloss ################################
            netDec.zero_grad()
            # forward pass fake encoded images
            #fake_enc = netE(fake_img.detach(),u)
            # add noise to image
            nfake_enc = channel(fake_enc.detach(), args.noise)
            foutput = netDec(nfake_enc)
            errDec_fakeenc = criterion(foutput,u)
            errDec_fakeenc.backward()
            fber = errors_ber(foutput, u)

            #forward pass real encoded images
            # add noise to enc image
            nenc_img = channel(enc_img.detach(), args.noise)
            routput = netDec(nenc_img)
            errDec_realenc = criterion(routput,u)
            errDec_realenc.backward()
            rber  = errors_ber(routput, u)

            errDec = (errDec_fakeenc + errDec_realenc)/2.0
            #errDec.backward()

            ber = (fber.item() + rber.item())/2.0

            optimizerDec.step()
Ejemplo n.º 9
0
                        dec_loss.backward(retain_graph=True)
                        optimizer_Dec.step()

                    #train the enc generator now
                    encg_loss = args.lambda_I*(MSELoss(enc_gan_imgs,gen_imgs)+MSELoss(enc_real_imgs,real_imgs))/2 + \
                            args.lambda_G*(BCELoss(encdiscriminator(real_imgs),valid)+BCELoss(encdiscriminator(enc_gan_imgs),fake)+BCELoss(encdiscriminator(enc_real_imgs),fake))/3 + \
                            (1-args.lambda_I-args.lambda_G)*((BCELoss(decoder(enc_gan_imgs), u) + BCELoss(decoder(enc_real_imgs), u))/2)

                    encg_loss.backward(retain_graph=True)
                    optimizer_ganD.step()

                #calculate ber loss
                decoded_info_1 = decoded_info_1.detach()
                decoded_info_2 = decoded_info_2.detach()
                u = u.detach()
                this_ber = (errors_ber(decoded_info_1, u) +
                            errors_ber(decoded_info_2, u)) / 2

                if i % 100 == 0:

                    print(
                        "[Epoch %d/%d] [Batch %d/%d] [Gan D loss: %f] [Gan G loss: %f] [Enc D loss: %f] [Enc G Loss: %f] [batch Dec BER: %f]"
                        % (epoch, args.num_epoch, i, len(train_dataloader),
                           gand_loss.item(), gang_loss.item(),
                           encd_loss.item(), encg_loss.item(), this_ber))

                batches_done = epoch * len(train_dataloader) + i
                #saving log
                if batches_done == 0:
                    filewriter.writerow([
                        'Batchnumber', 'Gan D loss', 'Dec loss', 'Gan G loss',
Ejemplo n.º 10
0
                noisy_encoded = channel(encoded, (j * args.awgn), args)
                #save this encoded image
                save_image(encoded.data,
                           'images/test/' + args.model_id +
                           '%d_encoded.png' % i,
                           normalize=True,
                           pad_value=1.0)
                save_image(noisy_encoded.data,
                           'images/test/' + args.model_id + '%d_noisy.png' % i,
                           normalize=True,
                           pad_value=1.0)

                #decode message from encoded
                decoded_u = netDec(encoded)
                noisy_decoded = netDec(noisy_encoded)
                ber = errors_ber(decoded_u, u)
                nber = errors_ber(noisy_decoded, u)
                print(ber.item(), nber.item())
                #find in entire batch
                for i in range(len(decoded_u)):

                    #convert u to bits array
                    single_u = torch.round(decoded_u[i]).view(-1)
                    noisysingle_u = torch.round(noisy_decoded[i]).view(-1)
                    #print(single_u.size())
                    #single_u = torch.round(torch.sum(torch.round(decoded_u[i]), dim=0) / 3.0).view(-1)
                    single_u = torch.round(
                        torch.sum(decoded_u[i], dim=0) / 3.0).view(-1)
                    #print(single_u.size())

                    print('this is something', errors_ber(single_u, u[i]))
Ejemplo n.º 11
0
            netDec.zero_grad()

            u = torch.randint(0,
                              2, (b_size, ud, im_u, im_u),
                              dtype=torch.float,
                              device=device)

            #forward pass real encoded images
            # add noise to enc image
            enc_img = netE(real_cpu, u)
            #nenc_img = channel(enc_img.detach(), args.awgn, args)
            routput = netDec(enc_img)
            errDec_realenc = criterion(routput, u)
            errDec_realenc.backward()
            rber = errors_ber(routput, u)

            errDec = errDec_realenc
            #errDec.backward()

            ber = rber.item()

            optimizerDec.step()

            #######################################################################
            # Train Encoder, minimize
            # Encoder should encode real+fake images
            #######################################################################
            netE.zero_grad()

            u = torch.randint(0,