Beispiel #1
0
def inference(config):
    hair_model, skin_model = load_model(config)

    #train_loader, val_loader = get_loaders(hair_model, skin_model, config)

    try:
        your_pic = Image.open(config.your_pic)
        celeb_pic = Image.open(config.celeb_pic)

    except:
        return

    your_pic, your_pic_mask, celeb_pic, celeb_pic_mask = DataLoader(
        transform(your_pic, celeb_pic, config.image_size, hair_model,
                  skin_model, config.device))

    # Initialize
    vgg = Vgg16().to(config.device)
    resnet = ResNet18(requires_grad=True, pretrained=True).to(config.device)
    generator = GeneratorUNet().to(config.device)
    # discriminator = Discriminator().to(config.device)

    try:
        resnet.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (20, 'resnet'))))
        generator.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (20, 'generator'))))
    except OSError:
        print('Check if your pretrained weight is in the right place.')

    z1 = resnet(your_pic * your_pic_mask)  #skin
    z2 = resnet(celeb_pic * celeb_pic_mask)  #hair
    fake_im = generator(your_pic, z1, z2)  # z1 is skin, z2 is hair

    images = [your_pic[0], celeb_pic[0], fake_im[0]]
    titles = ['Your picture', 'Celebrity picture', 'Synthesized picture']

    fig, axes = plt.subplots(1, len(titles))
    for i in range(len(images)):
        im = images[i]
        im = im.data.cpu().numpy().transpose(1, 2, 0)
        im = (im + 1) / 2
        axes[i].imshow(im)
        axes[i].axis('off')
        axes[i].set_title(titles[i])

    plt.show()
Beispiel #2
0
def main(args):

    # Create sample and checkpoint directories
    os.makedirs("images/{}".format(args.data_root), exist_ok=True)
    os.makedirs("saved_models/{}".format(args.data_root), exist_ok=True)

    # Loss weight of L1 pixel-wise loss between translated image and real image
    lambda_pixel = 100

    # Initialize generator and discriminator
    generator = GeneratorUNet()
    discriminator = Discriminator()

    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()

    if args.epoch != 0:
    # Load pretrained models
        generator.load_state_dict(torch.load("saved_models/{}/generator_{}.pth" % (args.data_root, args.epoch)))
        discriminator.load_state_dict(torch.load("saved_models/{}/discriminator_{}.pth" % (args.data_root, args.epoch)))
    else:
        # Initialize weights
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)


    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2))

    transforms_ = [
    transforms.Resize((args.img_height, args.img_width), Image.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]

    # Training data loader
    dataloader = DataLoader(
        FacadeDataset("../../datasets/{}".format(args.data_root), transforms_=transforms_),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.n_cpu,
    )
    # Test data loader
    # global val_dataloader
    val_dataloader = DataLoader(
        FacadeDataset("../../datasets/{}".format(args.data_root), transforms_=transforms_,  mode="val"),
        batch_size=10,
        shuffle=True,
        num_workers=1,
    )

    optimizer_list = [optimizer_G, optimizer_D]
    network_list = [generator, discriminator]

    dataloaders = [dataloader, val_dataloader]
    train(args, network_list, optimizer_list, dataloaders)
Beispiel #3
0
def main(config):
    model = load_model(config)
    your_pic = config.your_pic
    celeb_pic = config.celeb_pic
    try:
        your_pic = Image.open(your_pic)
        celeb_pic = Image.open(celeb_pic)
    except Exception:
        print("Provide correct paths!")
        return

    im1, m1, im2, m2 = DataLoader(
        transform(your_pic, celeb_pic, config.image_size, model,
                  config.device))
    im1, m1, im2, m2 = map(lambda x: x.repeat(2, 1, 1, 1), [im1, m1, im2, m2])

    # Initialize
    resnet = ResNet18(requires_grad=True, pretrained=True).to(config.device)
    generator = GeneratorUNet().to(config.device)

    try:
        resnet.load_state_dict(
            torch.load(os.path.join(
                config.checkpoints,
                'epoch_%d_%s.pth' % (config.inf_epoch, 'resnet')),
                       map_location=config.device))
        generator.load_state_dict(
            torch.load(os.path.join(
                config.checkpoints,
                'epoch_%d_%s.pth' % (config.inf_epoch, 'generator')),
                       map_location=config.device))
    except OSError:
        print('Check if your pretrained weight is in the right place.')

    with torch.no_grad():
        resnet.eval()
        generator.eval()

        z = resnet(im2 * m2)
        fake_im = generator(im1, im2, z)
        images = [im1[0], im2[0], fake_im[0]]
        titles = ['Your picture', 'Celebrity picture', 'Synthesized picture']

        fig, axes = plt.subplots(1, len(titles))
        for i in range(len(images)):
            im = images[i]
            im = im.data.cpu().numpy().transpose(1, 2, 0)
            im = (im + 1) / 2
            axes[i].imshow(im)
            axes[i].axis('off')
            axes[i].set_title(titles[i])

        plt.show()
Beispiel #4
0
# Loss functions
criterion_GAN = torch.nn.MSELoss(reduction='none')
# criterion_GAN = torch.nn.BCEWithLogitsLoss(reduction='none')

# criterion_pixelwise = torch.nn.MSELoss(reduction='none')
criterion_pixelwise = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=torch.full((30, 30), 10))
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 1
lambda_GAN = 1

