Esempio n. 1
0
 def init_trainer(self):
     # networks
     self.G = Generator(nc=self.nc, nz=self.nz, size=self.size)
     self.D = Discriminator(nc=self.nc, nz=self.nz, size=self.size)
     self.G_EMA = copy.deepcopy(self.G)
     # move to GPU
     self.G = nn.DataParallel(self.G, device_ids=self.device_ids).to(self.device)
     self.D = nn.DataParallel(self.D, device_ids=self.device_ids).to(self.device)
     self.G_EMA = self.G_EMA.to('cpu') # keep this model on CPU to save GPU memory
     for param in self.G_EMA.parameters():
         param.requires_grad_(False) # turn off grad because G_EMA will only be used for inference
     # optimizers
     self.opt_G = optim.Adam(self.G.parameters(), lr=self.lr, betas=(0,0.99), eps=1e-8, weight_decay=0.)
     self.opt_D = optim.Adam(self.D.parameters(), lr=self.lr, betas=(0,0.99), eps=1e-8, weight_decay=0.)
     # data loader
     self.transform = transforms.Compose([
         RatioCenterCrop(1.),
         transforms.Resize((300,300), Image.ANTIALIAS),
         transforms.RandomCrop((self.size,self.size)),
         RandomRotate(),
         transforms.RandomVerticalFlip(),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor()
     ])
     self.dataset = ISIC_GAN('train_gan.csv', transform=self.transform)
     self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size,
         shuffle=True, num_workers=8, worker_init_fn=_worker_init_fn_(), drop_last=True)
     # tickers (used for fading in)
     self.tickers = self.unit_epoch * self.num_aug * len(self.dataloader)
Esempio n. 2
0
def train():
    RDN = Generator("RDN")
    D = Discriminator("discriminator")
    HR = tf.placeholder(tf.float32, [None, 96, 96, 3])
    LR = tf.placeholder(tf.float32, [None, 24, 24, 3])
    SR = RDN(LR)
    fake_logits = D(SR, LR)
    real_logits = D(HR, LR)
    D_loss, G_loss = Hinge_Loss(fake_logits, real_logits)
    G_loss += MSE(SR, HR) * LAMBDA
    itr = tf.Variable(MAX_ITERATION, dtype=tf.int32, trainable=False)
    learning_rate = tf.Variable(2e-4, trainable=False)
    op_sub = tf.assign_sub(itr, 1)
    D_opt = tf.train.AdamOptimizer(learning_rate, beta1=0., beta2=0.9).minimize(D_loss, var_list=D.var_list())
    with tf.control_dependencies([op_sub]):
        G_opt = tf.train.AdamOptimizer(learning_rate, beta1=0., beta2=0.9).minimize(G_loss, var_list=RDN.var_list())
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    while True:
        HR_data, LR_data = read_crop_data(TRAINING_SET_PATH, BATCH_SIZE, [96, 96, 3], 4)
        sess.run(D_opt, feed_dict={HR: HR_data, LR: LR_data})
        [_, iteration] = sess.run([G_opt, itr], feed_dict={HR: HR_data, LR: LR_data})
        iteration = MAX_ITERATION - iteration
        if iteration < MAX_ITERATION // 2:
            learning_rate = learning_rate * (iteration * 2 / MAX_ITERATION)
        if iteration % 10 == 0:
            [D_LOSS, G_LOSS, LEARNING_RATE, img] = sess.run([D_loss, G_loss, learning_rate, SR], feed_dict={HR: HR_data, LR: LR_data})
            output = (np.concatenate((HR_data[0, :, :, :], img[0, :, :, :]), axis=1) + 1) * 127.5
            Image.fromarray(np.uint8(output)).save(RESULTS+str(iteration)+".jpg")
            print("Iteration: %d, D_loss: %f, G_loss: %f, LearningRate: %f"%(iteration, D_LOSS, G_LOSS, LEARNING_RATE))
        if iteration % 500 == 0:
            saver.save(sess, SAVE_MODEL + "model.ckpt")
Esempio n. 3
0
    def __init__(self, device):
        self.device = device        
        self.netG_A = self.__init_weights(Generator(3, use_dropout=False).to(self.device))
        self.netG_B = self.__init_weights(Generator(3, use_dropout=False).to(self.device))
        self.netD_A = self.__init_weights(Discriminator(3).to(self.device))
        self.netD_B = self.__init_weights(Discriminator(3).to(self.device))
        
        self.criterion_gan = nn.MSELoss()
        self.criterion_cycle = nn.L1Loss()
        self.criterion_idt = nn.L1Loss()

        self.optimizer_G = torch.optim.Adam(itertools.chain(
            self.netG_A.parameters(), self.netG_B.parameters()), lr=0.0002, betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(itertools.chain(
            self.netD_A.parameters(), self.netD_B.parameters()), lr=0.0002, betas=(0.5, 0.999))
        self.optimizers = [self.optimizer_G, self.optimizer_D]

        self.fake_A_pool = ImagePool(50)
        self.fake_B_pool = ImagePool(50)

        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']

        self.lambda_A = 10
        self.lambda_B = 10
        self.lambda_idt = 0.5

        self.save_dir = './models'
Esempio n. 4
0
    def __init__(self, config):
        super(InpaintModel, self).__init__('InpaintModel', config)

        #generator = Generator()
        generator = Generator_SE()
        discriminator = Discriminator(in_channels=3, use_sigmoid=False)

        # data = torch.load('ablation_v0/InpaintModel121_gen.pth')
        # generator.load_state_dict(data['generator'])
        # self.iteration = data['iteration']
        # data = torch.load('ablation_v0/InpaintModel121_dis.pth')
        # discriminator.load_state_dict(data['discriminator'])

        l1_loss = nn.L1Loss()
        adversarial_loss = AdversarialLoss(type='hinge')
        perceptual_loss = PerceptualLoss()
        style_loss = StyleLoss()

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)
        self.add_module('l1_loss', l1_loss)
        self.add_module('adversarial_loss', adversarial_loss)
        self.add_module('perceptual_loss', perceptual_loss)
        self.add_module('style_loss', style_loss)

        self.optimizer = optim.Adam(params=generator.parameters(),
                                    lr=config.LR)
        self.dis_optimizer = optim.Adam(params=discriminator.parameters(),
                                        lr=config.LR * config.D2G_LR)
