Beispiel #1
0
        if i % config.make_img_samples == 0:
            for x in range(5):
                make_img_samples(G)


if __name__ == '__main__':
    dataset = CelebADataset()

    dataloader = get_dataloader(dataset)

    G = Generator(config.latent_size).to(config.device)
    D = Discriminator().to(config.device)

    optim_G = torch.optim.AdamW(G.parameters(),
                                lr=config.lr,
                                betas=(0.5, 0.999))
    optim_D = torch.optim.AdamW(D.parameters(),
                                lr=config.lr,
                                betas=(0.5, 0.999))

    if (config.continue_training):
        G, optim_G, D, optim_D = load_models_with_optims(
            G, optim_G, D, optim_D, config.train_model_path, config.device)
    else:
        G.apply(weights_init)
        D.apply(weights_init)

    criterion = nn.CrossEntropyLoss()

    train(dataloader, D, G, optim_D, optim_G, criterion)
Beispiel #2
0
        nn.init.constant_(m.bias.data, 0)
    
img_size = 64
batch_size = 32
lr = 0.0002
beta1 = 0.5

device = torch.device("cuda:0")


netG = Generator(1).to(device)
netG.apply(weights_init)


netD = Discriminator(1).to(device)
netD.apply(weights_init)

criterion = nn.BCELoss()

# Setup Adam optimizers for both G and D
#optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))# SGD
optimizerD = optim.SGD(netD.parameters(), lr=0.01)
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

print("Data setup")

rgb_wi_gt_data = loader(["../../data/small_potsdam/rgb","../../data/small_potsdam/y"],img_size,batch_size,transformations=[lambda x: x-load.get_mean("../../data/vaihingen/rgb"),rgb_to_binary])
data_wi_gt = rgb_wi_gt_data.generate_patch()