# Calculate output of image discriminator (PatchGAN)
patch = (1, args.img_height // 2 ** 2, args.img_width // 2 ** 2)

# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()

if cuda:
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    criterion_GAN.to(device)
    criterion_pixelwise.to(device)

if args.epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (args.dataset_name, args.epoch)))
    discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (args.dataset_name, args.epoch)))
else:
    # Initialize weights
    generator.apply(weights_init_normal)
opt = parser.parse_args()
print(opt)

os.makedirs('images/%s' % opt.dataset_name, exist_ok=True)
os.makedirs('saved_model/%s' % opt.dataset_name, exist_ok=True)

cuda = True if torch.cuda.is_available() else False

# Loss functions
criterion = nn.BCELoss()

criterion_pixelwise = torch.nn.L1Loss()

# Initialize network
Unet = GeneratorUNet()

if cuda:
    Unet = torch.nn.DataParallel(Unet).cuda()
    criterion.cuda()
    criterion_pixelwise.cuda()

if opt.epoch != 0:
    # Load pretrained models
    Unet.load_state_dict(
        torch.load(
            os.path.join(
                opt.path, '/saved_model/%s/generator_%d.pth' %
                (opt.dataset_name, opt.epoch))))
else:
    # Initialize weights
Beispiel #6
0
def main(config):
    model = load_model(config)
    train_loader, val_loader = get_loaders(model, config)

    # Make dirs
    if not os.path.exists(config.checkpoints):
        os.makedirs(config.checkpoints, exist_ok=True)
    if not os.path.exists(config.save_path):
        os.makedirs(config.save_path, exist_ok=True)

    # Loss Functions
    criterion_GAN = mse_loss

    # Calculate output of image discriminator (PatchGAN)
    patch = (1, config.image_size // 2**4, config.image_size // 2**4)

    # Initialize
    vgg = Vgg16().to(config.device)
    resnet = ResNet18(requires_grad=True, pretrained=True).to(config.device)
    generator = GeneratorUNet().to(config.device)
    discriminator = Discriminator().to(config.device)

    if config.epoch != 0:
        # Load pretrained models
        resnet.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints, 'epoch_%d_%s.pth' %
                             (config.epoch - 1, 'resnet'))))
        generator.load_state_dict(
            torch.load(
                os.path.join(
                    config.checkpoints,
                    'epoch_%d_%s.pth' % (config.epoch - 1, 'generator'))))
        discriminator.load_state_dict(
            torch.load(
                os.path.join(
                    config.checkpoints,
                    'epoch_%d_%s.pth' % (config.epoch - 1, 'discriminator'))))
    else:
        # Initialize weights
        # resnet.apply(weights_init_normal)
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

    # Optimizers
    optimizer_resnet = torch.optim.Adam(resnet.parameters(),
                                        lr=config.lr,
                                        betas=(config.b1, config.b2))
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=config.lr,
                                   betas=(config.b1, config.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=config.lr,
                                   betas=(config.b1, config.b2))

    # ----------
    #  Training
    # ----------

    resnet.train()
    generator.train()
    discriminator.train()
    for epoch in range(config.epoch, config.n_epochs):
        for i, (im1, m1, im2, m2) in enumerate(train_loader):
            assert im1.size(0) == im2.size(0)
            valid = Variable(torch.Tensor(np.ones(
                (im1.size(0), *patch))).to(config.device),
                             requires_grad=False)
            fake = Variable(torch.Tensor(np.ones(
                (im1.size(0), *patch))).to(config.device),
                            requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            optimizer_resnet.zero_grad()
            optimizer_G.zero_grad()

            # GAN loss
            z = resnet(im2 * m2)
            if epoch < config.gan_epochs:
                fake_im = generator(im1 * (1 - m1), im2 * m2, z)
            else:
                fake_im = generator(im1, im2, z)
            if epoch < config.gan_epochs:
                pred_fake = discriminator(fake_im, im2)
                gan_loss = config.lambda_gan * criterion_GAN(pred_fake, valid)
            else:
                gan_loss = torch.Tensor([0]).to(config.device)

            # Hair, Face loss
            fake_m2 = torch.argmax(model(fake_im),
                                   1).unsqueeze(1).type(torch.uint8).repeat(
                                       1, 3, 1, 1).to(config.device)
            if 0.5 * torch.sum(m1) <= torch.sum(
                    fake_m2) <= 1.5 * torch.sum(m1):
                hair_loss = config.lambda_style * calc_style_loss(
                    fake_im * fake_m2, im2 * m2, vgg) + calc_content_loss(
                        fake_im * fake_m2, im2 * m2, vgg)
                face_loss = calc_content_loss(fake_im, im1, vgg)
            else:
                hair_loss = config.lambda_style * calc_style_loss(
                    fake_im * m1, im2 * m2, vgg) + calc_content_loss(
                        fake_im * m1, im2 * m2, vgg)
                face_loss = calc_content_loss(fake_im, im1, vgg)
            hair_loss *= config.lambda_hair
            face_loss *= config.lambda_face

            # Total loss
            loss = gan_loss + hair_loss + face_loss

            loss.backward()
            optimizer_resnet.step()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            if epoch < config.gan_epochs:
                optimizer_D.zero_grad()

                # Real loss
                pred_real = discriminator(im1 * (1 - m1) + im2 * m2, im2)
                loss_real = criterion_GAN(pred_real, valid)
                # Fake loss
                pred_fake = discriminator(fake_im.detach(), im2)
                loss_fake = criterion_GAN(pred_fake, fake)
                # Total loss
                loss_D = 0.5 * (loss_real + loss_fake)

                loss_D.backward()
                optimizer_D.step()

            if i % config.sample_interval == 0:
                msg = "Train || Gan loss: %.6f, hair loss: %.6f, face loss: %.6f, loss: %.6f\n" % \
                    (gan_loss.item(), hair_loss.item(), face_loss.item(), loss.item())
                sys.stdout.write("Epoch: %d || Batch: %d\n" % (epoch, i))
                sys.stdout.write(msg)
                fname = os.path.join(
                    config.save_path,
                    "Train_Epoch:%d_Batch:%d.png" % (epoch, i))
                sample_images([im1[0], im2[0], fake_im[0]],
                              ["img1", "img2", "img1+img2"], fname)
                for j, (im1, m1, im2, m2) in enumerate(val_loader):
                    with torch.no_grad():
                        valid = Variable(torch.Tensor(
                            np.ones((im1.size(0), *patch))).to(config.device),
                                         requires_grad=False)
                        fake = Variable(torch.Tensor(
                            np.ones((im1.size(0), *patch))).to(config.device),
                                        requires_grad=False)

                        # GAN loss
                        z = resnet(im2 * m2)
                        if epoch < config.gan_epochs:
                            fake_im = generator(im1 * (1 - m1), im2 * m2, z)
                        else:
                            fake_im = generator(im1, im2, z)

                        if epoch < config.gan_epochs:
                            pred_fake = discriminator(fake_im, im2)
                            gan_loss = config.lambda_gan * criterion_GAN(
                                pred_fake, valid)
                        else:
                            gan_loss = torch.Tensor([0]).to(config.device)

                        # Hair, Face loss
                        fake_m2 = torch.argmax(
                            model(fake_im),
                            1).unsqueeze(1).type(torch.uint8).repeat(
                                1, 3, 1, 1).to(config.device)
                        if 0.5 * torch.sum(m1) <= torch.sum(
                                fake_m2) <= 1.5 * torch.sum(m1):
                            hair_loss = config.lambda_style * calc_style_loss(
                                fake_im * fake_m2, im2 * m2,
                                vgg) + calc_content_loss(
                                    fake_im * fake_m2, im2 * m2, vgg)
                            face_loss = calc_content_loss(fake_im, im1, vgg)
                        else:
                            hair_loss = config.lambda_style * calc_style_loss(
                                fake_im * m1, im2 * m2,
                                vgg) + calc_content_loss(
                                    fake_im * m1, im2 * m2, vgg)
                            face_loss = calc_content_loss(fake_im, im1, vgg)
                        hair_loss *= config.lambda_hair
                        face_loss *= config.lambda_face

                        # Total loss
                        loss = gan_loss + hair_loss + face_loss

                        msg = "Validation || Gan loss: %.6f, hair loss: %.6f, face loss: %.6f, loss: %.6f\n" % \
                                (gan_loss.item(), hair_loss.item(), face_loss.item(), loss.item())
                        sys.stdout.write(msg)
                        fname = os.path.join(
                            config.save_path,
                            "Validation_Epoch:%d_Batch:%d.png" % (epoch, i))
                        sample_images([im1[0], im2[0], fake_im[0]],
                                      ["img1", "img2", "img1+img2"], fname)
                        break

        if epoch % config.checkpoint_interval == 0:
            if epoch < config.gan_epochs:
                models = [resnet, generator, discriminator]
                fnames = ['resnet', 'generator', 'discriminator']
            else:
                models = [resnet, generator]
                fnames = ['resnet', 'generator']
            fnames = [
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (epoch, s)) for s in fnames
            ]
            save_weights(models, fnames)