Esempio n. 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--learning-rate', '-lr', type=float, default=1e-3)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--data-parallel', action='store_true')
    parser.add_argument('--num-d-iterations', type=int, default=1)
    args = parser.parse_args()
    args.cuda = torch.cuda.is_available() and not args.no_cuda
    print(args)

    device = torch.device('cuda' if args.cuda else 'cpu')

    net_g = Generator(ch=128).to(device)
    net_d = Discriminator(ch=128).to(device)

    optim_g = optim.Adam(
        net_g.parameters(), lr=args.learning_rate, betas=(0.5, 0.999))
    optim_d = optim.Adam(
        net_d.parameters(), lr=args.learning_rate, betas=(0.5, 0.999))

    dataloader = get_cat_dataloader()

    trainer = Trainer(net_g, net_d, optim_g, optim_d, dataloader, device,
                      args.num_d_iterations)

    os.makedirs('samples', exist_ok=True)

    trainer.train(args.epochs)
Esempio n. 6
0
    def __init__(self, config):
        super(InpaintingModel, self).__init__()
        self.name = 'InpaintingModel'
        self.config = config
        self.iteration = 0
        self.gen_weights_path = os.path.join(config.save_model_dir, 'InpaintingModel_gen.pth')
        self.dis_weights_path = os.path.join(config.save_model_dir, 'InpaintingModel_dis.pth')

        self.generator = FRRNet()
        self.discriminator = Discriminator(in_channels=3, use_sigmoid=True)

        if torch.cuda.device_count() > 1:
            device_ids=range(torch.cuda.device_count())
            self.generator = nn.DataParallel(self.generator, device_ids)
            self.discriminator = nn.DataParallel(self.discriminator , device_ids)

        self.l1_loss = nn.L1Loss()
        self.l2_loss = nn.MSELoss()
        self.style_loss = StyleLoss()
        self.adversarial_loss = AdversarialLoss()  

        self.gen_optimizer = optim.Adam(
            params=self.generator.parameters(),
            lr=float(config.LR),
            betas=(0.0, 0.9)
        )

        self.dis_optimizer = optim.Adam(
            params=self.discriminator.parameters(),
            lr=float(config.LR) * float(config.D2G_LR),
            betas=(0.0, 0.9)
        )
Esempio n. 7
0
    def __init__(self, args):

        self.z_dim = args.z_dim
        self.decay_rate = args.decay_rate
        self.learning_rate = args.learning_rate
        self.model_name = args.model_name
        self.batch_size = args.batch_size

        #initialize networks
        self.Generator = Generator(self.z_dim).cuda()
        self.Encoder = Encoder(self.z_dim).cuda()
        self.Discriminator = Discriminator().cuda()

        #set optimizers for all networks
        self.optimizer_G_E = torch.optim.Adam(
            list(self.Generator.parameters()) +
            list(self.Encoder.parameters()),
            lr=self.learning_rate,
            betas=(0.5, 0.999))

        self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(),
                                            lr=self.learning_rate,
                                            betas=(0.5, 0.999))

        #initialize network weights
        self.Generator.apply(weights_init)
        self.Encoder.apply(weights_init)
        self.Discriminator.apply(weights_init)
Esempio n. 8
0
    def __init__(self, config):
        super().__init__('EdgeModel', config)

        # generator input: [rgb(3) + edge(1)]
        # discriminator input: (rgb(3) + edge(1))
        generator = EdgeGenerator(use_spectral_norm=True)
        discriminator = Discriminator(
            in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge')  #4-->2

        if len(config.GPU) > 1:
            generator = nn.DataParallel(generator, config.GPU)
            discriminator = nn.DataParallel(discriminator, config.GPU)

        l1_loss = nn.L1Loss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        self.gen_optimizer = optim.Adam(params=generator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))

        self.dis_optimizer = optim.Adam(params=discriminator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))
Esempio n. 9
0
    def __init__(self, batch_size=4, input_channels=3, use_multiple_gpu=False,
                       learning_rate=1e-4,
                       model_path='model', device='cuda:0', mode='train', train_dataset_dir='data_scene_flow/training', 
                       val_dataset_dir='data_scene_flow/testing', num_workers=4, do_augmentation=True,
                       output_directory='outputs',
                       input_height=256, input_width=512, augment_parameters=[0.8, 1.2, 0.5, 2.0, 0.8, 1.2]):

        self.batch_size = batch_size
        self.input_channels = input_channels
        self.model_path = model_path
        self.device = device
        self.use_multiple_gpu = use_multiple_gpu

        self.g_LL = Resnet50_md(self.input_channels).to(self.device)
        self.d_R = Discriminator(self.input_channels).to(self.device)

        if self.use_multiple_gpu:
            self.g_LL = torch.nn.DataParallel(self.g_LL)
            self.d_R = torch.nn.DataParallel(self.d_R)

        self.learning_rate=learning_rate
        self.mode = mode
        self.input_height = input_height
        self.input_width = input_width
        self.augment_parameters = augment_parameters
        self.train_dataset_dir = train_dataset_dir
        self.val_dataset_dir = val_dataset_dir
        self.g_best_val_loss = float('inf')
        self.num_workers = num_workers
        self.do_augmentation = do_augmentation

        if self.mode == 'train':
            self.criterion_GAN = HcGANLoss().to(self.device)
            self.criterion_mono = MonodepthLoss()
            
            self.optimizer = optim.Adam(
                chain(
                    self.g_LL.parameters()
                ),
                lr=self.learning_rate
            )
            self.val_n_img, self.val_loader = prepare_dataloader(self.val_dataset_dir, self.mode, self.augment_parameters,
                                                False, self.batch_size,
                                                (self.input_height, self.input_width),
                                                self.num_workers, shuffle=False)
        else:
            self.augment_parameters = None
            self.do_augmentation = False

        self.n_img, self.loader = prepare_dataloader(self.train_dataset_dir, self.mode,
                                                    self.augment_parameters,
                                                    self.do_augmentation, self.batch_size,
                                                    (self.input_height, self.input_width),
                                                    self.num_workers)
        self.output_directory = output_directory
        if 'cuda' in self.device:
            torch.cuda.synchronize()
