def train_cycle_gan(data_root, semi_supervised=False):
    opt = get_opts()

    ensure_dir(models_prefix)
    ensure_dir(images_prefix)

    cycle_gan = CycleGAN(device,
                         models_prefix,
                         opt["lr"],
                         opt["b1"],
                         train=True,
                         semi_supervised=semi_supervised)
    data = DataLoader(data_root=data_root,
                      image_size=(opt['img_height'], opt['img_width']),
                      batch_size=opt['batch_size'])

    total_images = len(data.names)
    print("Total Training Images", total_images)

    total_batches = int(ceil(total_images / opt['batch_size']))

    for epoch in range(cycle_gan.epoch_tracker.epoch, opt['n_epochs']):
        for iteration in range(total_batches):

            if (epoch == cycle_gan.epoch_tracker.epoch
                    and iteration < cycle_gan.epoch_tracker.iter):
                continue

            y, x = next(data.data_generator(iteration))

            real_A = Variable(x.type(Tensor))
            real_B = Variable(y.type(Tensor))

            cycle_gan.set_input(real_A, real_B)
            cycle_gan.train()

            message = (
                "\r[Epoch {}/{}] [Batch {}/{}] [DA:{}, DB:{}] [GA:{}, GB:{}, cycleA:{}, cycleB:{}, G:{}]"
                .format(epoch, opt["n_epochs"], iteration, total_batches,
                        cycle_gan.loss_disA.item(), cycle_gan.loss_disB.item(),
                        cycle_gan.loss_genA.item(), cycle_gan.loss_genB.item(),
                        cycle_gan.loss_cycle_A.item(),
                        cycle_gan.loss_cycle_B.item(), cycle_gan.loss_G))
            print(message)
            logger.info(message)

            if iteration % opt['sample_interval'] == 0:
                cycle_gan.save_progress(images_prefix, epoch, iteration)
        cycle_gan.save_progress(images_prefix,
                                epoch,
                                total_batches,
                                save_epoch=True)
def test_cycle_gan(semi_supervised=True):
    opt = get_opts()

    ensure_dir(models_prefix)
    ensure_dir(images_prefix)

    cycle_gan = CycleGAN(device,
                         models_prefix,
                         opt["lr"],
                         opt["b1"],
                         train=False,
                         semi_supervised=semi_supervised)
    data = DataLoader(data_root=data_root,
                      image_size=(opt['img_height'], opt['img_width']),
                      batch_size=1,
                      train=False)

    total_images = len(data.names)
    print("Total Testing Images", total_images)

    loss_A = 0.0
    loss_B = 0.0
    name_loss_A = []
    name_loss_B = []

    for i in range(total_images):
        print(i, "/", total_images)
        x, y = next(data.data_generator(i))
        name = data.names[i]

        real_A = Variable(x.type(Tensor))
        real_B = Variable(y.type(Tensor))

        cycle_gan.set_input(real_A, real_B)
        cycle_gan.test()
        cycle_gan.save_image(images_prefix, name)
        loss_A += cycle_gan.test_A
        loss_B += cycle_gan.test_B
        name_loss_A.append((cycle_gan.test_A, name))
        name_loss_B.append((cycle_gan.test_B, name))

    info = "Average Loss A:{} B :{}".format(loss_A / (1.0 * total_images),
                                            loss_B / (1.0 * total_images))
    print(info)
    logger.info(info)
    name_loss_A = sorted(name_loss_A)
    name_loss_B = sorted(name_loss_B)
    print("top 10 images")
    print(name_loss_A[:10])
    print(name_loss_B[:10])
            torch.load("saved_models/discriminator.pth"))
else:
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=opt["lr"],
                               betas=(opt["b1"], opt["b2"]))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=opt["lr"],
                               betas=(opt["b1"], opt["b2"]))

for epoch in range(opt['n_epochs']):
    for i in range(25000 // opt['batch_size']):

        y, x = next(data.data_generator())

        real_A = Variable(x.type(Tensor))
        real_B = Variable(y.type(Tensor))

        valid = Variable(Tensor(np.ones((real_A.size(0), 1))),
                         requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), 1))),
                        requires_grad=False)

        optimizer_G.zero_grad()

        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B)
        loss_GAN = criterion_GAN(pred_fake, valid)
        loss_pixel = criterion_pixelwise(fake_B, real_B)