else:
    criterion_samplewise = torch.nn.MSELoss()

# Output size of the discriminator (PatchGAN)
patch = (1, 9)

# Load data
dataloader, val_dataloader = get_data_loader(
    args.dataset_prefix,
    args.batch_size,
    from_ppg=args.from_ppg,
    shuffle_training=args.shuffle_training,
    shuffle_testing=args.shuffle_testing)

# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()

if cuda:
    if torch.cuda.device_count() > 2:
        # if False:
        generator = torch.nn.DataParallel(generator, device_ids=[0, 1,
                                                                 2]).to(device)
        discriminator = torch.nn.DataParallel(discriminator,
                                              device_ids=[0, 1, 2]).to(device)
    else:
        generator = generator.to(device)
        discriminator = discriminator.to(device)

    criterion_samplewise.to(device)
Beispiel #8
0
    def __init__(self, config, use_wandb, device, dataset_path):

        # Configure data loader
        self.config = config
        self.result_name = config['NAME']

        self.use_wandb = use_wandb
        self.device = device
        self.epochs = config['EPOCHS']
        self.log_interval = config['LOG_INTERVAL']
        self.step = 0
        self.loaded_epoch = 0
        self.epoch = 0
        self.base_dir = './'

        self.patch = (1, config['PATCH_SIZE'], config['PATCH_SIZE'])
        # self.patch = (1, 256 // 2 ** 4, 256 // 2 ** 4)

        # Input shape
        self.channels = config['CHANNELS']
        self.img_rows = config['IMAGE_RES'][0]
        self.img_cols = config['IMAGE_RES'][1]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        assert self.img_rows == self.img_cols, 'The current code only works with same values for img_rows and img_cols'

        # Input images and their conditioning images

        self.conditional_d = config.get('CONDITIONAL_DISCRIMINATOR', False)
        self.recon_loss = config.get('RECON_LOSS', 'basic')
        self.loss_weight_d = config["LOSS_WEIGHT_DISC"]
        self.loss_weight_g = config["LOSS_WEIGHT_GEN"]

        # Calculate output shape of D (PatchGAN)
        patch_size = config['PATCH_SIZE']
        patch_per_dim = int(self.img_rows / patch_size)
        self.num_patches = (patch_per_dim, patch_per_dim, 1)

        # Number of filters in the first layer of G and D
        self.gf = config['FIRST_LAYERS_FILTERS']
        self.df = config['FIRST_LAYERS_FILTERS']
        self.skipconnections_generator = config['SKIP_CONNECTIONS_GENERATOR']
        self.output_activation = config['GEN_OUTPUT_ACT']
        self.decay_factor_G = config['LR_EXP_DECAY_FACTOR_G']
        self.decay_factor_D = config['LR_EXP_DECAY_FACTOR_D']

        self.generator = GeneratorUNet(in_channels=self.channels, out_channels=self.channels).to(self.device)
        self.discriminator = Discriminator(img_size=(self.img_rows, self.img_cols), in_channels=self.channels, patch_size=(patch_size, patch_size)).to(self.device)

        self.optimizer_G = torch.optim.Adam(self.generator.parameters(),
                                            lr=config['LEARNING_RATE_G'],
                                            betas=(config['ADAM_B1'], 0.999))  # 0.0002
        self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(),
                                            lr=config['LEARNING_RATE_D'], betas=(config['ADAM_B1'], 0.999))

        opt_level = 'O1'
        self.generator, self.optimizer_G = amp.initialize(self.generator,
                                                          self.optimizer_G,
                                                          opt_level=opt_level)
        self.discriminator, self.optimizer_D = amp.initialize(self.discriminator,
                                                              self.optimizer_D,
                                                              opt_level=opt_level)

        self.criterion_GAN = torch.nn.MSELoss().to(self.device)

        self.criterion_pixelwise = torch.nn.L1Loss(reduction='none').to(self.device)  # MAE
        # self.criterion_pixelwise = torch.nn.BCEWithLogitsLoss().to(self.device) # + weight + mean

        self.augmentation = dict()
        for key, value in config.items():
            if 'AUG_' in key:
                self.augmentation[key] = value

        self.train_data = DatasetCAMUS(dataset_path=dataset_path,
                                       random_state=config['RANDOM_SEED'],
                                       img_size=config['IMAGE_RES'],
                                       classes=config['LABELS'],
                                       train_ratio=config['TRAIN_RATIO'],
                                       valid_ratio=config['VALID_RATIO'],
                                       heart_states=config['HEART_STATES'],
                                       views=config['HEART_VIEWS'],
                                       image_qualities=config['IMAGE_QUALITIES'],
                                       patient_qualities=config['PATIENT_QUALITIES'],
                                       # augment=self.augmentation,
                                       subset='train')
        self.valid_data = DatasetCAMUS(dataset_path=dataset_path,
                                       random_state=config['RANDOM_SEED'],
                                       img_size=config['IMAGE_RES'],
                                       classes=config['LABELS'],
                                       train_ratio=config['TRAIN_RATIO'],
                                       valid_ratio=config['VALID_RATIO'],
                                       heart_states=config['HEART_STATES'],
                                       views=config['HEART_VIEWS'],
                                       image_qualities=config['IMAGE_QUALITIES'],
                                       patient_qualities=config['PATIENT_QUALITIES'],
                                       # augment=self.augmentation,
                                       subset='valid')

        self.test_data = DatasetCAMUS(dataset_path=dataset_path,
                                      random_state=config['RANDOM_SEED'],
                                      img_size=config['IMAGE_RES'],
                                      classes=config['LABELS'],
                                      train_ratio=config['TRAIN_RATIO'],
                                      valid_ratio=config['VALID_RATIO'],
                                      heart_states=config['HEART_STATES'],
                                      views=config['HEART_VIEWS'],
                                      image_qualities=config['IMAGE_QUALITIES'],
                                      patient_qualities=config['PATIENT_QUALITIES'],
                                      # augment=self.augmentation,
                                      subset='test')

        self.train_loader = torch.utils.data.DataLoader(self.train_data,
                                                        batch_size=config['BATCH_SIZE'],  # 32 max
                                                        shuffle=True,
                                                        num_workers=config['NUM_WORKERS'])
        self.valid_loader = torch.utils.data.DataLoader(self.valid_data,
                                                        batch_size=config['BATCH_SIZE'],
                                                        shuffle=False,
                                                        num_workers=config['NUM_WORKERS'])
        self.test_loader = torch.utils.data.DataLoader(self.test_data,
                                                       batch_size=config['BATCH_SIZE'],
                                                       shuffle=False,
                                                       num_workers=config['NUM_WORKERS'])

        # Training hyper-parameters
        self.batch_size = config['BATCH_SIZE']
        # self.max_iter = config['MAX_ITER']
        self.val_interval = config['VAL_INTERVAL']
        self.log_interval = config['LOG_INTERVAL']
        self.save_model_interval = config['SAVE_MODEL_INTERVAL']
        self.lr_G = config['LEARNING_RATE_G']
        self.lr_D = config['LEARNING_RATE_D']