Esempio n. 10
0
    def build_model(self):
        self.G = AdaINGen(self.label_dim, self.gen_var)
        self.D = Discriminator(self.label_dim, self.dis_var)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.lr,
                                            [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.lr,
                                            [self.beta1, self.beta2])

        self.G.to(self.device)
        self.D.to(self.device)
Esempio n. 11
0
    def __init__(self, hyperparameters):
        super(Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        # auto-encoder for domain a
        self.trait_dim = hyperparameters['gen']['trait_dim']

        self.gen_a = VAEGen(hyperparameters['input_dim'],
                            hyperparameters['basis_encoder_dims'],
                            hyperparameters['trait_encoder_dims'],
                            hyperparameters['decoder_dims'], self.trait_dim)
        # auto-encoder for domain b
        self.gen_b = VAEGen(hyperparameters['input_dim'],
                            hyperparameters['basis_encoder_dims'],
                            hyperparameters['trait_encoder_dims'],
                            hyperparameters['decoder_dims'], self.trait_dim)
        # discriminator for domain a
        self.dis_a = Discriminator(hyperparameters['input_dim'],
                                   hyperparameters['dis_dims'], 1)
        # discriminator for domain b
        self.dis_b = Discriminator(hyperparameters['input_dim'],
                                   hyperparameters['dis_dims'], 1)

        # fix the noise used in sampling
        self.trait_a = torch.randn(8, self.trait_dim, 1, 1)
        self.trait_b = torch.randn(8, self.trait_dim, 1, 1)

        # Setup the optimizers
        dis_params = list(self.dis_a.parameters()) + \
            list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + \
            list(self.gen_b.parameters())
        for _p in gen_params:
            print(_p.data.shape)
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.gen_a.apply(weights_init('gaussian'))
        self.gen_b.apply(weights_init('gaussian'))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))
Esempio n. 12
0
def train():
    generator = Generator()
    discriminator = Discriminator()
    generator.to("cuda:0")
    discriminator.to("cuda:0")
    Opt_D = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    Opt_G = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    data = np.concatenate((sio.loadmat("D:/cifar10/data_batch_1.mat")["data"],
                           sio.loadmat("D:/cifar10/data_batch_2.mat")["data"],
                           sio.loadmat("D:/cifar10/data_batch_3.mat")["data"],
                           sio.loadmat("D:/cifar10/data_batch_4.mat")["data"],
                           sio.loadmat("D:/cifar10/data_batch_5.mat")["data"]))
    nums = data.shape[0]

    for i in range(100000):
        rand_idx = np.random.choice(range(nums), batchsize)
        batch = np.reshape(data[rand_idx], [batchsize, 3, 32, 32])
        batch = torch.tensor(batch / 127.5 - 1,
                             dtype=torch.float32).to("cuda:0")
        for j in range(n_cri):
            z = torch.randn(batchsize, 128).to("cuda:0")
            fake_img = generator(z).detach()
            fake_logits = discriminator(fake_img)
            real_logits = discriminator(batch)
            D_loss = torch.mean(
                torch.max(torch.zeros_like(real_logits),
                          1. - real_logits)) + torch.mean(
                              torch.max(torch.zeros_like(fake_logits),
                                        1. + fake_logits))
            Opt_D.zero_grad()
            D_loss.backward()
            Opt_D.step()
        z = torch.randn(batchsize, 128).to("cuda:0")
        fake_img = generator(z)
        fake_logits = discriminator(fake_img)
        G_loss = -torch.mean(fake_logits)
        Opt_G.zero_grad()
        G_loss.backward()
        Opt_G.step()
        if i % 100 == 0:
            img = (fake_img[0] + 1) * 127.5
            Image.fromarray(
                np.uint8(
                    np.transpose(img.cpu().detach().numpy(),
                                 axes=[1, 2, 0]))).save("./results/" + str(i) +
                                                        ".jpg")
            print("Iteration: %d, D_loss: %f, G_loss: %f" %
                  (i, D_loss, G_loss))
        if i % 1000 == 0:
            torch.save(generator, "generator.pth")
            torch.save(discriminator, "discriminator.pth")
Esempio n. 13
0
    def __init__(self, hyperparameters):
        super(STGANtrainer, self).__init__()
        self.hyperparameters = hyperparameters
        self.gen = Generator(5, 5, 4, 13)
        self.dis = Discriminator(5, 64, 13)

        self.dis_opt = Adam(self.dis.parameters(),
                            lr=self.hyperparameters['lr_dis'],
                            betas=self.hyperparameters['lr_beta'])
        self.gen_opt = Adam(self.gen.parameters(),
                            lr=self.hyperparameters['lr_gen'],
                            betas=self.hyperparameters['lr_beta'])
        self.dis_attr_opt = Adam(self.gen.parameters(),
                                 lr=0.5 * self.hyperparameters['lr_gen'],
                                 betas=self.hyperparameters['lr_beta'])
