Example #1
0
def get_GAN_AB_model(folder_model, model_name, device):          
    n_residual_blocks = 9 # this should be the same values used in training the G_AB model    
    G_AB = GeneratorResNet(input_shape=(3,0), num_residual_blocks = n_residual_blocks)        
    G_AB.load_state_dict(torch.load(folder_model + model_name,  map_location=device ),  )    
    
    if cuda: 
        G_AB = G_AB.to(device)
    return G_AB
Example #2
0
def get_generator_model():
    generator = GeneratorResNet(img_shape=img_shape,
                                res_blocks=residual_blocks,
                                c_dim=c_dim)
    generator.load_state_dict(
        torch.load(PATH_G, map_location=torch.device('cpu')))
    generator.eval()
    return generator
opt = parser.parse_args()

SCALE_FACTOR = opt.scale_factor
MODEL_NAME = opt.model_name
hr_shape = (opt.hr_height, opt.hr_width)

results = {'Test': {'psnr': [], 'ssim': []}}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

generator = GeneratorResNet()
generator = nn.DataParallel(generator, device_ids=[0, 1, 2])
generator.to(device)

# generator.load_state_dict(torch.load("saved_models/generator_%d_%d.pth" % (4,99)))
generator.load_state_dict(torch.load("saved_models/" + MODEL_NAME))
generator.eval()

test_dataloader = DataLoader(
    TestImageDataset("../My_dataset/single_channel_100000/%s" %
                     opt.test_dataset_name,
                     hr_shape=hr_shape,
                     scale_factor=opt.scale_factor),  # change
    batch_size=1,
    shuffle=False,
    num_workers=opt.n_cpu,
)

test_bar = tqdm(test_dataloader, desc='[testing datasets]')

test_out_path = 'testing_results/SRF_' + str(SCALE_FACTOR) + '/'
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_cycle.cuda()

if opt.is_print:
    print_network(generator, 'Generator')
    print_network(discriminator, 'Discriminator')

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

if opt.epoch != 0:
    # Load pre-trained models
    generator.load_state_dict(torch.load("saved_models/generator_%d.pth" % opt.epoch))
    discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth" % opt.epoch))
else:
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
'''
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                   lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D,
                                                   lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
'''
# Configure transforms
Example #5
0
parser = argparse.ArgumentParser()
parser.add_argument('--check_point', type=str, default='saved_models/G_AB_10.pth',
                    help='check point from which load trained model')
parser.add_argument('--batch_size', type=int, default=1, help='size of the batches')
parser.add_argument('--A_file', type=str, default='test.png', help='path of the data')
parser.add_argument('--img_height', type=int, default=256, help='size of image height')
parser.add_argument('--img_width', type=int, default=256, help='size of image width')
parser.add_argument('--gpu_id', type=int, default=-1, help='GPU id')
opt = parser.parse_args()
cuda = opt.gpu_id > -1

# # Load pretrained model G_AB
G_AB = GeneratorResNet()
if cuda:
    G_AB = G_AB.cuda()