Beispiel #9
0
class GAN:
    def __init__(self, config, use_wandb, device, dataset_path):

        # Configure data loader
        self.config = config
        self.result_name = config['NAME']

        self.use_wandb = use_wandb
        self.device = device
        self.epochs = config['EPOCHS']
        self.log_interval = config['LOG_INTERVAL']
        self.step = 0
        self.loaded_epoch = 0
        self.epoch = 0
        self.base_dir = './'

        self.patch = (1, config['PATCH_SIZE'], config['PATCH_SIZE'])
        # self.patch = (1, 256 // 2 ** 4, 256 // 2 ** 4)

        # Input shape
        self.channels = config['CHANNELS']
        self.img_rows = config['IMAGE_RES'][0]
        self.img_cols = config['IMAGE_RES'][1]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        assert self.img_rows == self.img_cols, 'The current code only works with same values for img_rows and img_cols'

        # Input images and their conditioning images

        self.conditional_d = config.get('CONDITIONAL_DISCRIMINATOR', False)
        self.recon_loss = config.get('RECON_LOSS', 'basic')
        self.loss_weight_d = config["LOSS_WEIGHT_DISC"]
        self.loss_weight_g = config["LOSS_WEIGHT_GEN"]

        # Calculate output shape of D (PatchGAN)
        patch_size = config['PATCH_SIZE']
        patch_per_dim = int(self.img_rows / patch_size)
        self.num_patches = (patch_per_dim, patch_per_dim, 1)

        # Number of filters in the first layer of G and D
        self.gf = config['FIRST_LAYERS_FILTERS']
        self.df = config['FIRST_LAYERS_FILTERS']
        self.skipconnections_generator = config['SKIP_CONNECTIONS_GENERATOR']
        self.output_activation = config['GEN_OUTPUT_ACT']
        self.decay_factor_G = config['LR_EXP_DECAY_FACTOR_G']
        self.decay_factor_D = config['LR_EXP_DECAY_FACTOR_D']

        self.generator = GeneratorUNet(in_channels=self.channels, out_channels=self.channels).to(self.device)
        self.discriminator = Discriminator(img_size=(self.img_rows, self.img_cols), in_channels=self.channels, patch_size=(patch_size, patch_size)).to(self.device)

        self.optimizer_G = torch.optim.Adam(self.generator.parameters(),
                                            lr=config['LEARNING_RATE_G'],
                                            betas=(config['ADAM_B1'], 0.999))  # 0.0002
        self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(),
                                            lr=config['LEARNING_RATE_D'], betas=(config['ADAM_B1'], 0.999))

        opt_level = 'O1'
        self.generator, self.optimizer_G = amp.initialize(self.generator,
                                                          self.optimizer_G,
                                                          opt_level=opt_level)
        self.discriminator, self.optimizer_D = amp.initialize(self.discriminator,
                                                              self.optimizer_D,
                                                              opt_level=opt_level)

        self.criterion_GAN = torch.nn.MSELoss().to(self.device)

        self.criterion_pixelwise = torch.nn.L1Loss(reduction='none').to(self.device)  # MAE
        # self.criterion_pixelwise = torch.nn.BCEWithLogitsLoss().to(self.device) # + weight + mean

        self.augmentation = dict()
        for key, value in config.items():
            if 'AUG_' in key:
                self.augmentation[key] = value

        self.train_data = DatasetCAMUS(dataset_path=dataset_path,
                                       random_state=config['RANDOM_SEED'],
                                       img_size=config['IMAGE_RES'],
                                       classes=config['LABELS'],
                                       train_ratio=config['TRAIN_RATIO'],
                                       valid_ratio=config['VALID_RATIO'],
                                       heart_states=config['HEART_STATES'],
                                       views=config['HEART_VIEWS'],
                                       image_qualities=config['IMAGE_QUALITIES'],
                                       patient_qualities=config['PATIENT_QUALITIES'],
                                       # augment=self.augmentation,
                                       subset='train')
        self.valid_data = DatasetCAMUS(dataset_path=dataset_path,
                                       random_state=config['RANDOM_SEED'],
                                       img_size=config['IMAGE_RES'],
                                       classes=config['LABELS'],
                                       train_ratio=config['TRAIN_RATIO'],
                                       valid_ratio=config['VALID_RATIO'],
                                       heart_states=config['HEART_STATES'],
                                       views=config['HEART_VIEWS'],
                                       image_qualities=config['IMAGE_QUALITIES'],
                                       patient_qualities=config['PATIENT_QUALITIES'],
                                       # augment=self.augmentation,
                                       subset='valid')

        self.test_data = DatasetCAMUS(dataset_path=dataset_path,
                                      random_state=config['RANDOM_SEED'],
                                      img_size=config['IMAGE_RES'],
                                      classes=config['LABELS'],
                                      train_ratio=config['TRAIN_RATIO'],
                                      valid_ratio=config['VALID_RATIO'],
                                      heart_states=config['HEART_STATES'],
                                      views=config['HEART_VIEWS'],
                                      image_qualities=config['IMAGE_QUALITIES'],
                                      patient_qualities=config['PATIENT_QUALITIES'],
                                      # augment=self.augmentation,
                                      subset='test')

        self.train_loader = torch.utils.data.DataLoader(self.train_data,
                                                        batch_size=config['BATCH_SIZE'],  # 32 max
                                                        shuffle=True,
                                                        num_workers=config['NUM_WORKERS'])
        self.valid_loader = torch.utils.data.DataLoader(self.valid_data,
                                                        batch_size=config['BATCH_SIZE'],
                                                        shuffle=False,
                                                        num_workers=config['NUM_WORKERS'])
        self.test_loader = torch.utils.data.DataLoader(self.test_data,
                                                       batch_size=config['BATCH_SIZE'],
                                                       shuffle=False,
                                                       num_workers=config['NUM_WORKERS'])

        # Training hyper-parameters
        self.batch_size = config['BATCH_SIZE']
        # self.max_iter = config['MAX_ITER']
        self.val_interval = config['VAL_INTERVAL']
        self.log_interval = config['LOG_INTERVAL']
        self.save_model_interval = config['SAVE_MODEL_INTERVAL']
        self.lr_G = config['LEARNING_RATE_G']
        self.lr_D = config['LEARNING_RATE_D']

    def train(self):

        prev_time = time.time()
        batch_size = self.batch_size
        # max_iter = self.max_iter
        val_interval = self.val_interval
        log_interval = self.log_interval
        save_model_interval = self.save_model_interval

        for epoch in range(self.loaded_epoch, self.epochs):
            self.epoch = epoch
            for i, batch in enumerate(self.train_loader):

                image, mask, full_mask, weight_map, segment_mask, quality, heart_state, view = batch

                mask = mask.to(self.device)
                image = image.to(self.device)
                full_mask = full_mask.to(self.device)
                weight_map = weight_map.to(self.device)
                segment_mask = segment_mask.to(self.device)

                # Adversarial ground truths for discriminator losses
                patch_real = torch.tensor(np.ones((mask.size(0), *self.patch)), dtype=torch.float32, device=self.device)
                patch_fake = torch.tensor(np.zeros((mask.size(0), *self.patch)), dtype=torch.float32,
                                          device=self.device)

                #  Train Discriminator

                self.generator.eval()
                self.discriminator.train()

                self.optimizer_D.zero_grad()

                fake_echo = self.generator(full_mask)  # * segment_mask  # mask

                # Real loss
                pred_real = self.discriminator(image, mask)
                loss_real = self.criterion_GAN(pred_real, patch_real)

                # Fake loss
                pred_fake = self.discriminator(fake_echo.detach(), mask)
                loss_fake = self.criterion_GAN(pred_fake, patch_fake)

                # Total loss
                loss_D = 0.5 * (loss_real + loss_fake)

                with amp.scale_loss(loss_D, self.optimizer_D) as scaled_loss:
                    scaled_loss.backward()

                self.optimizer_D.step()

                #  Train Generator

                self.generator.train()
                self.discriminator.eval()

                self.optimizer_G.zero_grad()

                # GAN loss
                fake_echo = self.generator(full_mask)
                pred_fake = self.discriminator(fake_echo, mask)
                loss_GAN = self.criterion_GAN(pred_fake, patch_fake)
                # loss_GAN = loss_fake

                # Pixel-wise loss
                loss_pixel = torch.mean(self.criterion_pixelwise(fake_echo, image) * weight_map)  # * segment_mask

                # Total loss
                loss_G = self.loss_weight_d * loss_GAN + self.loss_weight_g * loss_pixel  # 1 100

                with amp.scale_loss(loss_G, self.optimizer_G) as scaled_loss:
                    scaled_loss.backward()

                self.optimizer_G.step()

                #  Log Progress

                # Determine approximate time left
                batches_done = self.epoch * len(self.train_loader) + i
                batches_left = self.epochs * len(self.train_loader) - batches_done
                time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
                prev_time = time.time()

                # metrics
                psnr = metrics.psnr(mask, fake_echo)  # * segment_mask
                ssim = metrics.ssim(mask, fake_echo, window_size=11, size_average=True)  # * segment_mask

                # print log
                sys.stdout.write(
                    "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f patch_fake: %f real: %f] [G loss: %f, pixel: %f, adv: %f] PSNR: %f SSIM: %f ETA: %s"
                    % (
                        self.epoch,
                        self.epochs,
                        i,
                        len(self.train_loader),
                        loss_D.item(),
                        loss_fake.item(),
                        loss_real.item(),
                        loss_G.item(),
                        loss_pixel.item(),
                        loss_GAN.item(),
                        psnr,
                        ssim,
                        time_left,
                    )
                )

                # save images
                if batches_done % self.log_interval == 0:
                    self.generator.eval()
                    self.discriminator.eval()
                    self.sample_images(batches_done)
                    self.sample_images2(batches_done)

                # log wandb
                self.step += 1
                if self.use_wandb:
                    import wandb
                    wandb.log({'loss_D': loss_D, 'loss_real_D': loss_real, 'loss_fake_D': loss_fake,
                               'loss_G': loss_G, 'loss_pixel': loss_pixel, 'loss_GAN': loss_GAN,
                               'PSNR': psnr, 'SSIM': ssim},

                              step=self.step)

            # save models
            if (epoch + 1) % save_model_interval == 0:
                self.save(f'{self.base_dir}/generator_last_checkpoint.bin', model='generator')
                self.save(f'{self.base_dir}/discriminator_last_checkpoint.bin', model='discriminator')

    def sample_images(self, batches_done):
        """Saves a generated sample from the validation set"""
        image, mask, full_mask, weight_map, segment_mask, quality, heart_state, view = next(iter(self.valid_loader))
        image = image.to(self.device)
        mask = mask.to(self.device)
        full_mask = full_mask.to(self.device)
        quality = quality.to(self.device)
        segment_mask = segment_mask.to(self.device)
        fake_echo = self.generator(full_mask)  # * segment_mask # , quality)
        img_sample = torch.cat((image.data, fake_echo.data, mask.data), -2)
        save_image(img_sample, "images/%s.png" % batches_done, nrow=4, normalize=True)

        # if self.use_wandb:
        #    import wandb
        #    wandb.log({'val_image': img_sample.cpu()}, step=self.step)

    # paper-like + wandb
    def sample_images2(self, batches_done):
        """Saves a generated sample from the validation set"""
        image, mask, full_mask, weight_map, segment_mask, quality, heart_state, view = next(iter(self.valid_loader))
        mask = mask.to(self.device)
        full_mask = full_mask.to(self.device)
        image = image.to(self.device)
        quality = quality.to(self.device)
        segment_mask = segment_mask.to(self.device)
        fake_echo = self.generator(full_mask)  # * segment_mask  # , quality)

        image = image.cpu().detach().numpy()
        fake_echo = fake_echo.cpu().detach().numpy()
        mask = mask.cpu().detach().numpy()
        quality = quality.cpu().detach().numpy()

        batch = 5

        img_sample = np.concatenate([image,
                                     fake_echo,
                                     mask], axis=1)
        q = ['low', 'med', 'high']
        import matplotlib.pyplot as plt
        rows, cols = 3, batch
        titles = ['Condition', 'Generated', 'Original']
        fig, axs = plt.subplots(rows, cols)
        cnt = 0
        for row in range(rows):
            for col in range(cols):
                class_label = np.argmax(quality[col], axis=1)[0]

                axs[row, col].imshow(img_sample[col, row, :, :], cmap='gray')
                axs[row, col].set_title(titles[row] + ' ' + q[class_label], fontdict={'fontsize': 6})
                axs[row, col].axis('off')
                cnt += 1

        # fig.savefig('%s/%s/%s/%s_%d.png' % (RESULT_DIR, self.result_name, VAL_DIR, prefix, step_num))
        fig.savefig("images/_%s.png" % batches_done)

        if self.use_wandb:
            import wandb
            wandb.log({'val_image': fig}, step=self.step)

    def save(self, path, model='generator'):
        if model == 'generator':
            self.generator.eval()
            torch.save({
                'model_state_dict': self.generator.state_dict(),
                'optimizer_state_dict': self.optimizer_G.state_dict(),

                # 'scheduler_state_dict': self.scheduler.state_dict(),
                # 'best_summary_loss': self.best_summary_loss,
                'epoch': self.epoch,
            }, path)
            # print('\ngenerator saved, epoch ', self.epoch)
        elif model == 'discriminator':

            self.discriminator.eval()
            torch.save({
                'model_state_dict': self.discriminator.state_dict(),
                'optimizer_state_dict': self.optimizer_D.state_dict(),

                # 'optimizer_state_dict': self.optimizer.state_dict(),
                # 'scheduler_state_dict': self.scheduler.state_dict(),
                # 'best_summary_loss': self.best_summary_loss,
                'epoch': self.epoch,
            }, path)
            # print('discriminator saved, epoch ', self.epoch)

    def load(self, path, model='generator'):
        if model == 'generator':
            checkpoint = torch.load(path)
            self.generator.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer_G.load_state_dict(checkpoint['optimizer_state_dict'])
            # self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            # self.best_summary_loss = checkpoint['best_summary_loss']
            self.loaded_epoch = checkpoint['epoch'] + 1
            print('generator loaded, epoch ', self.loaded_epoch)
        elif model == 'discriminator':
            checkpoint = torch.load(path)
            self.discriminator.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer_D.load_state_dict(checkpoint['optimizer_state_dict'])
            self.loaded_epoch = checkpoint['epoch'] + 1
            print('discriminator loaded, epoch ', self.loaded_epoch)
test_dataloader = DataLoader(ImageDataset(os.path.join(
    opt.path, "Data/%s" % opt.dataset_name),
                                          transforms_=transforms_,
                                          mode="val"),
                             batch_size=1,
                             shuffle=False,
                             num_workers=1)

# Initialize generator and discriminator
cuda = True if torch.cuda.is_available() else False

# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

generator = GeneratorUNet()

if cuda:
    generator = torch.nn.DataParallel(generator).cuda()

generator.load_state_dict(
    torch.load(os.path.join(opt.path, 'trained_model/generator.pth')))
generator.eval()


def main():
    elapse = []
    os.makedirs(os.path.join(opt.path, 'test_results/gt/'), exist_ok=True)
    os.makedirs(os.path.join(opt.path, 'test_results/image/'), exist_ok=True)
    os.makedirs(os.path.join(opt.path, 'test_results/pred_map/'),
                exist_ok=True)