Esempio n. 14
0
def main():
    opt = get_opt()
    #     opt.cuda = False
    #     opt.batch_size = 1
    #     opt.name = "test"
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)
    board = SummaryWriter(log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    G = WUTON(opt)
    D = Discriminator()
    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):  # TODO
        load_checkpoint(G, opt.checkpoint)
    train(opt, train_loader, G, D, board)
    # train2(opt, train_loader, G, board)
    save_checkpoint(
        G, os.path.join(opt.checkpoint_dir, opt.name, 'wuton_final.pth'))

    print('Finished training %s, named: %s!' % (opt.stage, opt.name))
Esempio n. 15
0
def init():
    for dirname in os.listdir(state_dict_path):
        discriminator = Discriminator()
        discriminator.load_state_dict(
            torch.load(state_dict_path + '/' + dirname + '/discriminator'))
        generator = Generator()
        generator.load_state_dict(
            torch.load(state_dict_path + '/' + dirname + '/generator'))
        network = {
            'id': dirname,
            'discriminator': discriminator,
            'generator': generator,
        }
        networks.append(network)
        network_dict[dirname] = network
        print('Model #' + dirname + ' loaded.')
Esempio n. 16
0
    def __init__(self, generator_path, discriminator_path):

        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')

        # Path for loading and saving model weights
        self.generator_path = generator_path
        self.discriminator_path = discriminator_path

        # Get training and testing data
        train_set = PegasusDataset('data',
                                   train=True,
                                   download=True,
                                   transform=torchvision.transforms.Compose(
                                       [torchvision.transforms.ToTensor()]))

        test_set = PegasusDataset('data',
                                  train=False,
                                  download=True,
                                  transform=torchvision.transforms.Compose(
                                      [torchvision.transforms.ToTensor()]))

        self.train_loader = torch.utils.data.DataLoader(train_set,
                                                        shuffle=True,
                                                        batch_size=BATCH_SIZE,
                                                        drop_last=True)
        self.test_loader = torch.utils.data.DataLoader(test_set,
                                                       shuffle=True,
                                                       batch_size=BATCH_SIZE,
                                                       drop_last=True)

        # Create generator and discriminator
        self.G = Generator().to(self.device)
        self.D = Discriminator().to(self.device)

        #self.loadModels(self.generator_path, self.discriminator_path)

        # initialise the optimiser
        self.optimiser_G = torch.optim.Adam(self.G.parameters(),
                                            lr=0.0002,
                                            betas=(0.5, 0.99))
        self.optimiser_D = torch.optim.Adam(self.D.parameters(),
                                            lr=0.0002,
                                            betas=(0.5, 0.99))
        self.bce_loss = nn.BCELoss()

        self.dgRatio = 0
    def create_discriminator(self):
        kernels_dis = [
            (64, 2, 0),  # [batch, 32, 32, ch] => [batch, 16, 16, 64]
            (128, 2, 0),  # [batch, 16, 16, 64] => [batch, 8, 8, 128]
            (256, 2, 0),  # [batch, 8, 8, 128] => [batch, 4, 4, 256]
            (512, 1, 0),  # [batch, 4, 4, 256] => [batch, 4, 4, 512]
        ]

        return Discriminator('dis', kernels_dis)