Beispiel #3
0
class VaeGanModule(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        # Encoder
        self.encoder = Encoder(ngf=self.hparams.ngf, z_dim=self.hparams.z_dim)
        self.encoder.apply(weights_init)
        device = "cuda" if isinstance(self.hparams.gpus, int) else "cpu"
        # Decoder
        self.decoder = Decoder(ngf=self.hparams.ngf, z_dim=self.hparams.z_dim)
        self.decoder.apply(weights_init)
        # Discriminator
        self.discriminator = Discriminator()
        self.discriminator.apply(weights_init)

        # Losses
        self.criterionFeat = torch.nn.L1Loss()
        self.criterionGAN = GANLoss(gan_mode="lsgan")

        if self.hparams.use_vgg:
            self.criterion_perceptual_style = [Perceptual_Loss(device)]

    @staticmethod
    def reparameterize(mu, logvar, mode='train'):
        if mode == 'train':
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu

    def discriminate(self, fake_image, real_image):
        input_concat_fake = torch.cat(
            (fake_image.detach(), real_image),
            dim=1)  # non sono sicuro che .detach() sia necessario in lightning
        input_concat_real = torch.cat((real_image, real_image), dim=1)

        return (self.discriminator(input_concat_fake),
                self.discriminator(input_concat_real))

    def training_step(self, batch, batch_idx, optimizer_idx):
        x, _ = batch

        # train VAE
        if optimizer_idx == 0:

            # encode
            mu, log_var = self.encoder(x)
            z_repar = VaeGanModule.reparameterize(mu, log_var)

            # decode
            fake_image = self.decoder(z_repar)

            # reconstruction
            reconstruction_loss = self.criterionFeat(fake_image, x)
            kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) -
                                         log_var.exp())

            # Discriminate
            input_concat_fake = torch.cat((fake_image, x), dim=1)
            pred_fake = self.discriminator(input_concat_fake)

            # Losses
            loss_G_GAN = self.criterionGAN(pred_fake, True)
            if self.hparams.use_vgg:
                loss_G_perceptual = \
                    self.criterion_perceptual_style[0](fake_image, x)
            else:
                loss_G_perceptual = 0.0
            g_loss = (reconstruction_loss *
                      20) + kld_loss + loss_G_GAN + loss_G_perceptual

            # Results are collected in a TrainResult object
            result = pl.TrainResult(g_loss)
            result.log("rec_loss", reconstruction_loss * 10, prog_bar=True)
            result.log("loss_G_GAN", loss_G_GAN, prog_bar=True)
            result.log("kld_loss", kld_loss, prog_bar=True)
            result.log("loss_G_perceptual", loss_G_perceptual, prog_bar=True)

        # train Discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # Encode
            mu, log_var = self.encoder(x)
            z_repar = VaeGanModule.reparameterize(mu, log_var)

            # Decode
            fake_image = self.decoder(z_repar)

            # how well can it label as real?
            pred_fake, pred_real = self.discriminate(fake_image, x)

            # Fake loss
            d_loss_fake = self.criterionGAN(pred_fake, False)

            # Real Loss
            d_loss_real = self.criterionGAN(pred_real, True)

            # Total loss is average of prediction of fakes and reals
            loss_D = (d_loss_fake + d_loss_real) / 2

            # Results are collected in a TrainResult object
            result = pl.TrainResult(loss_D)
            result.log("loss_D_real", d_loss_real, prog_bar=True)
            result.log("loss_D_fake", d_loss_fake, prog_bar=True)

        return result

    def training_epoch_end(self, training_step_outputs):
        z_appr = torch.normal(mean=0,
                              std=1,
                              size=(16, self.hparams.z_dim),
                              device=training_step_outputs[0].minimize.device)

        # Generate images from latent vector
        sample_imgs = self.decoder(z_appr)
        grid = torchvision.utils.make_grid(sample_imgs,
                                           normalize=True,
                                           range=(-1, 1))

        # where to save the image
        path = os.path.join(self.hparams.generated_images_folder,
                            f"generated_images_{self.current_epoch}.png")
        torchvision.utils.save_image(sample_imgs,
                                     path,
                                     normalize=True,
                                     range=(-1, 1))

        # Log images in tensorboard
        self.logger.experiment.add_image(f'generated_images', grid,
                                         self.current_epoch)

        # Epoch level metrics
        epoch_loss = torch.mean(
            torch.stack([x['minimize'] for x in training_step_outputs]))
        results = pl.TrainResult()
        results.log("epoch_loss", epoch_loss, prog_bar=False)

        return results

    def validation_step(self, batch, batch_idx):
        x, _ = batch

        # Encode
        mu, log_var = self.encoder(x)
        z_repar = VaeGanModule.reparameterize(mu, log_var)

        # Decode
        recons = self.decoder(z_repar)
        reconstruction_loss = nn.functional.mse_loss(recons, x)

        # Results are collected in a EvalResult object
        result = pl.EvalResult(checkpoint_on=reconstruction_loss)
        return result

    testing_step = validation_step

    def configure_optimizers(self):
        params_vae = concat_generators(self.encoder.parameters(),
                                       self.decoder.parameters())
        opt_vae = torch.optim.Adam(params_vae,
                                   lr=self.hparams.learning_rate_vae)

        parameters_discriminator = self.discriminator.parameters()
        opt_d = torch.optim.Adam(parameters_discriminator,
                                 lr=self.hparams.learning_rate_d)

        return [opt_vae, opt_d]

    @staticmethod
    def add_argparse_args(parser):

        parser.add_argument('--generated_images_folder',
                            required=False,
                            default="./output",
                            type=str)
        parser.add_argument('--ngf', type=int, default=128)
        parser.add_argument('--z_dim', type=int, default=128)
        parser.add_argument('--learning_rate_vae',
                            default=1e-03,
                            required=False,
                            type=float)
        parser.add_argument('--learning_rate_d',
                            default=1e-03,
                            required=False,
                            type=float)
        parser.add_argument("--use_vgg", action="store_true", default=False)

        return parser
