Пример #1
0
    discriminator = discriminator.cuda()
    criterion_cycle.cuda()

if opt.is_print:
    print_network(generator, 'Generator')
    print_network(discriminator, 'Discriminator')

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

if opt.epoch != 0:
    # Load pre-trained models
    generator.load_state_dict(torch.load("saved_models/generator_%d.pth" % opt.epoch))
    discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth" % opt.epoch))
else:
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
'''
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G,
                                                   lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D,
                                                   lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
'''
# Configure transforms
train_transforms = [
    transforms.Resize(int(1.12 * opt.img_height), Image.BICUBIC),
    transforms.CenterCrop(opt.img_height),
Пример #2
0
def main():
    cuda = torch.cuda.is_available()

    input_shape = (opt.channels, opt.img_height, opt.img_width)

    # Initialize generator and discriminator
    G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
    G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
    D_A = Discriminator(input_shape)
    D_B = Discriminator(input_shape)

    if cuda:
        G_AB = G_AB.cuda()
        G_BA = G_BA.cuda()
        D_A = D_A.cuda()
        D_B = D_B.cuda()
        criterion_GAN.cuda()
        criterion_cycle.cuda()
        criterion_identity.cuda()

    if opt.epoch != 0:
        # Load pretrained models
        G_AB.load_state_dict(
            torch.load("saved_models/%s/G_AB_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        G_BA.load_state_dict(
            torch.load("saved_models/%s/G_BA_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_A.load_state_dict(
            torch.load("saved_models/%s/D_A_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
        D_B.load_state_dict(
            torch.load("saved_models/%s/D_B_%d.pth" %
                       (opt.dataset_name, opt.epoch)))
    else:
        # Initialize weights
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                                   G_BA.parameters()),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))
    optimizer_D_B = torch.optim.Adam(D_B.parameters(),
                                     lr=opt.lr,
                                     betas=(opt.b1, opt.b2))

    # Learning rate update schedulers
    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

    # Buffers of previously generated samples
    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Image transformations
    transforms_ = [
        transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
        transforms.RandomCrop((opt.img_height, opt.img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]

    # Training data loader
    dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )
    # Test data loader
    val_dataloader = DataLoader(
        ImageDataset("../../data/%s" % opt.dataset_name,
                     transforms_=transforms_,
                     unaligned=True,
                     mode="test"),
        batch_size=5,
        shuffle=True,
        num_workers=1,
    )

    def sample_images(batches_done):
        """Saves a generated sample from the test set"""
        imgs = next(iter(val_dataloader))
        G_AB.eval()
        G_BA.eval()
        real_A = Variable(imgs["A"].type(Tensor))
        fake_B = G_AB(real_A)
        real_B = Variable(imgs["B"].type(Tensor))
        fake_A = G_BA(real_B)
        # Arange images along x-axis
        real_A = make_grid(real_A, nrow=5, normalize=True)
        real_B = make_grid(real_B, nrow=5, normalize=True)
        fake_A = make_grid(fake_A, nrow=5, normalize=True)
        fake_B = make_grid(fake_B, nrow=5, normalize=True)
        # Arange images along y-axis
        image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
        save_image(image_grid,
                   "images/%s/%s.png" % (opt.dataset_name, batches_done),
                   normalize=False)

    # ----------
    #  Training
    # ----------
    prev_time = time.time()
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(
                np.ones((real_A.size(0), *D_A.output_shape))),
                             requires_grad=False)
            fake = Variable(Tensor(
                np.zeros((real_A.size(0), *D_A.output_shape))),
                            requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()

            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity

            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------

            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_GAN.item(),
                    loss_cycle.item(),
                    loss_identity.item(),
                    time_left,
                ))

            # If at sample interval save image
            if batches_done % opt.sample_interval == 0:
                sample_images(batches_done)

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(
                G_AB.state_dict(),
                "saved_models/%s/G_AB_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                G_BA.state_dict(),
                "saved_models/%s/G_BA_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_A.state_dict(),
                "saved_models/%s/D_A_%d.pth" % (opt.dataset_name, epoch))
            torch.save(
                D_B.state_dict(),
                "saved_models/%s/D_B_%d.pth" % (opt.dataset_name, epoch))
Пример #3
0
class Trainer():
    def __init__(self, opt):
        self.config = opt

        # Make output dirs
        os.makedirs('saved_models/%s' % (opt.model_name), exist_ok=True)
        os.makedirs('images/%s' % (opt.model_name), exist_ok=True)

        self.cuda = opt.gpu_id > -1

        # Gs and Ds
        self.G_AB = GeneratorResNet(res_blocks=opt.n_residual_blocks)
        self.G_BA = GeneratorResNet(res_blocks=opt.n_residual_blocks)
        if opt.large_patch:
            self.D_A = LargePatchDiscriminator()
            self.D_B = LargePatchDiscriminator()
        else:
            self.D_A = Discriminator()
            self.D_B = Discriminator()

        # Patch
        if opt.large_patch:
            self.patch = (1, 64, 64)
        else:
            self.patch = (1, 16, 16)

        # Weight init
        self.G_AB.apply(weights_init_normal)
        self.G_BA.apply(weights_init_normal)
        self.D_A.apply(weights_init_normal)
        self.D_B.apply(weights_init_normal)

        # Loss
        self.criterion_GAN = torch.nn.MSELoss()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_identity = torch.nn.L1Loss()

        if self.cuda:
            self.G_AB = self.G_AB.cuda()
            self.G_BA = self.G_BA.cuda()
            self.D_A = self.D_A.cuda()
            self.D_B = self.D_B.cuda()
            self.criterion_GAN = self.criterion_GAN.cuda()
            self.criterion_cycle = self.criterion_cycle.cuda()
            self.criterion_identity = self.criterion_identity.cuda()

        # Optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.G_AB.parameters(), self.G_BA.parameters()),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))
        self.optimizer_D = torch.optim.Adam(self.D_A.parameters(),
                                            lr=opt.lr,
                                            betas=(opt.b1, opt.b2))
        # Learning rate update schedulers
        self.lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step)
        self.lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=LambdaLR(opt.n_epochs, 0, opt.decay_epoch).step)
        self.Tensor = torch.cuda.FloatTensor if self.cuda else torch.Tensor
        # Loss weights
        self.lambda_cyc = 10
        self.lambda_id = opt.lambda_id * self.lambda_cyc

        # Buffers of previously generated samples
        self.fake_A_buffer = ReplayBuffer()
        self.fake_B_buffer = ReplayBuffer()

        # Image transformations
        A_transforms_ = [
            transforms.CenterCrop((178, 178)),
            transforms.Resize((300, 300)),
            transforms.RandomCrop((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=opt.rotate_degree,
                                    fillcolor=(255, 255, 255)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        B_transforms_ = [
            transforms.Resize((360, 360)),
            transforms.RandomCrop((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=opt.rotate_degree,
                                    fillcolor=(255, 255, 255)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
        # Training data loader
        self.train_dataloader = DataLoader(
            ImageDataset("./data/",
                         A_transforms_=A_transforms_,
                         B_transforms_=B_transforms_),
            batch_size=1,
            shuffle=True,
        )
        # Test data loader
        self.val_dataloader = DataLoader(ImageDataset(
            "./data/",
            A_transforms_=A_transforms_,
            B_transforms_=B_transforms_,
            mode='test'),
                                         batch_size=1)

    def train_epoch(self, epoch):
        prev_time = time.time()
        for i, batch in enumerate(self.train_dataloader):

            # Model input
            real_A = Variable(batch['A'].type(self.Tensor))
            real_B = Variable(batch['B'].type(self.Tensor))

            # Adversarial ground truths

            valid = Variable(self.Tensor(np.ones(
                (real_A.size(0), *self.patch))),
                             requires_grad=False)
            fake = Variable(self.Tensor(np.zeros(
                (real_A.size(0), *self.patch))),
                            requires_grad=False)

            #  Train Generators

            self.optimizer_G.zero_grad()

            # GAN loss
            fake_B = self.G_AB(real_A)
            loss_GAN_AB = self.criterion_GAN(self.D_B(fake_B), valid)
            fake_A = self.G_BA(real_B)
            loss_GAN_BA = self.criterion_GAN(self.D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = self.G_BA(fake_B)
            loss_cycle_A = self.criterion_cycle(recov_A, real_A)
            recov_B = self.G_AB(fake_A)
            loss_cycle_B = self.criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Identity loss

            loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A)
            loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B)
            loss_identity = (loss_id_A + loss_id_B) / 2

            # Total loss
            loss_G = loss_GAN + self.lambda_cyc * loss_cycle + self.lambda_id * loss_identity
            loss_G.backward()
            self.optimizer_G.step()

            #  Train Discriminator

            self.optimizer_D.zero_grad()

            # Real loss
            loss_real = self.criterion_GAN(self.D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = self.fake_A_buffer.push_and_pop(fake_A)
            loss_fake = self.criterion_GAN(self.D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            self.optimizer_D.step()

            self.optimizer_D.zero_grad()
            loss_real = self.criterion_GAN(self.D_B(real_B), valid)
            fake_B_ = self.fake_B_buffer.push_and_pop(fake_B)
            loss_fake = self.criterion_GAN(self.D_B(fake_B_.detach()), fake)
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            self.optimizer_D.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # Determine approximate time left
            batches_done = epoch * len(self.train_dataloader) + i
            batches_left = self.config.n_epochs * len(
                self.train_dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                % (epoch, self.config.n_epochs, i, len(self.train_dataloader),
                   loss_D.item(), loss_G.item(), loss_GAN.item(),
                   loss_cycle.item(), loss_identity.item(), time_left))

            if batches_done % self.config.sample_interval == 0:
                # Sample a picture
                imgs = next(iter(self.val_dataloader))
                real_A = Variable(imgs['A'].type(self.Tensor))
                fake_B = self.G_AB(real_A)
                real_B = Variable(imgs['B'].type(self.Tensor))
                fake_A = self.G_BA(real_B)
                img_sample = torch.cat(
                    (real_A.data, fake_B.data, real_B.data, fake_A.data), 0)
                save_image(img_sample,
                           'images/%s/%s.png' %
                           (self.config.model_name, batches_done),
                           nrow=4,
                           normalize=True)

        self.lr_scheduler_G.step()
        self.lr_scheduler_D.step()

        if self.config.checkpoint_interval != -1 and epoch % self.config.checkpoint_interval == 0:
            torch.save(
                self.G_AB.state_dict(), 'saved_models/%s/G_AB_%d.pth' %
                (self.config.model_name, epoch))
Пример #4
0
if opt.epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(
        torch.load('saved_models/%s/G_AB_%d.pth' %
                   (opt.model_name, opt.epoch)))
    G_BA.load_state_dict(
        torch.load('saved_models/%s/G_BA_%d.pth' %
                   (opt.model_name, opt.epoch)))
    D_A.load_state_dict(
        torch.load('saved_models/%s/D_A_%d.pth' % (opt.model_name, opt.epoch)))
    D_B.load_state_dict(
        torch.load('saved_models/%s/D_B_%d.pth' % (opt.model_name, opt.epoch)))
else:
    # Initialize weights
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

# Loss weights
lambda_cyc = 10
lambda_id = opt.lambda_id * lambda_cyc

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(),
                                               G_BA.parameters()),
                               lr=opt.lr,
                               betas=(opt.b1, opt.b2))
optimizer_D_A = torch.optim.Adam(D_A.parameters(),
                                 lr=opt.lr,