Esempio n. 18
0
 def build_model(self):
     """ DataLoader """
     train_transform = transforms.Compose([
         transforms.RandomHorizontalFlip(),
         transforms.Resize((self.img_size + 30, self.img_size + 30)),
         transforms.RandomCrop(self.img_size),
         transforms.ToTensor(),
         transforms.Normalize(mean=0.5, std=0.5)
     ])
     test_transform = transforms.Compose([
         transforms.Resize((self.img_size, self.img_size)),
         transforms.ToTensor(),
         transforms.Normalize(mean=0.5, std=0.5)
     ])
     self.trainA_loader = paddle.batch(
         a_reader(shuffle=True, transforms=train_transform),
         self.batch_size)()
     self.trainB_loader = paddle.batch(
         b_reader(shuffle=True, transforms=train_transform),
         self.batch_size)()
     self.testA_loader = a_test_reader(transforms=test_transform)
     self.testB_loader = b_test_reader(transforms=test_transform)
     """ Define Generator, Discriminator """
     self.genA2B = ResnetGenerator(input_nc=3,
                                   output_nc=3,
                                   ngf=self.ch,
                                   n_blocks=self.n_res,
                                   img_size=self.img_size,
                                   light=self.light)
     self.genB2A = ResnetGenerator(input_nc=3,
                                   output_nc=3,
                                   ngf=self.ch,
                                   n_blocks=self.n_res,
                                   img_size=self.img_size,
                                   light=self.light)
     self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
     self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
     self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
     self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
     """ Define Loss """
     self.L1_loss = L1Loss()
     self.MSE_loss = MSELoss()
     self.BCE_loss = BCEWithLogitsLoss()
     """ Trainer """
     self.G_optim = self.optimizer_setting(self.genA2B.parameters() +
                                           self.genB2A.parameters())
     self.D_optim = self.optimizer_setting(self.disGA.parameters() +
                                           self.disGB.parameters() +
                                           self.disLA.parameters() +
                                           self.disLB.parameters())
     """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
     self.Rho_clipper = RhoClipper(0, 1)
Esempio n. 19
0
File: model.py Progetto: xlnwel/cv
    def _build_graph(self):
        with tf.device('/CPU: 0'):
            self.image = self._prepare_data()
        gen_args = self.args['generator']
        gen_args['batch_size'] = self.batch_size
        self.generator = Generator('Generator', 
                                    gen_args, 
                                    self.graph, 
                                    self.training,
                                    scope_prefix= self.name, 
                                    log_tensorboard=self.log_tensorboard,
                                    log_params=self.log_params)
        self.gen_image = self.generator.image
        dis_args = self.args['discriminator']
        self.real_discriminator = Discriminator('Discriminator', 
                                                dis_args, 
                                                self.graph, 
                                                self.image,
                                                False,
                                                self.training,
                                                scope_prefix= self.name,
                                                log_tensorboard=self.log_tensorboard,
                                                log_params=self.log_params)
        self.fake_discriminator = Discriminator('Discriminator',
                                                dis_args,
                                                self.graph,
                                                self.gen_image,
                                                False,
                                                self.training,
                                                scope_prefix=self.name,
                                                log_tensorboard=False,
                                                log_params=False,
                                                reuse=True)
        
        self.gen_loss = self._generator_loss()
        self.dis_loss = self._discriminator_loss()

        self.gen_opt_op, _, _ = self.generator._optimization_op(self.gen_loss)
        self.dis_opt_op, _, _ = self.real_discriminator._optimization_op(self.dis_loss)
        
        with tf.device('/CPU: 0'):
            self._log_train_info()
Esempio n. 20
0
    def __init__(self, hp, class_emb_vis, class_emb_all):
        super(OurModel, self).__init__()
        self.hp = hp

        self.Em_vis = nn.Embedding.from_pretrained(class_emb_vis).cuda()
        self.Em_vis.weight.requires_grad = False
        self.Em_all = nn.Embedding.from_pretrained(class_emb_all).cuda()
        self.Em_all.weight.requires_grad = False

        self.prior = np.ones((hp['dis']['out_dim_cls'] - 1))
        for k in range(hp['dis']['out_dim_cls'] - hp['num_unseen'] - 1,
                       hp['dis']['out_dim_cls'] - 1):
            self.prior[k] = self.prior[k] + hp['gen_unseen_rate']
        self.prior_ = self.prior / np.linalg.norm(self.prior, ord=1)

        self.gen = Generator(hp['gen'])
        self.dis = Discriminator(hp['dis'])
        self.back = DeepLabV2_ResNet101_local_MSC(hp['back'])

        self.discLoss, self.contentLoss, self.clsLoss = init_loss(hp)
Esempio n. 21
0
    def create_discriminator(self):
        kernels_dis = [
            (64, 2, 0),  # [batch, 256, 256, ch] => [batch, 128, 128, 64]
            (128, 2, 0),  # [batch, 128, 128, 64] => [batch, 64, 64, 128]
            (256, 2, 0),  # [batch, 64, 64, 128] => [batch, 32, 32, 256]
            (512, 1, 0),  # [batch, 32, 32, 256] => [batch, 32, 32, 512]
        ]

        return Discriminator('dis',
                             kernels_dis,
                             training=self.options.training)
Esempio n. 22
0
    def __init__(self, config):
        super().__init__('SRModel', config)

        # generator input: [gray(1) + edge(1)]
        # discriminator input: [gray(1)]
        generator = SRGenerator()
        discriminator = Discriminator(
            in_channels=1, use_sigmoid=config.GAN_LOSS != 'hinge')  # 3-->1

        if len(config.GPU) > 1:
            generator = nn.DataParallel(generator, config.GPU)
            discriminator = nn.DataParallel(discriminator, config.GPU)

        l1_loss = nn.L1Loss()
        content_loss = ContentLoss()
        style_loss = StyleLoss()
        adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)

        kernel = np.zeros((self.config.SCALE, self.config.SCALE))
        kernel[0, 0] = 1
        #kernel_weight = torch.tensor(np.tile(kernel, (3, 1, 1, 1))).float().to(config.DEVICE)     # (out_channels, in_channels/groups, height, width)

        #self.add_module('scale_kernel', kernel_weight)
        #self.scale_kernel = torch.tensor(np.tile(kernel, (1, 1, 1, 1))).float().to(config.DEVICE)  #3-->1

        self.add_module('generator', generator)
        self.add_module('discriminator', discriminator)

        self.add_module('l1_loss', l1_loss)
        self.add_module('content_loss', content_loss)
        self.add_module('style_loss', style_loss)
        self.add_module('adversarial_loss', adversarial_loss)

        self.gen_optimizer = optim.Adam(params=generator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))

        self.dis_optimizer = optim.Adam(params=discriminator.parameters(),
                                        lr=float(config.LR),
                                        betas=(config.BETA1, config.BETA2))
Esempio n. 23
0
def train(path):
    imgs = get_training_imgs(path)
    nums = len(imgs)
    Gs = []
    Ds = []
    fixed_Zs = []
    sigmas = []
    ch = 16
    for i in range(nums):
        if i % 4 == 0:
            ch = ch * 2
        G = Generator(ch)
        D = Discriminator(ch)
        G.to("cuda:0")
        D.to("cuda:0")
        if i > 0:
            try:
                G.load_state_dict(G_.state_dict())
                D.load_state_dict(D_.state_dict())
                del G_, D_
            except:
                pass
        Gs.append(G)
        Ds.append(D)
        print(".............Total Scale: %d, current scale: %d............."%(nums, i+1))
        G_, D_ = train_single_scale(Gs, Ds, imgs[:i+1], sigmas, fixed_Zs)
    state_dict = {}
    state_dict["Gs"] = Gs
    state_dict["sigmas"] = sigmas
    state_dict["imgs"] = imgs
    torch.save(state_dict, path[:-3]+"pth")
Esempio n. 24
0
    def __init__(self, dataset_dir, generator_channels, discriminator_channels, nz, style_depth, lrs, betas, eps,
                 phase_iter, weights_halflife, batch_size, n_cpu, opt_level):
        self.nz = nz
        self.dataloader = Dataloader(dataset_dir, batch_size, phase_iter * 2, n_cpu)

        self.generator = Generator(generator_channels, nz, style_depth).cuda()
        self.generator_ema = Generator(generator_channels, nz, style_depth).cuda()
        self.generator_ema.load_state_dict(copy.deepcopy(self.generator.state_dict()))
        self.discriminator = Discriminator(discriminator_channels).cuda()

        self.tb = tensorboard.tf_recorder('StyleGAN')

        self.phase_iter = phase_iter
        self.lrs = lrs
        self.betas = betas
        self.weights_halflife = weights_halflife

        self.opt_level = opt_level

        self.ema = None

        torch.backends.cuda.benchmark = True
Esempio n. 25
0
    def build_model(self):
        """ DataLoader """
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((self.img_size + 30, self.img_size+30)),
            transforms.RandomCrop(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        test_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform)
        self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform)
        self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform)
        self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform)
        self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True)
        self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True)
        self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False)
        self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False)

        """ Define Generator, Discriminator """
        self.genA2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch,
                                      n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device)
        self.genB2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=self.ch,
                                      n_blocks=self.n_res, img_size=self.img_size, light=self.light).to(self.device)
        self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
        self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device)
        self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
        self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)

        """ Define Loss """
        self.L1_loss = nn.L1Loss().to(self.device)
        self.MSE_loss = nn.MSELoss().to(self.device)
        self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)

        """ Trainer """
        self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()),
                                        lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)
        self.D_optim = torch.optim.Adam(itertools.chain(self.disGA.parameters(), self.disGB.parameters(),
                                        self.disLA.parameters(), self.disLB.parameters()),
                                        lr=self.lr, betas=(0.5, 0.999), weight_decay=self.weight_decay)

        """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
        self.Rho_clipper = RhoClipper(0, 1)