G_AB.load_state_dict(torch.load(opt.check_point))

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Image transformations
transforms_ = [transforms.Resize((opt.img_height, opt.img_width)),
               transforms.ToTensor(),
               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
img_transformer = transforms.Compose(transforms_)

# Test data

img = img_transformer(Image.open(opt.A_file).convert("RGB"))
real_A = Variable(img.reshape(1, 3, opt.img_height, opt.img_width).type(Tensor))

img_sample = G_AB(real_A)
Example #6
0
def main():
    cuda = torch.cuda.is_available()

    input_shape = (opt.channels, opt.img_height, opt.img_width)

    # Initialize generator and discriminator
    G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
    G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
    D_A = Discriminator(input_shape)
    D_B = Discriminator(input_shape)

    if cuda:
        G_AB = G_AB.cuda()
        G_BA = G_BA.cuda()
        D_A = D_A.cuda()
        D_B = D_B.cuda()
        criterion_GAN.cuda()
        criterion_cycle.cuda()
        criterion_identity.cuda()

    if opt.epoch != 0:
        # Load pretrained models
        G_AB.load_state_dict(
            torch.load("saved_models/%s/G_AB_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        G_BA.load_state_dict(
            torch.load("saved_models/%s/G_BA_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_A.load_state_dict(
            torch.load("saved_models/%s/D_A_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_B.load_state_dict(
            torch.load("saved_models/%s/D_B_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
    else:
        # Initialize weights
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                                   G_BA.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))

    # Learning rate update schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

    # Buffers of previously generated samples
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Image transformations
    transforms_ = [
        transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
        transforms.RandomCrop((opt.img_height, opt.img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]

    # Training data loader
    dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )
    # Test data loader
    val_dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True,
                     mode="test"),
        batch_size=5,
        shuffle=True,
        num_workers=1,
    )

    def sample_images(batches_done):
        """Saves a generated sample from the test set"""
        imgs = next(iter(val_dataloader))
        G_AB.eval()
        G_BA.eval()
        real_A = Variable(imgs["A"].type(Tensor))
        fake_B = G_AB(real_A)
        real_B = Variable(imgs["B"].type(Tensor))
        fake_A = G_BA(real_B)
        # Arange images along x-axis
        real_A = make_grid(real_A, nrow=5, normalize=True)
        real_B = make_grid(real_B, nrow=5, normalize=True)
        fake_A = make_grid(fake_A, nrow=5, normalize=True)
        fake_B = make_grid(fake_B, nrow=5, normalize=True)
        # Arange images along y-axis
        image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
        save_image(image_grid,
                   "images/%s/%s.png" % (opt.dataset_name, batches_done),
                   normalize=False)

    # ----------
    #  Training
    # ----------
    prev_time = time.time()
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(
                np.ones((real_A.size(0), *D_A.output_shape))),
                             requires_grad=False)
            fake = Variable(Tensor(
                np.zeros((real_A.size(0), *D_A.output_shape))),
                            requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()

            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------

            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_GAN.item(),
                    loss_cycle.item(),
                    loss_identity.item(),
                    time_left,
                ))

            # If at sample interval save image
            if batches_done % opt.sample_interval == 0:
                sample_images(batches_done)

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(
                G_AB.state_dict(),
                "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                G_BA.state_dict(),
                "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_A.state_dict(),
                "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_B.state_dict(),
                "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
Example #7
0
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

if opt.epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(
        torch.load('saved_models/%s/G_AB_%d.pth' %
                   (opt.model_name, opt.epoch)))
    G_BA.load_state_dict(
        torch.load('saved_models/%s/G_BA_%d.pth' %
                   (opt.model_name, opt.epoch)))
    D_A.load_state_dict(
        torch.load('saved_models/%s/D_A_%d.pth' % (opt.model_name, opt.epoch)))
    D_B.load_state_dict(
        torch.load('saved_models/%s/D_B_%d.pth' % (opt.model_name, opt.epoch)))
else:
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)
 def _define_generator(self, path_to_model):
     gen = GeneratorResNet()
     gen.load_state_dict(torch.load(path_to_model))
     return gen
Example #9
0
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

if opt.epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(
        torch.load("saved_models/%s/G_AB_%d.pth" %
                   (opt.dataset_name, opt.epoch)))
    G_BA.load_state_dict(
        torch.load("saved_models/%s/G_BA_%d.pth" %
                   (opt.dataset_name, opt.epoch)))
    D_A.load_state_dict(
        torch.load("saved_models/%s/D_A_%d.pth" %
                   (opt.dataset_name, opt.epoch)))
    D_B.load_state_dict(
        torch.load("saved_models/%s/D_B_%d.pth" %
                   (opt.dataset_name, opt.epoch)))
else:
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)