示例#1
0
    real_ber, real_mse = [], []
    for snr in noise_ratio:
        real_this_ber, real_this_mse = 0.0, 0.0
        for epoch in range(num_test_epochs):
            with torch.no_grad():
                u_message = torch.randint(0,
                                          2,
                                          (args.batch_size, args.code_rate_k,
                                           args.img_size, args.img_size),
                                          dtype=torch.float,
                                          device=device)
                x_message = netCCE(u_message, real_cpu)

                real_enc = netIE(x_message, real_cpu)
                noisy_fake = channel_test(real_enc, snr, args, mode='awgn')

                output_t = netIDec(noisy_fake)
                output = netCDec(noisy_fake, output_t)

                real_this_ber += errors_ber(output, u_message).item()
                real_this_mse += MSE_loss(x_message, output_t).item() / (
                    args.img_channels * args.img_size**2)

        real_avg_ber = real_this_ber / num_test_epochs
        real_avg_mse = real_this_mse / num_test_epochs
        real_ber.append(real_avg_ber)
        real_mse.append(real_avg_mse)

        print('AWGN ber for snr : %.2f \t is %.4f (ber) \t is %.4f (MSE)' %
              (snr, real_avg_ber, real_avg_mse))
示例#2
0
    # sample a batch of real image

    real_ber, fake_ber = [], []
    for snr in noise_ratio:
        real_this_ber, fake_this_ber = 0.0, 0.0
        for epoch in range(num_test_epochs):
            with torch.no_grad():

                noise = torch.randn(args.batch_size, nz, 1, 1, device=device)
                u = torch.randint(0,
                                  2, (args.batch_size, ud, im_u, im_u),
                                  device=device)
                fake = netG(noise)

                fake_enc = netE(fake, u)
                noisy_fake = channel_test(fake_enc, snr, args, mode='awgn')
                output = netDec(noisy_fake)
                fake_this_ber += errors_ber(output, u).item()

                real_enc = netE(real_cpu, u)
                noisy_real = channel_test(fake_enc, snr, args, mode='awgn')
                output = netDec(noisy_real)
                real_this_ber += errors_ber(output, u).item()

        real_avg_ber = real_this_ber / num_test_epochs
        fake_avg_ber = fake_this_ber / num_test_epochs
        real_ber.append(real_avg_ber)
        fake_ber.append(fake_avg_ber)

        print('AWGN ber for snr : %.2f \t is %.4f (real) and  %.4f (fake)' %
              (snr, real_avg_ber, fake_avg_ber))
示例#3
0
        for epoch in range(args.num_epoch):
            # For each batch in the dataloader
            for i, data in enumerate(dataloader, 0):
                # Format batch
                real_cpu = data[0].to(device)
                if real_cpu.shape[0] != args.batch_size:
                    print('batch size mismatch!')
                    continue

                #######################################################################
                # Train enc discriminator (D2), maximize log(D(x)) + log(1 - D(G(z))),
                # Idea is to discriminate encoded image/non-info image.
                #######################################################################
                for run in range(args.num_train_D2):
                    netED.zero_grad()
                    nreal_cpu = channel_test(real_cpu, args)
                    routput = netED(real_cpu).view(-1)
                    errED_real1 = BCE_loss(routput, labelr)
                    errED_real1.backward()

                    if args.use_cce:
                        u_message = torch.randint(
                            0,
                            2, (args.batch_size, args.code_rate_k,
                                args.img_size, args.img_size),
                            dtype=torch.float,
                            device=device)
                        x_message = netCCE(u_message)
                    else:  # pre-train the whole channel, no cce. co
                        u_message = torch.randint(
                            0,