Esempio n. 26
0
    def __init__(self, args, max_steps):
        super().__init__()
        self.hparams = args
        self.max_steps = max_steps
        self.save_hyperparameters(args)

        # Define Generator, Discriminator
        self.genA2B = ResnetGenerator(img_size=self.hparams.img_size)
        self.genB2A = ResnetGenerator(img_size=self.hparams.img_size)
        self.disGA = Discriminator(n_layers=7)
        self.disGB = Discriminator(n_layers=7)
        self.disLA = Discriminator(n_layers=5)
        self.disLB = Discriminator(n_layers=5)

        # Define Loss
        self.L1_loss = nn.L1Loss()
        self.MSE_loss = nn.MSELoss()
        self.BCE_loss = nn.BCEWithLogitsLoss()

        # Define Rho clipper to constraint the value of rho in AdaILN and ILN
        self.Rho_clipper = RhoClipper(0, 1)
Esempio n. 27
0
    def __init__(self, dataset_dir, log_dir, generator_channels,
                 discriminator_channels, nz, style_depth, lrs, betas, eps,
                 phase_iter, batch_size, n_cpu, opt_level):
        self.nz = nz
        self.dataloader = Dataloader(dataset_dir, batch_size, phase_iter * 2,
                                     n_cpu)

        self.generator = cuda(
            DataParallel(Generator(generator_channels, nz, style_depth)))
        self.discriminator = cuda(
            DataParallel(Discriminator(discriminator_channels)))

        self.tb = tensorboard.tf_recorder('StyleGAN', log_dir)

        self.phase_iter = phase_iter
        self.lrs = lrs
        self.betas = betas

        self.opt_level = opt_level
Esempio n. 28
0
 def build_model(self):
     '''DataLoader'''
     gl._init()
     gl.set_value('rho', 0)
     l2 = fluid.regularizer.L2Decay(self.weight_decay)
     self.train_reader, self.test_reader = reader(self.batch_size)
     self.genA2B = ResnetGenerator(in_channels=3,
                                   out_channels=3,
                                   ngf=self.ch,
                                   n_blocks=self.n_res,
                                   img_size=self.img_size,
                                   light=self.light)
     self.genB2A = ResnetGenerator(in_channels=3,
                                   out_channels=3,
                                   ngf=self.ch,
                                   n_blocks=self.n_res,
                                   img_size=self.img_size,
                                   light=self.light)
     self.disGA = Discriminator(in_channels=3, ndf=self.ch, n_layers=7)
     self.disGB = Discriminator(in_channels=3, ndf=self.ch, n_layers=7)
     self.disLA = Discriminator(in_channels=3, ndf=self.ch, n_layers=5)
     self.disLB = Discriminator(in_channels=3, ndf=self.ch, n_layers=5)
     self.clip = fluid.clip.GradientClipByValue(1,
                                                0,
                                                need_clip=self.fileter_func)
     self.G_opt = fluid.optimizer.Adam(
         learning_rate=self.lr1,
         beta1=0.5,
         beta2=0.999,
         regularization=l2,
         parameter_list=self.genA2B.parameters() + self.genB2A.parameters())
     self.D_opt = fluid.optimizer.Adam(
         learning_rate=self.lr2,
         beta1=0.5,
         beta2=0.999,
         regularization=l2,
         parameter_list=self.disGA.parameters() + self.disGB.parameters() +
         self.disLA.parameters() + self.disLB.parameters())
     self.L1loss = fluid.dygraph.L1Loss()
     self.BCELoss = fluid.dygraph.BCELoss()