Beispiel #4
0
class AdvGAN_Attack:
    def __init__(self, device, model, image_nc, box_min, box_max):
        output_nc = image_nc
        self.device = device
        self.model = model
        self.input_nc = image_nc
        self.output_nc = output_nc

        self.box_min = box_min
        self.box_max = box_max

        self.gen_input_nc = image_nc
        self.netG = Generator(self.gen_input_nc, image_nc).to(device)
        self.netDisc = Discriminator(image_nc).to(device)

        # initialize all weights
        self.netG.apply(weights_init)
        self.netDisc.apply(weights_init)

        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.001)
        self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(), lr=0.001)

        if not os.path.exists(models_path):
            os.makedirs(models_path)

    def train_batch(self, x, path, alignment):
        """x is the large not cropped face. TODO find a way to associate image with the image it came from (see if we can do it by filename)"""
        # x is the cropped 256x256 to perturb

        # optimize D
        perturbation = self.netG(x)

        # add a clipping trick
        adv_images = torch.clamp(perturbation, -0.3, 0.3) + x
        adv_images = torch.clamp(adv_images, self.box_min, self.box_max)    # 256 x 256

        original_deepfake = y = ... # TODO load image

        # apply the adversarial image
        protected_image = compose(adv_images, path, alignment)     # TODO: Original image size

        for i in range(1):
            self.optimizer_D.zero_grad()
            pred_real = self.netDisc(x)
            loss_D_real = F.mse_loss(
                pred_real, torch.ones_like(pred_real, device=self.device)
            )
            loss_D_real.backward()

            pred_fake = self.netDisc(adv_images.detach())
            loss_D_fake = F.mse_loss(
                pred_fake, torch.zeros_like(pred_fake, device=self.device)
            )
            loss_D_fake.backward()
            loss_D_GAN = loss_D_fake + loss_D_real
            self.optimizer_D.step()

        # optimize G
        for i in range(1):
            self.optimizer_G.zero_grad()

            # cal G's loss in GAN
            pred_fake = self.netDisc(adv_images)
            loss_G_fake = F.mse_loss(
                pred_fake, torch.ones_like(pred_fake, device=self.device)
            )
            loss_G_fake.backward(retain_graph=True)

            # calculate perturbation norm
            C = 0.1
            loss_perturb = torch.mean(
                torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1)
            )

            # 1 - image similarity
            # TODO apply image back to original image
            # ex. perturbed_original = (adv_images patched onto the original image)
            # Clamp it
            # perform face swap with the images

            # Need to see how it affects the 

            y_ = swapfaces(protected_image)
            norm_similarity = torch.abs(torch.dot(torch.norm(y_, 2), torch.norm(original_deepfake, 2)))
            loss_adv = norm_similarity
            loss_adv.backward() # retain graph

            adv_lambda = 10
            pert_lambda = 1
            loss_G = adv_lambda * loss_adv + pert_lambda * loss_perturb
            loss_G.backward()
            self.optimizer_G.step()

        return (
            loss_D_GAN.item(),
            loss_G_fake.item(),
            loss_perturb.item(),
            loss_adv.item(),
        )

    def train(self, train_dataloader, epochs):
        for epoch in range(1, epochs + 1):

            if epoch == 50:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.0001)
                self.optimizer_D = torch.optim.Adam(
                    self.netDisc.parameters(), lr=0.0001
                )
            if epoch == 80:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.00001)
                self.optimizer_D = torch.optim.Adam(
                    self.netDisc.parameters(), lr=0.00001
                )
            loss_D_sum = 0
            loss_G_fake_sum = 0
            loss_perturb_sum = 0
            loss_adv_sum = 0
            for i, data in enumerate(train_dataloader, start=0):
                (images, _, paths) = data
                images = images.to(self.device)

                (
                    loss_D_batch,
                    loss_G_fake_batch,
                    loss_perturb_batch,
                    loss_adv_batch,
                ) = self.train_batch(images)
                loss_D_sum += loss_D_batch
                loss_G_fake_sum += loss_G_fake_batch
                loss_perturb_sum += loss_perturb_batch
                loss_adv_sum += loss_adv_batch

            # print statistics
            num_batch = len(train_dataloader)
            print(
                "epoch %d:\nloss_D: %.3f, loss_G_fake: %.3f,\
             \nloss_perturb: %.3f, loss_adv: %.3f, \n"
                % (
                    epoch,
                    loss_D_sum / num_batch,
                    loss_G_fake_sum / num_batch,
                    loss_perturb_sum / num_batch,
                    loss_adv_sum / num_batch,
                )
            )

            # save generator
            if epoch % 5 == 0:
                netG_file_name = models_path + "netG_epoch_" + str(epoch) + ".pth"
                torch.save(self.netG.state_dict(), netG_file_name)

                netD_file_name = models_path + "netD_epoch_" + str(epoch) + ".pth"
                torch.save(self.netD.state_dict(), netD_file_name)
