示例#1
0
class SAGAN_test(object):
    def __init__(self, data_loader):

        # Data loader
        self.data_loader = data_loader
        self.labels_dict = {
            0: 'airplane',
            1: 'automobile',
            2: 'bird',
            3: 'cat',
            4: 'deer',
            5: 'dog',
            6: 'frog',
            7: 'horse',
            8: 'ship',
            9: 'truck'
        }
        # exact model and loss
        self.model = model
        self.adv_loss = adv_loss

        # Model hyper-parameters
        self.imsize = imsize
        self.g_num = g_num
        self.z_dim = z_dim
        self.g_conv_dim = g_conv_dim
        self.d_conv_dim = d_conv_dim
        self.parallel = parallel

        self.d_iters = d_iters
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.pretrained_model = pretrained_model

        self.dataset = dataset

        self.model_save_path = model_save_path
        self.test_path = test_path

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.fid_model = FID("./log_path",
                             device)  # as a var log path not a string changed

        self.build_model()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def test(self):

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # Data iterator
        data_iter = iter(self.data_loader)

        # Fixed input for debugging
        fixed_z = tensor2var(torch.randn(self.batch_size, 90))  #self.z_dim

        self.G.eval()
        fid_scores = []
        n_batches = 2
        for i in range(n_batches):

            real_images, labels = next(iter(self.data_loader))

            if i == n_batches - 1:
                if self.batch_size <= 10:
                    for l in labels:
                        print(self.labels_dict[l.item()])
                else:
                    print(
                        "Avoiding to print labels since batch size greater than 10"
                    )
            # Compute loss with real images
            real_images = tensor2var(real_images)
            labels = tensor2var(encode(labels))

            z = tensor2var(torch.randn(real_images.size(0), 90))
            fake_images, gf1, gf2 = self.G(z, labels)
            fid_score = self.fid_model.compute_fid(real_images, fake_images)
            fid_scores.append(fid_score)

        fid_score = self.fid_model.compute_fid(real_images, fake_images)
        save_image(denorm(fake_images.data), 'SAGAN_test.png')
        print("Image saved as SAGAN_test.png")
        avg_fid_score = sum(fid_scores) / len(fid_scores)
        print("Average FID_score for SA GAN, for ", n_batches, " is:",
              avg_fid_score)

    def build_model(self):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.g_conv_dim).to(device)

    def build_tensorboard(self):
        return
        #from logger import Logger
        #self.logger = Logger(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))

        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))
示例#2
0
    fixed_p = output_dir + str(epoch + 1) + '.png'

    vutils.save_image(G(save_noise).detach(), fixed_p, normalize=True)

    num_info = {
        'Discriminator loss': torch.mean(torch.FloatTensor(D_losses)),
        'Generator loss': torch.mean(torch.FloatTensor(G_losses))
    }
    fake_to_show = G(save_noise).detach()

    #tensorboard logging
    writer.add_scalars('Loss', num_info, epoch)
    writer.add_image('Fake Samples', fake_to_show[0].cpu())
    train_hist['per_epoch_ptimes'].append(per_epoch_ptime)
    if epoch % 30 == 0:
        fid_score = fid_model.compute_fid(real_image, G_result)
        print("FID score", fid_score)
        writer.add_scalar('FID Score', fid_score, epoch)

end_time = time.time()
total_ptime = end_time - start_time
train_hist['total_ptime'].append(total_ptime)
print("Avg one epoch ptime: %.2f, total %d epochs ptime: %.2f" %
      (torch.mean(torch.FloatTensor(
          train_hist['per_epoch_ptimes'])), train_epoch, total_ptime))
writer.close()

with open(report_dir + 'train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)

show_train_hist(train_hist, save=True, path=report_dir + 'train_hist.png')