Esempio n. 29
0
    def __init__(self, batch_size, noise_dim=100, version="CGAN"):
        super(CGAN, self).__init__(batch_size, version)

        self.noise_dim = noise_dim

        self.data = datamanager('CT', train_ratio=0.8, expand_dim=3, seed=0)
        self.data_test = self.data(self.batch_size,
                                   'test',
                                   var_list=['data', 'labels'])
        self.class_num = self.data.class_num

        self.Generator = Generator(output_dim=1, name='G')
        self.Discriminator = Discriminator(name='D')

        self.build_placeholder()
        self.build_gan()
        self.build_optimizer()
        self.build_summary()

        self.build_sess()
        self.build_dirs()
Esempio n. 30
0
class Trainer:
    def __init__(self, arg, device, device_ids):
        print("\ninitializing trainer ...\n")
        # network architecture
        self.nc = arg.nc
        self.nz = arg.nz
        self.init_size = arg.init_size
        self.size = arg.size
        # training
        self.batch_size = arg.batch_size
        self.unit_epoch = arg.unit_epoch
        self.lambda_gp  = arg.lambda_gp
        self.lambda_drift = arg.lambda_drift
        self.num_aug = arg.num_aug
        self.lr = arg.lr
        self.outf = arg.outf
        self.device = device
        self.device_ids = device_ids
        self.writer = SummaryWriter(self.outf)
        self.init_trainer()
        print("done\n")
    def init_trainer(self):
        # networks
        self.G = Generator(nc=self.nc, nz=self.nz, size=self.size)
        self.D = Discriminator(nc=self.nc, nz=self.nz, size=self.size)
        self.G_EMA = copy.deepcopy(self.G)
        # move to GPU
        self.G = nn.DataParallel(self.G, device_ids=self.device_ids).to(self.device)
        self.D = nn.DataParallel(self.D, device_ids=self.device_ids).to(self.device)
        self.G_EMA = self.G_EMA.to('cpu') # keep this model on CPU to save GPU memory
        for param in self.G_EMA.parameters():
            param.requires_grad_(False) # turn off grad because G_EMA will only be used for inference
        # optimizers
        self.opt_G = optim.Adam(self.G.parameters(), lr=self.lr, betas=(0,0.99), eps=1e-8, weight_decay=0.)
        self.opt_D = optim.Adam(self.D.parameters(), lr=self.lr, betas=(0,0.99), eps=1e-8, weight_decay=0.)
        # data loader
        self.transform = transforms.Compose([
            RatioCenterCrop(1.),
            transforms.Resize((300,300), Image.ANTIALIAS),
            transforms.RandomCrop((self.size,self.size)),
            RandomRotate(),
            transforms.RandomVerticalFlip(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        self.dataset = ISIC_GAN('train_gan.csv', transform=self.transform)
        self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size,
            shuffle=True, num_workers=8, worker_init_fn=_worker_init_fn_(), drop_last=True)
        # tickers (used for fading in)
        self.tickers = self.unit_epoch * self.num_aug * len(self.dataloader)
    def update_trainer(self, stage, inter_ticker):
        if stage == 1:
            current_alpha = 0
        else:
            total_stages = int(math.log2(self.size/self.init_size)) + 1
            assert stage <= total_stages, 'Invalid stage number!'
            flag_opt = False
            delta = 1. / self.tickers
            if inter_ticker == 0:
                self.G.module.grow_network()
                self.D.module.grow_network()
                self.G_EMA.grow_network()
                flag_opt = True
            elif (inter_ticker > 0) and (inter_ticker < self.tickers):
                self.G.module.model.fadein.update_alpha(delta)
                self.D.module.model.fadein.update_alpha(delta)
                self.G_EMA.model.fadein.update_alpha(delta)
                flag_opt = False
            elif inter_ticker == self.tickers:
                self.G.module.flush_network()
                self.D.module.flush_network()
                self.G_EMA.flush_network()
                flag_opt = True
            else:
                flag_opt = False;
            # archive alpha
            try:
                current_alpha = self.G.module.model.fadein.get_alpha()
            except:
                current_alpha = 1
            # move to devie & update optimizer
            if flag_opt:
                self.G.to(self.device)
                self.D.to(self.device)
                self.G_EMA.to('cpu')
                # opt_G
                opt_G_state_dict = self.opt_G.state_dict()
                old_opt_G_state = opt_G_state_dict['state']
                self.opt_G = optim.Adam(self.G.parameters(), lr=self.lr, betas=(0,0.99), eps=1e-8, weight_decay=0.)
                new_opt_G_param_id =  self.opt_G.state_dict()['param_groups'][0]['params']
                opt_G_state = copy.deepcopy(old_opt_G_state)
                for key in old_opt_G_state.keys():
                    if key not in new_opt_G_param_id:
                        del opt_G_state[key]
                opt_G_state_dict['param_groups'] = self.opt_G.state_dict()['param_groups']
                opt_G_state_dict['state'] = opt_G_state
                self.opt_G.load_state_dict(opt_G_state_dict)
                # opt_D
                opt_D_state_dict = self.opt_D.state_dict()
                old_opt_D_state = opt_D_state_dict['state']
                self.opt_D = optim.Adam(self.D.parameters(), lr=self.lr, betas=(0,0.99), eps=1e-8, weight_decay=0.)
                new_opt_D_param_id =  self.opt_D.state_dict()['param_groups'][0]['params']
                opt_D_state = copy.deepcopy(old_opt_D_state)
                for key in old_opt_D_state.keys():
                    if key not in new_opt_D_param_id:
                        del opt_D_state[key]
                opt_D_state_dict['param_groups'] = self.opt_D.state_dict()['param_groups']
                opt_D_state_dict['state'] = opt_D_state
                self.opt_D.load_state_dict(opt_D_state_dict)
        return current_alpha
    def update_moving_average(self, decay=0.999):
        # update exponential running average (EMA) for the weights of the generator
        # W_EMA_t = decay * W_EMA_{t-1} + (1-decay) * W_G
        with torch.no_grad():
            param_dict_G = dict(self.G.module.named_parameters())
            for name, param_EMA in self.G_EMA.named_parameters():
                param_G = param_dict_G[name]
                assert (param_G is not param_EMA)
                param_EMA.copy_(decay * param_EMA + (1. - decay) * param_G.detach().cpu())
    def update_network(self, real_data):
        # switch to training mode
        self.G.train(); self.D.train()
        ##########
        ## Train Discriminator
        ##########
        # clear grad cache
        self.D.zero_grad()
        self.opt_D.zero_grad()
        # D loss - real data
        pred_real = self.D(real_data)
        loss_real = pred_real.mean().mul(-1.)
        loss_real_drift = pred_real.pow(2.).mean()
        # D loss - fake data
        z = torch.FloatTensor(real_data.size(0), self.nz).normal_(0.0, 1.0).to(self.device)
        fake_data = self.G(z)
        pred_fake = self.D(fake_data.detach())
        loss_fake = pred_fake.mean()
        # D loss - gradient penalty
        gp = self.gradient_penalty(real_data, fake_data)
        # update D
        D_loss = loss_real + loss_fake + self.lambda_drift * loss_real_drift + self.lambda_gp * gp
        W_dist = loss_real.item() + loss_fake.item()
        D_loss.backward()
        self.opt_D.step()
        ##########
        ## Train Generator
        ##########
        # clear grad cache
        self.G.zero_grad()
        self.opt_G.zero_grad()
        # G loss
        z = torch.FloatTensor(real_data.size(0), self.nz).normal_(0.0, 1.0).to(self.device)
        fake_data = self.G(z)
        pred_fake = self.D(fake_data)
        # update G
        G_loss = pred_fake.mean().mul(-1.)
        G_loss.backward()
        self.opt_G.step()
        return [G_loss.item(), D_loss.item(), W_dist]
    def gradient_penalty(self, real_data, fake_data):
        alpha = torch.rand(real_data.size(0),1,1,1).to(self.device)
        interpolates = alpha * real_data.detach() + (1 - alpha) * fake_data.detach()
        interpolates.requires_grad_(True)
        disc_interpolates = self.D(interpolates)
        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
            grad_outputs=torch.ones_like(disc_interpolates).to(self.device), create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = gradients.norm(2, dim=1).sub(1.).pow(2.).mean()
        return gradient_penalty
    def train(self):
        global_step = 0
        global_epoch = 0
        total_stages = int(math.log2(self.size/self.init_size)) + 1
        fixed_z = torch.FloatTensor(self.batch_size, self.nz).normal_(0.0, 1.0).to('cpu')
        for stage in range(1, total_stages+1):
            if stage == 1:
                M = self.unit_epoch
            elif stage <= 4:
                M = self.unit_epoch * 2
            else:
                M = self.unit_epoch * 3
            current_size = self.init_size * (2 ** (stage-1))
            ticker = 0
            for epoch in range(M):
                torch.cuda.empty_cache()
                for aug in range(self.num_aug):
                    for i, data in enumerate(self.dataloader, 0):
                        current_alpha = self.update_trainer(stage, ticker)
                        self.writer.add_scalar('archive/current_alpha', current_alpha, global_step)
                        real_data_current = data
                        real_data_current = F.adaptive_avg_pool2d(real_data_current, current_size)
                        if stage > 1 and current_alpha < 1:
                            real_data_previous = F.interpolate(F.avg_pool2d(real_data_current, 2), scale_factor=2., mode='nearest')
                            real_data = (1 - current_alpha) * real_data_previous + current_alpha * real_data_current
                        else:
                            real_data = real_data_current
                        real_data = real_data.mul(2.).sub(1.) # [0,1] --> [-1,1]
                        real_data = real_data.to(self.device)
                        G_loss, D_loss, W_dist = self.update_network(real_data)
                        self.update_moving_average()
                        if i % 10 == 0:
                            self.writer.add_scalar('train/G_loss', G_loss, global_step)
                            self.writer.add_scalar('train/D_loss', D_loss, global_step)
                            self.writer.add_scalar('train/W_dist', W_dist, global_step)
                            print("[stage {}/{}][epoch {}/{}][aug {}/{}][iter {}/{}] G_loss {:.4f} D_loss {:.4f} W_Dist {:.4f}" \
                                .format(stage, total_stages, epoch+1, M, aug+1, self.num_aug, i+1, len(self.dataloader), G_loss, D_loss, W_dist))
                        global_step += 1
                        ticker += 1
                global_epoch += 1
                if epoch % 10 == 9:
                    # log image
                    print('\nlog images...\n')
                    I_real = utils.make_grid(real_data, nrow=4, normalize=True, scale_each=True)
                    self.writer.add_image('stage_{}/real'.format(stage), I_real, epoch)
                    with torch.no_grad():
                        self.G_EMA.eval()
                        fake_data = self.G_EMA(fixed_z)
                        I_fake = utils.make_grid(fake_data, nrow=4, normalize=True, scale_each=True)
                        self.writer.add_image('stage_{}/fake'.format(stage), I_fake, epoch)
                    # save checkpoints
                    print('\nsaving checkpoints...\n')
                    checkpoint = {
                        'G_state_dict': self.G.module.state_dict(),
                        'G_EMA_state_dict': self.G_EMA.state_dict(),
                        'D_state_dict': self.D.module.state_dict(),
                        'opt_G_state_dict': self.opt_G.state_dict(),
                        'opt_D_state_dict': self.opt_D.state_dict(),
                        'stage': stage
                    }
                    torch.save(checkpoint, os.path.join(self.outf,'stage{}.tar'.format(stage))) # overwrite if exist