Beispiel #5
0
class Trainer():
    def __init__(self, config):
        self.batch_size = config.batchSize
        self.epochs = config.epochs

        self.use_cycle_loss = config.cycleLoss
        self.cycle_multiplier = config.cycleMultiplier

        self.use_identity_loss = config.identityLoss
        self.identity_multiplier = config.identityMultiplier

        self.load_models = config.loadModels
        self.data_x_loc = config.dataX
        self.data_y_loc = config.dataY

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.init_models()
        self.init_data_loaders()
        self.g_optimizer = torch.optim.Adam(list(self.G_X.parameters()) +
                                            list(self.G_Y.parameters()),
                                            lr=config.lr)
        self.d_optimizer = torch.optim.Adam(list(self.D_X.parameters()) +
                                            list(self.D_Y.parameters()),
                                            lr=config.lr)
        self.scheduler_g = torch.optim.lr_scheduler.StepLR(self.g_optimizer,
                                                           step_size=1,
                                                           gamma=0.95)

        self.output_path = "./outputs/"
        self.img_width = 256
        self.img_height = 256

    # Load/Construct the models
    def init_models(self):

        self.G_X = Generator(3, 3, nn.InstanceNorm2d)
        self.D_X = Discriminator(3)
        self.G_Y = Generator(3, 3, nn.InstanceNorm2d)
        self.D_Y = Discriminator(3)

        if self.load_models:
            self.G_X.load_state_dict(
                torch.load(self.output_path + "models/G_X",
                           map_location='cpu'))
            self.G_Y.load_state_dict(
                torch.load(self.output_path + "models/G_Y",
                           map_location='cpu'))
            self.D_X.load_state_dict(
                torch.load(self.output_path + "models/D_X",
                           map_location='cpu'))
            self.D_Y.load_state_dict(
                torch.load(self.output_path + "models/D_Y",
                           map_location='cpu'))
        else:
            self.G_X.apply(init_func)
            self.G_Y.apply(init_func)
            self.D_X.apply(init_func)
            self.D_Y.apply(init_func)

        self.G_X.to(self.device)
        self.G_Y.to(self.device)
        self.D_X.to(self.device)
        self.D_Y.to(self.device)

    # Initialize data loaders and image transformer
    def init_data_loaders(self):

        transform = transforms.Compose([
            transforms.Resize((self.img_width, self.img_height)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        X_folder = torchvision.datasets.ImageFolder(self.data_x_loc, transform)
        self.X_loader = torch.utils.data.DataLoader(X_folder,
                                                    batch_size=self.batch_size,
                                                    shuffle=True)

        Y_folder = torchvision.datasets.ImageFolder(self.data_y_loc, transform)
        self.Y_loader = torch.utils.data.DataLoader(Y_folder,
                                                    batch_size=self.batch_size,
                                                    shuffle=True)

    def save_models(self):
        torch.save(self.G_X.state_dict(), self.output_path + "models/G_X")
        torch.save(self.D_X.state_dict(), self.output_path + "models/D_X")
        torch.save(self.G_Y.state_dict(), self.output_path + "models/G_Y")
        torch.save(self.D_Y.state_dict(), self.output_path + "models/D_Y")

    # Reset gradients for all models, needed for between every training
    def reset_gradients(self):
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    # Sample image from training data every %x epoch and save them for judging
    def save_samples(self, epoch):
        x_iter = iter(self.X_loader)
        y_iter = iter(self.Y_loader)

        img_data_x, _ = next(x_iter)
        img_data_y, _ = next(y_iter)

        original_x = np.array(img_data_x[0])
        generated_y = np.array(
            self.G_Y(img_data_x[0].view(1, 3, self.img_width,
                                        self.img_height).to(
                                            self.device)).cpu().detach())[0]

        original_y = np.array(img_data_y[0])
        generated_x = np.array(
            self.G_X(img_data_y[0].view(1, 3, self.img_width,
                                        self.img_height).to(
                                            self.device)).cpu().detach())[0]

        def prepare_image(img):
            img = img.transpose((1, 2, 0))
            return img / 2 + 0.5

        original_x = prepare_image(original_x)
        generated_y = prepare_image(generated_y)

        original_y = prepare_image(original_y)
        generated_x = prepare_image(generated_x)

        plt.imsave('./outputs/samples/original_X_{}.png'.format(epoch),
                   original_x)
        plt.imsave('./outputs/samples/original_Y_{}.png'.format(epoch),
                   original_y)

        plt.imsave('./outputs/samples/generated_X_{}.png'.format(epoch),
                   generated_x)
        plt.imsave('./outputs/samples/generated_Y_{}.png'.format(epoch),
                   generated_y)

    # Training loop
    def train(self):
        D_X_losses = []
        D_Y_losses = []

        G_X_losses = []
        G_Y_losses = []

        for epoch in range(self.epochs):
            print("======")
            print("Epoch {}!".format(epoch + 1))

            # Track progress
            if epoch % 5 == 0:
                self.save_samples(epoch)

            # Paper reduces lr after 100 epochs
            if epoch > 100:
                self.scheduler_g.step()

            for (data_X, _), (data_Y, _) in zip(self.X_loader, self.Y_loader):
                data_X = data_X.to(self.device)
                data_Y = data_Y.to(self.device)

                # =====================================
                # Train Discriminators
                # =====================================

                # Train fake X
                self.reset_gradients()
                fake_X = self.G_X(data_Y)
                out_fake_X = self.D_X(fake_X)
                d_x_f_loss = torch.mean(out_fake_X**2)
                d_x_f_loss.backward()
                self.d_optimizer.step()

                # Train fake Y
                self.reset_gradients()
                fake_Y = self.G_Y(data_X)
                out_fake_Y = self.D_Y(fake_Y)
                d_y_f_loss = torch.mean(out_fake_Y**2)
                d_y_f_loss.backward()
                self.d_optimizer.step()

                # Train true X
                self.reset_gradients()
                out_true_X = self.D_X(data_X)
                d_x_t_loss = torch.mean((out_true_X - 1)**2)
                d_x_t_loss.backward()
                self.d_optimizer.step()

                # Train true Y
                self.reset_gradients()
                out_true_Y = self.D_Y(data_Y)
                d_y_t_loss = torch.mean((out_true_Y - 1)**2)
                d_y_t_loss.backward()
                self.d_optimizer.step()

                D_X_losses.append([
                    d_x_t_loss.cpu().detach().numpy(),
                    d_x_f_loss.cpu().detach().numpy()
                ])
                D_Y_losses.append([
                    d_y_t_loss.cpu().detach().numpy(),
                    d_y_f_loss.cpu().detach().numpy()
                ])

                # =====================================
                # Train GENERATORS
                # =====================================

                # Cycle X -> Y -> X
                self.reset_gradients()

                fake_Y = self.G_Y(data_X)
                out_fake_Y = self.D_Y(fake_Y)

                g_loss1 = torch.mean((out_fake_Y - 1)**2)
                if self.use_cycle_loss:
                    reconst_X = self.G_X(fake_Y)
                    g_loss2 = self.cycle_multiplier * torch.mean(
                        (data_X - reconst_X)**2)

                G_Y_losses.append([
                    g_loss1.cpu().detach().numpy(),
                    g_loss2.cpu().detach().numpy()
                ])
                g_loss = g_loss1 + g_loss2
                g_loss.backward()
                self.g_optimizer.step()

                # Cycle Y -> X -> Y
                self.reset_gradients()

                fake_X = self.G_X(data_Y)
                out_fake_X = self.D_X(fake_X)

                g_loss1 = torch.mean((out_fake_X - 1)**2)
                if self.use_cycle_loss:
                    reconst_Y = self.G_Y(fake_X)
                    g_loss2 = self.cycle_multiplier * torch.mean(
                        (data_Y - reconst_Y)**2)

                G_X_losses.append([
                    g_loss1.cpu().detach().numpy(),
                    g_loss2.cpu().detach().numpy()
                ])
                g_loss = g_loss1 + g_loss2
                g_loss.backward()
                self.g_optimizer.step()

                # =====================================
                # Train image IDENTITY
                # =====================================

                if self.use_identity_loss:
                    self.reset_gradients()

                    # X should be same after G(X)
                    same_X = self.G_X(data_X)
                    g_loss = self.identity_multiplier * torch.mean(
                        (data_X - same_X)**2)
                    g_loss.backward()
                    self.g_optimizer.step()

                    # Y should be same after G(Y)
                    same_Y = self.G_X(data_Y)
                    g_loss = self.identity_multiplier * torch.mean(
                        (data_Y - same_Y)**2)
                    g_loss.backward()
                    self.g_optimizer.step()

            # Epoch done, save models
            self.save_models()

        # Save losses for analysis
        np.save(self.output_path + 'losses/G_X_losses.npy',
                np.array(G_X_losses))
        np.save(self.output_path + 'losses/G_Y_losses.npy',
                np.array(G_Y_losses))
        np.save(self.output_path + 'losses/D_X_losses.npy',
                np.array(D_X_losses))
        np.save(self.output_path + 'losses/D_Y_losses.npy',
                np.array(D_Y_losses))
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=400,
                        help='number of epochs of training')
    parser.add_argument('--batchSize',
                        type=int,
                        default=10,
                        help='size of the batches')
    parser.add_argument('--dataroot',
                        type=str,
                        default='datasets/genderchange/',
                        help='root directory of the dataset')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='initial learning rate')
    parser.add_argument(
        '--decay_epoch',
        type=int,
        default=100,
        help='epoch to start linearly decaying the learning rate to 0')
    parser.add_argument('--size',
                        type=int,
                        default=256,
                        help='size of the data crop (squared assumed)')
    parser.add_argument('--input_nc',
                        type=int,
                        default=3,
                        help='number of channels of input data')
    parser.add_argument('--output_nc',
                        type=int,
                        default=3,
                        help='number of channels of output data')
    parser.add_argument('--cuda',
                        action='store_true',
                        help='use GPU computation')
    parser.add_argument(
        '--n_cpu',
        type=int,
        default=8,
        help='number of cpu threads to use during batch generation')
    opt = parser.parse_args()
    print(opt)

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    ###### Definition of variables ######
    # Networks
    netG_A2B = Generator(opt.input_nc, opt.output_nc)
    netG_B2A = Generator(opt.output_nc, opt.input_nc)
    netD_A = Discriminator(opt.input_nc)
    netD_B = Discriminator(opt.output_nc)

    if opt.cuda:
        netG_A2B.cuda()
        netG_B2A.cuda()
        netD_A.cuda()
        netD_B.cuda()

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    # Lossess
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()

    # Optimizers & LR schedulers
    optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                                   netG_B2A.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D_A = torch.optim.Adam(netD_A.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))
    optimizer_D_B = torch.optim.Adam(netD_B.parameters(),
                                     lr=opt.lr,
                                     betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_A,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D_B,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Inputs & targets memory allocation
    Tensor = torch.cuda.FloatTensor if opt.cuda else torch.Tensor
    input_A = Tensor(opt.batchSize, opt.input_nc, opt.size, opt.size)
    input_B = Tensor(opt.batchSize, opt.output_nc, opt.size, opt.size)
    target_real = Variable(Tensor(opt.batchSize).fill_(1.0),
                           requires_grad=False)
    target_fake = Variable(Tensor(opt.batchSize).fill_(0.0),
                           requires_grad=False)

    fake_A_buffer = ReplayBuffer()
    fake_B_buffer = ReplayBuffer()

    # Dataset loader
    transforms_ = [
        transforms.Resize(int(opt.size * 1.2), Image.BICUBIC),
        transforms.CenterCrop(opt.size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]

    dataloader = DataLoader(ImageDataset(opt.dataroot,
                                         transforms_=transforms_,
                                         unaligned=True),
                            batch_size=opt.batchSize,
                            shuffle=True,
                            num_workers=opt.n_cpu,
                            drop_last=True)

    # Plot Loss and Images in Tensorboard
    experiment_dir = 'logs/{}@{}'.format(
        opt.dataroot.split('/')[1],
        datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
    os.makedirs(experiment_dir, exist_ok=True)
    writer = SummaryWriter(os.path.join(experiment_dir, "tb"))

    metric_dict = defaultdict(list)
    n_iters_total = 0

    ###################################
    ###### Training ######
    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(input_A.copy_(batch['A']))
            real_B = Variable(input_B.copy_(batch['B']))

            ###### Generators A2B and B2A ######
            optimizer_G.zero_grad()

            # Identity loss
            # G_A2B(B) should equal B if real B is fed
            same_B = netG_A2B(real_B)
            loss_identity_B = criterion_identity(
                same_B, real_B) * 5.0  # [batchSize, 3, ImgSize, ImgSize]

            # G_B2A(A) should equal A if real A is fed
            same_A = netG_B2A(real_A)
            loss_identity_A = criterion_identity(
                same_A, real_A) * 5.0  # [batchSize, 3, ImgSize, ImgSize]

            # GAN loss
            fake_B = netG_A2B(real_A)
            pred_fake = netD_B(fake_B).view(-1)
            loss_GAN_A2B = criterion_GAN(pred_fake, target_real)  # [batchSize]

            fake_A = netG_B2A(real_B)
            pred_fake = netD_A(fake_A).view(-1)
            loss_GAN_B2A = criterion_GAN(pred_fake, target_real)  # [batchSize]

            # Cycle loss
            recovered_A = netG_B2A(fake_B)
            loss_cycle_ABA = criterion_cycle(
                recovered_A, real_A) * 10.0  # [batchSize, 3, ImgSize, ImgSize]

            recovered_B = netG_A2B(fake_A)
            loss_cycle_BAB = criterion_cycle(
                recovered_B, real_B) * 10.0  # [batchSize, 3, ImgSize, ImgSize]

            # Total loss
            loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB

            loss_G.backward()
            optimizer_G.step()
            ###################################

            ###### Discriminator A ######
            optimizer_D_A.zero_grad()

            # Real loss
            pred_real = netD_A(real_A).view(-1)
            loss_D_real = criterion_GAN(pred_real, target_real)  # [batchSize]

            # Fake loss
            fake_A = fake_A_buffer.push_and_pop(fake_A)
            pred_fake = netD_A(fake_A.detach()).view(-1)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)  # [batchSize]

            # Total loss
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A.backward()

            optimizer_D_A.step()
            ###################################

            ###### Discriminator B ######
            optimizer_D_B.zero_grad()

            # Real loss
            pred_real = netD_B(real_B).view(-1)
            loss_D_real = criterion_GAN(pred_real, target_real)  # [batchSize]

            # Fake loss
            fake_B = fake_B_buffer.push_and_pop(fake_B)
            pred_fake = netD_B(fake_B.detach()).view(-1)
            loss_D_fake = criterion_GAN(pred_fake, target_fake)  # [batchSize]

            # Total loss
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B.backward()

            optimizer_D_B.step()
            ###################################

            metric_dict['loss_G'].append(loss_G.item())
            metric_dict['loss_G_identity'].append(loss_identity_A.item() +
                                                  loss_identity_B.item())
            metric_dict['loss_G_GAN'].append(loss_GAN_A2B.item() +
                                             loss_GAN_B2A.item())
            metric_dict['loss_G_cycle'].append(loss_cycle_ABA.item() +
                                               loss_cycle_BAB.item())
            metric_dict['loss_D'].append(loss_D_A.item() + loss_D_B.item())

            for title, value in metric_dict.items():
                writer.add_scalar('train/{}'.format(title), value[-1],
                                  n_iters_total)

            n_iters_total += 1

        print("""
        -----------------------------------------------------------
        Epoch : {} Finished
        Loss_G : {}
        Loss_G_identity : {}
        Loss_G_GAN : {}
        Loss_G_cycle : {}
        Loss_D : {}
        -----------------------------------------------------------
        """.format(epoch, loss_G, loss_identity_A + loss_identity_B,
                   loss_GAN_A2B + loss_GAN_B2A,
                   loss_cycle_ABA + loss_cycle_BAB, loss_D_A + loss_D_B))

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        # Save models checkpoints

        if loss_G.item() < 2.5:
            os.makedirs(os.path.join(experiment_dir, str(epoch)),
                        exist_ok=True)
            torch.save(netG_A2B.state_dict(),
                       '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch))
            torch.save(netG_B2A.state_dict(),
                       '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch))
            torch.save(netD_A.state_dict(),
                       '{}/{}/netD_A.pth'.format(experiment_dir, epoch))
            torch.save(netD_B.state_dict(),
                       '{}/{}/netD_B.pth'.format(experiment_dir, epoch))
        elif epoch > 100 and epoch % 40 == 0:
            os.makedirs(os.path.join(experiment_dir, str(epoch)),
                        exist_ok=True)
            torch.save(netG_A2B.state_dict(),
                       '{}/{}/netG_A2B.pth'.format(experiment_dir, epoch))
            torch.save(netG_B2A.state_dict(),
                       '{}/{}/netG_B2A.pth'.format(experiment_dir, epoch))
            torch.save(netD_A.state_dict(),
                       '{}/{}/netD_A.pth'.format(experiment_dir, epoch))
            torch.save(netD_B.state_dict(),
                       '{}/{}/netD_B.pth'.format(experiment_dir, epoch))

        for title, value in metric_dict.items():
            writer.add_scalar("train/{}_epoch".format(title), np.mean(value),
                              epoch)