Example #1
0
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

        self.patch_stride = [
            float(4), float(8)
        ]  # Receptive field stride given the current discriminator architecture for background stage
        self.n_out = [
            24, 23
        ]  # Output size of the discriminator at the background stage; N X N
        self.recp_field = [34, 74
                           ]  # Receptive field of each of the member of N X N
        self.crop_size = [126, 250]
Example #2
0
    def __init__(self, output_dir):
        self.sample_transfer_Stage2Gen = None
        self.sample_transfer_Stage1Gen = None

        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
        if cfg.CUDA:
            s_gpus = ['2']  # cfg.GPU_ID.split(',')
            self.gpus = [int(ix) for ix in s_gpus]
            self.num_gpus = len(self.gpus)
            self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
            torch.cuda.set_device(self.gpus[0])
            print('AAAAAAAAAAAAAAAA ' + cfg.GPU_ID)
            cudnn.benchmark = True
        else:
            s_gpus = cfg.GPU_ID.split(',')
            self.gpus = [int(ix) for ix in s_gpus]
            self.num_gpus = len(self.gpus)
            self.batch_size = cfg.TRAIN.BATCH_SIZE
            cudnn.benchmark = False
Example #3
0
def main(exercise: str = "Aufgabe-2"):

    with open("./config.{}.json".format(exercise), mode="r") as f:
        args = json.load(f,
                         object_hook=lambda d: namedtuple("X", d.keys())
                         (*d.values()))

    environment = gym.make(args.env)

    input_shape = environment.observation_space.shape
    num_actions = environment.action_space.n

    output_directory = "./tmp/{}/{}".format(exercise, datetime.datetime.now())

    writer = FileWriter(output_directory)

    agent = make_agent(args, input_shape, num_actions, output_directory)

    rewards = []
    for episode in range(args.episodes):
        episode_rewards = run_episode(
            environment,
            agent,
            render=episode % args.render_episode_interval == 0,
            max_length=args.max_episode_length,
        )
        rewards.append(episode_rewards)
        if episode % args.training_interval == 0:
            for _ in range(args.training_interval):
                loss = agent.train()

            if loss and episode % (args.training_interval * 10) == 0:
                mean_rewards = np.mean(rewards)
                std_rewards = np.std(rewards)
                writer.add_summary(summary=summary.scalar("dqn/loss", loss),
                                   global_step=episode)
                writer.add_summary(
                    summary=summary.scalar("rewards/mean", mean_rewards),
                    global_step=episode,
                )
                writer.add_summary(
                    summary=summary.scalar("rewards/standard deviation",
                                           std_rewards),
                    global_step=episode,
                )
                writer.add_summary(
                    summary=summary.scalar("dqn/epsilon",
                                           agent.exploration_strategy.epsilon),
                    global_step=episode,
                )
                print("Episode {}\tMean rewards {:f}\tLoss {:f}\tEpsilon {:f}".
                      format(episode, mean_rewards, loss,
                             agent.exploration_strategy.epsilon))
                rewards.clear()
Example #4
0
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True
        
        # load fasttext embeddings (e.g., birds.en.vec)
        path = os.path.join(cfg.DATA_DIR, cfg.DATASET_NAME + ".en.vec")
        txt_dico, _txt_emb = load_external_embeddings(path)
        txt_emb = nn.Embedding(len(txt_dico), 300, sparse=False)
        txt_emb.weight.data.copy_(_txt_emb)
        txt_emb.weight.requires_grad = False
        self.txt_dico = txt_dico
        self.txt_emb = txt_emb
        
        # load networks and evaluator
        self.networks = self.load_network()
        self.evaluator = Evaluator(self.networks, self.txt_emb)
        
        # visualizer to visdom server
        self.vis = Visualizer(cfg.VISDOM_HOST, cfg.VISDOM_PORT, output_dir)
        self.vis.make_img_window("real_im")
        self.vis.make_img_window("fake_im")
        self.vis.make_txt_window("real_captions")
        self.vis.make_txt_window("genr_captions")        
        self.vis.make_plot_window("G_loss", num=7, 
                                  legend=["errG", "uncond", "cond", "latent", "cycltxt", "autoimg", "autotxt"])
        self.vis.make_plot_window("D_loss", num=4, 
                                  legend=["errD", "uncond", "cond", "latent"])
        self.vis.make_plot_window("KL_loss", num=4, 
                                  legend=["kl", "img", "txt", "fakeimg"])
        
        self.vis.make_plot_window("inception_score", num=2,
                                 legend=["real", "fake"])
        self.vis.make_plot_window("r_precision", num=1)
Example #5
0
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True
Example #6
0
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True
Example #7
0
 def __init__(self, output_dir, dataloaders):
     self.trainloader, self.validloader, self.testloader = dataloaders
     self.train_num = int(cfg.Total.node_num * cfg.Data.frac)
     self.test_num = cfg.Total.node_num - self.train_num
     self.train_clients = [node.Client(k, self.trainloader[k], self.validloader[k]) for k in range(self.train_num)]
     self.test_list = [node.Client(j + self.train_num, self.trainloader[j], self.validloader[j]) for j in
                       range(self.test_num)]
     self.model_dir = os.path.join(output_dir, 'Model')
     self.log_dir = os.path.join(output_dir, 'Log')
     self.server = node.Server(self.testloader)
     mkdir_p(self.model_dir)
     mkdir_p(self.log_dir)
     self.summary_writer = FileWriter(self.log_dir)
     self.recorder = Recorder()
Example #8
0
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        #path = "../data/birds/birds.en.vec"
        path = os.path.join(cfg.DATA_DIR, cfg.DATASET_NAME + ".en.vec")
        txt_dico, _txt_emb = load_external_embeddings(path)
        #params.src_dico = src_dico
        txt_emb = nn.Embedding(len(txt_dico), 300, sparse=False)
        txt_emb.weight.data.copy_(_txt_emb)
        txt_emb.weight.requires_grad = False
        self.txt_dico = txt_dico
        self.txt_emb = txt_emb

        self.vis = visdom.Visdom(server='http://bvisionserver9.cs.unc.edu',
                                 port=8088,
                                 env="birds_spv2")
        self.vis_win1 = self.vis.images(np.ones((64, 3, 64, 64)))
        self.vis_win2 = self.vis.images(np.ones((64, 3, 64, 64)))
        self.vis_win3 = self.vis.images(np.ones((64, 3, 64, 64)))
        self.vis_txt1 = self.vis.text('')
Example #9
0
    def __init__(self, output_dir, experiment_cfg, model_cfg):
        if cfg.TRAIN.FLAG:
            print('Creating the directories...')
            self.outputdir = output_dir
            self.model_dir = os.path.join(output_dir, 'Model')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.log_dir)
            with open(os.path.join(output_dir, 'config.yaml'),
                      'w') as currect_cfg_file:
                yaml.dump(cfg, currect_cfg_file, default_flow_style=False)

            self.summary_writer = FileWriter(self.log_dir)
            self.model = BiLSTMClassifier.from_configs(experiment_cfg,
                                                       model_cfg)
            if cfg.CUDA:
                self.model.cuda()

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
        self.snapshot_interval_validation = cfg.TRAIN.SNAPSHOT_INTERVAL_VALIDATION
        self.checkpoint_step = None

        self.__init_gpus(experiment_cfg)
Example #10
0
class GANTrainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    # ############# For training stageII GAN  #############
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = \
                torch.load(cfg.STAGE1_G,
                           map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict)
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = torch.FloatTensor([1])
        fake_labels = real_labels * -1
        wrong_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        # real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        # fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels, wrong_labels = real_labels.cuda(
            ), fake_labels.cuda(), wrong_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = optim.RMSprop(netD.parameters(), lr=discriminator_lr)
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.RMSprop(netG_para, lr=generator_lr)
        # optimizerD = \
        #     optim.Adam(netD.parameters(),
        #                lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        # netG_para = []
        # for p in netG.parameters():
        #     if p.requires_grad:
        #         netG_para.append(p)
        # optimizerG = optim.Adam(netG_para,
        #                         lr=cfg.TRAIN.GENERATOR_LR,
        #                         betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)
                ############################
                # (3) Update D network
                ###########################
                for p in netD.parameters():
                    p.data.clamp_(-0.01, 0.01)
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                real_labels, fake_labels,wrong_labels,
                                                mu, self.gpus)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward(retain_graph=True)
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' % (epoch, self.max_epoch, i, len(data_loader),
                         errD.data[0], errG.data[0], kl_loss.data[0],
                         errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()

    def sample(self, datapath, stage=1):
        if stage == 1:
            netG, _ = self.load_network_stageI()
        else:
            netG, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        t_file = torchfile.load(datapath)
        captions_list = t_file.raw_txt
        embeddings = np.concatenate(t_file.fea_txt, axis=0)
        num_embeddings = len(captions_list)
        print('Successfully load sentences from: ', datapath)
        print('Total number of sentences:', num_embeddings)
        print('num_embeddings:', num_embeddings, embeddings.shape)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')]
        mkdir_p(save_dir)

        batch_size = np.minimum(num_embeddings, self.batch_size)
        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        if cfg.CUDA:
            noise = noise.cuda()
        count = 0
        while count < num_embeddings:
            if count > 3000:
                break
            iend = count + batch_size
            if iend > num_embeddings:
                iend = num_embeddings
                count = num_embeddings - batch_size
            embeddings_batch = embeddings[count:iend]
            # captions_batch = captions_list[count:iend]
            txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
            if cfg.CUDA:
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            inputs = (txt_embedding, noise)
            _, fake_imgs, mu, logvar = \
                nn.parallel.data_parallel(netG, inputs, self.gpus)
            for i in range(batch_size):
                save_name = '%s/%d.png' % (save_dir, count + i)
                im = fake_imgs[i].data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                # print('im', im.shape)
                im = np.transpose(im, (1, 2, 0))
                # print('im', im.shape)
                im = Image.fromarray(im)
                im.save(save_name)
            count += batch_size

"""## Hyperparameters"""

SAVEPATH = '/test/'
WEIGHTDECAY = 5e-4  # 별로 영향을 안준다.
MOMENTUM = 0.9
BATCHSIZE = 128
LR = 0.1
EPOCHS = 200
PRINTFREQ = 10

from torchviz import make_dot
from torch.autograd import Variable
"""## Train Model"""
summary = FileWriter('runs/graph')


def main():
    model = nn.DataParallel(PyramidNet())
    #model.module.load_state_dict(torch.load(SAVEPATH+'149model_weight.pth'))
    #model.train()
    ##### optimizer / learning rate scheduler / criterion #####
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=LR,
                                momentum=MOMENTUM,
                                weight_decay=WEIGHTDECAY,
                                nesterov=True)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [0], gamma=0.1)
    criterion = torch.nn.CrossEntropyLoss()
    ###########################################################
Example #12
0
 def __init__(self, **kwargs):
     super().__init__(**kwargs)
     self.writer = FileWriter(self.log_dir)
Example #13
0
class WHAI_GAN_Trainer(object):
    def __init__(self, output_dir, data_loader):
        self.model_dir = os.path.join(output_dir, 'Model')
        self.image_dir = os.path.join(output_dir, 'Image')
        self.log_dir = os.path.join(output_dir, 'Log')
        mkdir_p(self.model_dir)
        mkdir_p(self.image_dir)
        mkdir_p(self.log_dir)
        self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        real_vimgs, wrong_vimgs = [], []
        imgs, texts, w_imgs, _ = data
        if cfg.CUDA:
            vtxts = Variable(texts).cuda()
        else:
            vtxts = Variable(texts)
        for i in xrange(3):
            if cfg.CUDA:
                real_vimgs.append(Variable(imgs[i]).cuda())
                wrong_vimgs.append(Variable(w_imgs[i]).cuda())
            else:
                real_vimgs.append(Variable(imgs[i]))
                wrong_vimgs.append(Variable(w_imgs[i]))
        return imgs, vtxts, real_vimgs, wrong_vimgs

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_tgpu[0].size(0)
        criterion, mu = self.criterion_1, self.mu_theta1

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_tgpu[idx]
        wrong_imgs = self.wrong_tgpu[idx]
        fake_imgs = self.fake_imgs[idx]

        netD.zero_grad()

        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        # for real
        real_logits = netD(real_imgs, mu.detach())
        wrong_logits = netD(wrong_imgs, mu.detach())
        fake_logits = netD(fake_imgs.detach(), mu.detach())

        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(wrong_logits[1], real_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(fake_logits[1], fake_labels)

            errD_real = errD_real + errD_real_uncond
            errD_wrong = errD_wrong + errD_wrong_uncond
            errD_fake = errD_fake + errD_fake_uncond

            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)
        # backward
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, float(errD.data[0]))
            self.summary_writer.add_summary(summary_D, count)
        return float(errD)

    def train_Gnet(self, count):
        errG_total = 0
        flag = count % 100
        batch_size = self.real_tgpu[0].size(0)
        criterion_1, mu = self.criterion_1, self.mu_theta1

        params = [
            self.shape1, self.scale1, self.Phi1, self.theta1, self.txtbow
        ]
        criterion, optEnG, netG = self.criterion_2, self.optimizerEnG, self.netG
        loss, theta1_KL, Likelihood, p1, p2, p3, shape1, scale1 = criterion(
            params)

        real_labels = self.real_labels[:batch_size]
        for i in xrange(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], mu)
            errG = criterion_1(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                             criterion_1(outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_D, count)

        errG_total += loss
        optEnG.zero_grad()

        errG_total.backward()

        optEnG.step()
        return float(errG_total
                     ), theta1_KL, Likelihood, loss, p1, p2, p3, shape1, scale1

    def updatePhi(self, miniBatch, Phi, Theta, MBratio, MBObserved, NDot):
        Xt = miniBatch
        Xt_to_t1, WSZS = PGBN_sampler.Multrnd_Matrix(Xt.astype('double'),
                                                     Phi.astype('double'),
                                                     Theta.astype('double'))

        EWSZS = WSZS
        EWSZS = MBratio * EWSZS

        if (MBObserved == 0):
            NDot = EWSZS.sum(0)
        else:
            NDot = (1 - self.ForgetRate[MBObserved]
                    ) * NDot + self.ForgetRate[MBObserved] * EWSZS.sum(0)
        tmp = EWSZS + self.eta
        tmp = (1 / NDot) * (tmp - tmp.sum(0) * Phi)
        tmp1 = (2 / NDot) * Phi

        tmp = Phi + self.epsit[MBObserved] * tmp + np.sqrt(
            self.epsit[MBObserved] * tmp1) * np.random.randn(
                Phi.shape[0], Phi.shape[1])
        Phi = PGBN_sampler.ProjSimplexSpecial(tmp, Phi, 0)

        return Phi, NDot

    def train(self):
        self.netG, self.netsD, self.netIMG, self.inception_model,\
        self.num_Ds, start_count = load_network(self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizersD, self.optimizerEnG = \
            define_optimizers(self.netG, self.netsD, self.netIMG)

        self.criterion_1 = nn.BCELoss()
        self.criterion_2 = myLoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        # Prepare PHI
        real_min = np.float64(2.2e-308)
        Phi1 = 0.2 + 0.8 * np.float64(np.random.rand(1000, 256))
        Phi1 = Phi1 / np.maximum(real_min, Phi1.sum(0))
        if cfg.CUDA:
            self.Phi1 = Variable(Phi1).cuda()
            self.criterion_1.cuda()
            self.criterion_2.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // self.num_batches

        batch_length = self.num_batches
        self.NDot = 0
        self.ForgetRate = np.power(
            (0 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length),
                             cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7)
        self.eta = 0.1
        epsit = np.power(
            (20 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length),
                              cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7)
        self.epsit = 1 * epsit / epsit[0]

        num_total_samples = batch_length * self.batch_size

        start_t = time.time()

        for epoch in xrange(start_epoch, self.max_epoch):
            LL = 0
            KL = 0
            LS = 0
            p1 = 0
            p2 = 0
            p3 = 0
            DL = 0
            GL = 0
            shape = []
            scale = []
            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.img_tcpu, self.txtbow, self.real_tgpu, self.wrong_tgpu = self.prepare_data(
                    data)

                #######################################################
                # (1) Get conv hidden units
                ######################################################
                _, self.flat = self.inception_model(self.real_tgpu[-1])

                #######################################################
                # (2) Get shape, scale and sample of theta
                ######################################################
                self.theta1, self.shape1, self.scale1 = self.netIMG(self.flat)

                shape1 = self.shape1.detach().cpu().numpy()
                scale1 = self.scale1.detach().cpu().numpy()
                mu_theta1 = scale1 * ss.gamma(1 + 1 / shape1)
                self.mu_theta1 = torch.tensor(mu_theta1, dtype=torch.float32)
                #######################################################
                # (3) Generate fake images
                ######################################################
                self.fake_imgs, _ = self.netG(self.mu_theta1.detach())

                #######################################################
                # (4) Update D network
                ######################################################
                errD_total = 0
                for i in xrange(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (5) Update G network (or En network): maximize log(D(G(z)))
                ######################################################
                errG_total, self.KL1, self.LL, self.LS, self.p1, self.p2, self.p3, shape1, scale1 = self.train_Gnet(
                    count)
                LL += self.LL
                KL += self.KL1
                LS += self.LS
                p1 += self.p1
                p2 += self.p2
                p3 += self.p3
                shape.append(shape1)
                scale.append(scale1)

                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                DL += errD_total
                GL += errG_total

                #######################################################
                # (6) Update Phi
                #######################################################
                input_txt = np.array(np.transpose(self.txtbow.cpu().numpy()),
                                     order='C').astype('double')
                Phi1 = np.array(self.Phi1.cpu().numpy(),
                                order='C').astype('double')
                Theta1 = np.array(np.transpose(self.theta1.cpu().numpy()),
                                  order='C').astype('double')
                phi1, self.NDot = self.updatePhi(input_txt, Phi1, Theta1,
                                                 int(batch_length), count,
                                                 self.NDot)
                self.Phi1 = torch.tensor(phi1, dtype=torch.float32).cuda()

                # for inception score
                pred, _ = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total)
                    summary_G = summary.scalar('G_loss', errG_total)
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netIMG, self.netG, avg_param_G, self.netsD,
                               epoch, count, self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)

                    self.fake_imgs, _ = \
                        self.netG(self.theta1)

                    save_img_results(self.img_tcpu, self.fake_imgs,
                                     self.num_Ds, count, self.image_dir,
                                     self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []
                count += 1

            end_t = time.time()
            LS = LS / num_total_samples
            LL = LL / num_total_samples
            KL = KL / num_total_samples
            DL = DL / num_total_samples
            GL = GL / num_total_samples
            print(
                'Epoch: %d/%d,   Time elapsed: %.4fs\n'
                '* Batch Train Loss: %.6f          (LL: %.6f, KL: %.6f, Loss_D:'
                '%.2f Loss_G: %.2f)\n' %
                (epoch, self.max_epoch, end_t - start_t, LS, LL, KL, DL, GL))
            start_t = time.time()
            if epoch % 50 == 0:
                save_model(self.netIMG, self.netG, avg_param_G, self.netsD,
                           epoch, count, self.model_dir)
        # save the model at the last updating
        save_model(self.netIMG, self.netG, avg_param_G, self.netsD, epoch,
                   count, self.model_dir)
        self.summary_writer.close()
Example #14
0
class condGANTrainer(object):
    def __init__(self, output_dir, data_loader):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')

            self.image_dir = os.path.join(output_dir, 'Image')

            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        real_vimgs, wrong_vimgs = [], []
        imgs, texts, w_imgs, _ = data
        if cfg.CUDA:
            vtxts = Variable(texts).cuda()
        else:
            vtxts = Variable(texts)
        for i in xrange(len(imgs)):
            if cfg.CUDA:
                real_vimgs.append(Variable(imgs[i]).cuda())
                wrong_vimgs.append(Variable(w_imgs[i]).cuda())
            else:
                real_vimgs.append(Variable(imgs[i]))
                wrong_vimgs.append(Variable(w_imgs[i]))
        return imgs, vtxts, real_vimgs, wrong_vimgs

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_tgpu[0].size(0)
        criterion, c_code = self.criterion, self.c_code[idx // 3]

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_tgpu[int((idx // 3) + idx % 3)]
        wrong_imgs = self.wrong_tgpu[int((idx // 3) + idx % 3)]
        fake_imgs = self.fake_imgs[idx]

        netD.zero_grad()
        # Forward
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        # for real
        real_logits = netD(real_imgs, c_code.detach())
        wrong_logits = netD(wrong_imgs, c_code.detach())
        fake_logits = netD(fake_imgs.detach(), c_code.detach())

        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(wrong_logits[1], real_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(fake_logits[1], fake_labels)

            errD_real = errD_real + errD_real_uncond
            errD_wrong = errD_wrong + errD_wrong_uncond
            errD_fake = errD_fake + errD_fake_uncond

            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)
        # backward
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, float(errD.data[0]))
            self.summary_writer.add_summary(summary_D, count)
        return float(errD)

    def train_Gnet(self, idx, count):
        optG = self.optimizersG[idx]
        optG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_tgpu[0].size(0)
        criterion, c_code = self.criterion, self.c_code[idx]
        real_labels = self.real_labels[:batch_size]
        for i in xrange(len(self.netsG)):
            outputs = self.netsD[idx * 3 + i](self.fake_imgs[idx * 3 + i],
                                              c_code)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\
                    criterion(outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_D, count)

        errG_total = errG_total
        errG_total.backward()
        optG.step()
        return float(errG_total)

    def train_Enet(self, count):
        errEn_total = 0
        flag = count % 100
        params = [
            self.shape_1, self.scale_1, self.shape_2, self.scale_2,
            self.shape_3, self.scale_3, self.Phi[0], self.theta_1, self.Phi[1],
            self.theta_2, self.Phi[2], self.theta_3, self.txtbow
        ]
        criterion, optEn, netEn = self.vae, self.optimizerEn, self.netEn
        loss, theta3_KL, theta2_KL, theta1_KL, Likelihood, Lowerbound = criterion(
            params)
        errEn_total = errEn_total + loss
        netEn.zero_grad()
        # backward
        errEn_total.backward()
        # update parameters
        optEn.step()
        if flag == 0:
            summary_LS = summary.scalar('En_loss', float(loss.data.item()))
            summary_LB = summary.scalar('En_lowerbound',
                                        float(Lowerbound.data.item()))
            summary_LL = summary.scalar('En_likelihood',
                                        float(Likelihood.data.item()))
            summary_KL1 = summary.scalar('En_kl1',
                                         float(theta1_KL.data.item()))
            summary_KL2 = summary.scalar('En_kl2',
                                         float(theta2_KL.data.item()))
            summary_KL3 = summary.scalar('En_kl3',
                                         float(theta3_KL.data.item()))
            self.summary_writer.add_summary(summary_LS, count)
            self.summary_writer.add_summary(summary_LB, count)
            self.summary_writer.add_summary(summary_LL, count)
            self.summary_writer.add_summary(summary_KL1, count)
            self.summary_writer.add_summary(summary_KL2, count)
            self.summary_writer.add_summary(summary_KL3, count)
        return theta1_KL, theta2_KL, theta3_KL, Likelihood, Lowerbound, loss

    def updatePhi(self, miniBatch, Phi, Theta, MBratio, MBObserved):
        real_min = 1e-6
        Xt = miniBatch

        for t in range(len(Phi)):
            if t == 0:
                self.Xt_to_t1[t], self.WSZS[t] = PGBN_sampler.Multrnd_Matrix(
                    Xt.astype('double'), Phi[t], Theta[t])
            else:
                self.Xt_to_t1[t], self.WSZS[
                    t] = PGBN_sampler.Crt_Multirnd_Matrix(
                        self.Xt_to_t1[t - 1], Phi[t], Theta[t])

            self.EWSZS[t] = MBratio * self.WSZS[t]

            if (MBObserved == 0):
                self.NDot[t] = self.EWSZS[t].sum(0)
            else:
                self.NDot[t] = (1 - self.ForgetRate[MBObserved]) * self.NDot[t] + self.ForgetRate[MBObserved] * \
                               self.EWSZS[t].sum(0)

            tmp = self.EWSZS[t] + self.eta[t]
            tmp = (1 / (self.NDot[t] + real_min)) * (tmp - tmp.sum(0) * Phi[t])
            tmp1 = (2 / (self.NDot[t] + real_min)) * Phi[t]
            tmp = Phi[t] + self.epsit[MBObserved] * tmp + np.sqrt(
                self.epsit[MBObserved] * tmp1) * np.random.randn(
                    Phi[t].shape[0], Phi[t].shape[1])
            Phi[t] = PGBN_sampler.ProjSimplexSpecial(tmp, Phi[t], 0)

        return Phi

    def train(self):
        self.netEn, self.netsG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        avg_param_G = []
        for i in xrange(len(self.netsG)):
            avg_param_G.append(copy_G_params(self.netsG[i]))

        self.optimizerEn, self.optimizersG, self.optimizersD = \
            define_optimizers(self.netEn, self.netsG, self.netsD)

        self.criterion = nn.BCELoss()
        self.vae = myLoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        predictions = []
        count = start_count
        start_epoch = start_count // self.num_batches

        batch_length = self.num_batches

        self.Phi = []
        self.eta = []
        K = [256, 128, 64]
        real_min = np.float64(2.2e-308)
        eta = np.ones(3) * 0.1
        for i in range(3):
            self.eta.append(eta[i])
            if i == 0:
                self.Phi.append(0.2 +
                                0.8 * np.float64(np.random.rand(1000, K[i])))
            else:
                self.Phi.append(0.2 + 0.8 *
                                np.float64(np.random.rand(K[i - 1], K[i])))
            self.Phi[i] = self.Phi[i] / np.maximum(real_min,
                                                   self.Phi[i].sum(0))

        self.NDot = [0] * 3
        self.Xt_to_t1 = [0] * 3
        self.WSZS = [0] * 3
        self.EWSZS = [0] * 3

        self.ForgetRate = np.power(
            (0 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length),
                             cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7)
        epsit = np.power(
            (20 + np.linspace(1, cfg.TRAIN.MAX_EPOCH * int(batch_length),
                              cfg.TRAIN.MAX_EPOCH * int(batch_length))), -0.7)
        self.epsit = 1 * epsit / epsit[0]

        num_total_samples = batch_length * self.batch_size

        if cfg.CUDA:
            for i in xrange(len(self.Phi)):
                self.Phi[i] = Variable(torch.from_numpy(
                    self.Phi[i]).float()).cuda()
            self.criterion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()

        for epoch in xrange(start_epoch, self.max_epoch):
            start_t = time.time()
            LL = 0
            KL1 = 0
            KL2 = 0
            KL3 = 0
            LS = 0
            DL = 0
            GL = 0
            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.img_tcpu, self.txtbow, self.real_tgpu, self.wrong_tgpu = self.prepare_data(
                    data)

                #######################################################
                # (1) Get conv hidden units
                ######################################################
                _, self.flat = self.inception_model(self.real_tgpu[-1])

                self.theta_1, self.shape_1, self.scale_1, self.theta_2,\
                self.shape_2, self.scale_2, self.theta_3, self.shape_3,\
                self.scale_3 = self.netEn(self.flat)

                self.txt_embedding = []
                self.txt_embedding.append(self.theta_3.detach())
                self.txt_embedding.append(self.theta_2.detach())
                self.txt_embedding.append(self.theta_1.detach())
                #######################################################
                # (2) Generate fake images
                ######################################################
                tmp = []
                self.c_code = []
                x_embedding = None
                for it in xrange(len(self.netsG)):
                    fake_imgs, c_code, x_embedding = \
                        self.netsG[it](self.txt_embedding[it], x_embedding)
                    tmp.append(fake_imgs)
                    self.c_code.append(c_code)

                self.fake_imgs = []
                for it in xrange(len(tmp)):
                    for jt in xrange(len(tmp[it])):
                        self.fake_imgs.append(tmp[it][jt])

                #######################################################
                # (3) Update En network
                ######################################################
                self.KL1, self.KL2, self.KL3, self.LL, self.LB, self.LS = self.train_Enet(
                    count)
                LL += self.LL
                KL1 += self.KL1
                KL2 += self.KL2
                KL3 += self.KL3
                LS += self.LS

                if count % 100 == 0:
                    print(self.LS)
                    print(self.KL1)
                    print(self.KL2)
                    print(self.KL3)

                #######################################################
                # (4) Update Phi
                #######################################################
                input_txt = np.array(np.transpose(self.txtbow.cpu().numpy()),
                                     order='C').astype('double')
                Phi = []
                theta = []
                self.theta = [self.theta_1, self.theta_2, self.theta_3]
                for i in xrange(len(self.Phi)):
                    Phi.append(
                        np.array(self.Phi[i].cpu().numpy(),
                                 order='C').astype('double'))
                    theta.append(
                        np.array(np.transpose(
                            self.theta[i].detach().cpu().numpy()),
                                 order='C').astype('double'))
                phi = self.updatePhi(input_txt, Phi, theta, int(batch_length),
                                     count)
                for i in xrange(len(phi)):
                    self.Phi[i] = torch.tensor(phi[i],
                                               dtype=torch.float32).cuda()

                #######################################################
                # (5) Update D network
                ######################################################
                errD_total = 0
                for i in xrange(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD
                DL += errD_total

                #######################################################
                # (6) Update G network: maximize log(D(G(z)))
                ######################################################
                errG_total = 0
                for i in xrange(len(self.netsG)):
                    errG = self.train_Gnet(i, count)
                    errG_total += errG
                    for p, avg_p in zip(self.netsG[i].parameters(),
                                        avg_param_G[i]):
                        avg_p.mul_(0.999).add_(0.001, p.data)
                GL += errG_total

                # for inception score
                if cfg.INCEPTION:
                    pred, _ = self.inception_model(self.fake_imgs[-1].detach())
                    predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total)
                    summary_G = summary.scalar('G_loss', errG_total)
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netEn, self.netsG, avg_param_G, self.netsD,
                               epoch, self.model_dir)
                    # Save images
                    backup_para = []
                    for i in xrange(len(self.netsG)):
                        backup_para.append(copy_G_params(self.netsG[i]))
                        load_params(self.netsG[i], avg_param_G[i])

                    x_embedding = None
                    self.fake_imgs = []
                    for it in xrange(len(self.netsG)):
                        fake_imgs, _, x_embedding = self.netsG[it](
                            self.txt_embedding[it], x_embedding)
                        self.fake_imgs.append(fake_imgs[-1])
                    save_img_results(self.img_tcpu, self.fake_imgs,
                                     len(self.netsG), count, self.image_dir)

                    for i in xrange(len(self.netsG)):
                        load_params(self.netsG[i], backup_para[i])

                    if cfg.INCEPTION:
                        # Compute inception score
                        if len(predictions) > 500:
                            predictions = np.concatenate(predictions, 0)
                            mean, std = compute_inception_score(
                                predictions, 10)
                            m_incep = summary.scalar('Inception_mean', mean)
                            self.summary_writer.add_summary(m_incep, count)

                            mean_nlpp, std_nlpp = \
                                negative_log_posterior_probability(predictions, 10)
                            m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                            self.summary_writer.add_summary(m_nlpp, count)

                            predictions = []

                count = count + 1

            end_t = time.time()
            LS = LS / num_total_samples
            LL = LL / num_total_samples
            KL1 = KL1 / num_total_samples
            KL2 = KL2 / num_total_samples
            KL3 = KL3 / num_total_samples
            DL = DL / num_total_samples
            GL = GL / num_total_samples
            print(
                'Epoch: %d/%d,   Time elapsed: %.4fs\n'
                '* Batch Train Loss: %.6f          (LL: %.6f, KL1: %.6f, KL2: %.6f,'
                'KL3: %.6f, Loss_D: %.2f Loss_G: %.2f)\n' %
                (epoch, self.max_epoch, end_t - start_t, LS, LL, KL1, KL2, KL3,
                 DL, GL))
        save_model(self.netEn, self.netsG, avg_param_G, self.netsD, epoch,
                   self.model_dir)
        self.summary_writer.close()
Example #15
0
class FineGAN_trainer(object):
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        fimgs, cimgs, c_code, _, warped_bbox = data

        real_vfimgs, real_vcimgs = [], []
        if cfg.CUDA:
            vc_code = Variable(c_code).cuda()
            for i in range(len(warped_bbox)):
                warped_bbox[i] = Variable(warped_bbox[i]).float().cuda()

        else:
            vc_code = Variable(c_code)
            for i in range(len(warped_bbox)):
                warped_bbox[i] = Variable(warped_bbox[i])

        if cfg.CUDA:
            real_vfimgs.append(Variable(fimgs[0]).cuda())
            real_vcimgs.append(Variable(cimgs[0]).cuda())
        else:
            real_vfimgs.append(Variable(fimgs[0]))
            real_vcimgs.append(Variable(cimgs[0]))

        return fimgs, real_vfimgs, real_vcimgs, vc_code, warped_bbox

    def train_Dnet(self, idx, count):
        if idx == 0 or idx == 2:  # Discriminator is only trained in background and child stage. (NOT in parent stage)
            flag = count % 100
            batch_size = self.real_fimgs[0].size(0)
            criterion, criterion_one = self.criterion, self.criterion_one

            netD, optD = self.netsD[idx], self.optimizersD[idx]
            if idx == 0:
                real_imgs = self.real_fimgs[0]

            elif idx == 2:
                real_imgs = self.real_cimgs[0]

            fake_imgs = self.fake_imgs[idx]
            netD.zero_grad()
            real_logits = netD(real_imgs)

            if idx == 2:
                fake_labels = torch.zeros_like(real_logits[1])
                real_labels = torch.ones_like(real_logits[1])
            elif idx == 0:

                fake_labels = torch.zeros_like(real_logits[1])
                ext, output = real_logits
                weights_real = torch.ones_like(output)
                real_labels = torch.ones_like(output)

                for i in range(batch_size):
                    x1 = self.warped_bbox[0][i]
                    x2 = self.warped_bbox[2][i]
                    y1 = self.warped_bbox[1][i]
                    y2 = self.warped_bbox[3][i]

                    a1 = max(
                        torch.tensor(0).float().cuda(),
                        torch.ceil((x1 - self.recp_field) / self.patch_stride))
                    a2 = min(
                        torch.tensor(self.n_out - 1).float().cuda(),
                        torch.floor((self.n_out - 1) -
                                    ((126 - self.recp_field) - x2) /
                                    self.patch_stride)) + 1
                    b1 = max(
                        torch.tensor(0).float().cuda(),
                        torch.ceil((y1 - self.recp_field) / self.patch_stride))
                    b2 = min(
                        torch.tensor(self.n_out - 1).float().cuda(),
                        torch.floor((self.n_out - 1) -
                                    ((126 - self.recp_field) - y2) /
                                    self.patch_stride)) + 1

                    if (x1 != x2 and y1 != y2):
                        weights_real[
                            i, :,
                            a1.type(torch.int):a2.type(torch.int),
                            b1.type(torch.int):b2.type(torch.int)] = 0.0

                norm_fact_real = weights_real.sum()
                norm_fact_fake = weights_real.shape[0] * weights_real.shape[
                    1] * weights_real.shape[2] * weights_real.shape[3]
                real_logits = ext, output

            fake_logits = netD(fake_imgs.detach())

            if idx == 0:  # Background stage

                errD_real_uncond = criterion(
                    real_logits[1], real_labels
                )  # Real/Fake loss for 'real background' (on patch level)
                errD_real_uncond = torch.mul(
                    errD_real_uncond, weights_real
                )  # Masking output units which correspond to receptive fields which lie within the boundin box
                errD_real_uncond = errD_real_uncond.mean()

                errD_real_uncond_classi = criterion(
                    real_logits[0],
                    weights_real)  # Background/foreground classification loss
                errD_real_uncond_classi = errD_real_uncond_classi.mean()

                errD_fake_uncond = criterion(
                    fake_logits[1], fake_labels
                )  # Real/Fake loss for 'fake background' (on patch level)
                errD_fake_uncond = errD_fake_uncond.mean()

                if (
                        norm_fact_real > 0
                ):  # Normalizing the real/fake loss for background after accounting the number of masked members in the output.
                    errD_real = errD_real_uncond * ((norm_fact_fake * 1.0) /
                                                    (norm_fact_real * 1.0))
                else:
                    errD_real = errD_real_uncond

                errD_fake = errD_fake_uncond
                errD = ((errD_real + errD_fake) *
                        cfg.TRAIN.BG_LOSS_WT) + errD_real_uncond_classi

            if idx == 2:

                errD_real = criterion_one(
                    real_logits[1],
                    real_labels)  # Real/Fake loss for the real image
                errD_fake = criterion_one(
                    fake_logits[1],
                    fake_labels)  # Real/Fake loss for the fake image
                errD = errD_real + errD_fake

            if (idx == 0 or idx == 2):
                errD.backward()
                optD.step()

            if (flag == 0):
                summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
                self.summary_writer.add_summary(summary_D, count)
                summary_D_real = summary.scalar('D_loss_real_%d' % idx,
                                                errD_real.data[0])
                self.summary_writer.add_summary(summary_D_real, count)
                summary_D_fake = summary.scalar('D_loss_fake_%d' % idx,
                                                errD_fake.data[0])
                self.summary_writer.add_summary(summary_D_fake, count)

            return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        for myit in range(len(self.netsD)):
            self.netsD[myit].zero_grad()

        errG_total = 0
        flag = count % 100
        batch_size = self.real_fimgs[0].size(0)
        criterion_one, criterion_class, c_code, p_code = self.criterion_one, self.criterion_class, self.c_code, self.p_code

        for i in range(self.num_Ds):

            outputs = self.netsD[i](self.fake_imgs[i])

            if i == 0 or i == 2:  # real/fake loss for background (0) and child (2) stage
                real_labels = torch.ones_like(outputs[1])
                errG = criterion_one(outputs[1], real_labels)
                if i == 0:
                    errG = errG * cfg.TRAIN.BG_LOSS_WT
                    errG_classi = criterion_one(
                        outputs[0], real_labels
                    )  # Background/Foreground classification loss for the fake background image (on patch level)
                    errG = errG + errG_classi
                errG_total = errG_total + errG

            if i == 1:  # Mutual information loss for the parent stage (1)
                pred_p = self.netsD[i](self.fg_mk[i - 1])
                errG_info = criterion_class(pred_p[0],
                                            torch.nonzero(p_code.long())[:, 1])
            elif i == 2:  # Mutual information loss for the child stage (2)
                pred_c = self.netsD[i](self.fg_mk[i - 1])
                errG_info = criterion_class(pred_c[0],
                                            torch.nonzero(c_code.long())[:, 1])

            if (i > 0):
                errG_total = errG_total + errG_info

            if flag == 0:
                if i > 0:
                    summary_D_class = summary.scalar('Information_loss_%d' % i,
                                                     errG_info.data[0])
                    self.summary_writer.add_summary(summary_D_class, count)

                if i == 0 or i == 2:
                    summary_D = summary.scalar('G_loss%d' % i, errG.data[0])
                    self.summary_writer.add_summary(summary_D, count)

        errG_total.backward()
        for myit in range(len(self.netsD)):
            self.optimizerG[myit].step()
        return errG_total

    def train(self):
        self.netG, self.netsD, self.num_Ds, start_count = load_network(
            self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss(reduce=False)
        self.criterion_one = nn.BCELoss()
        self.criterion_class = nn.CrossEntropyLoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))
        hard_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)).cuda()

        self.patch_stride = float(
            4
        )  # Receptive field stride given the current discriminator architecture for background stage
        self.n_out = 24  # Output size of the discriminator at the background stage; N X N where N = 24
        self.recp_field = 34  # Receptive field of each of the member of N X N

        if cfg.CUDA:
            self.criterion.cuda()
            self.criterion_one.cuda()
            self.criterion_class.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        print("Starting normal FineGAN training..")
        count = start_count
        start_epoch = start_count // (self.num_batches)

        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):

                self.imgs_tcpu, self.real_fimgs, self.real_cimgs, \
                    self.c_code, self.warped_bbox = self.prepare_data(data)

                # Feedforward through Generator. Obtain stagewise fake images
                noise.data.normal_(0, 1)
                self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
                    self.netG(noise, self.c_code)

                # Obtain the parent code given the child code
                self.p_code = child_to_parent(self.c_code,
                                              cfg.FINE_GRAINED_CATEGORIES,
                                              cfg.SUPER_CATEGORIES)

                # Update Discriminator networks
                errD_total = 0
                for i in range(self.num_Ds):
                    if i == 0 or i == 2:  # only at parent and child stage
                        errD = self.train_Dnet(i, count)
                        errD_total += errD

                # Update the Generator networks
                errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                count = count + 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    backup_para = copy_G_params(self.netG)
                    save_model(self.netG, avg_param_G, self.netsD, count,
                               self.model_dir)
                    # Save images
                    load_params(self.netG, avg_param_G)

                    self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
                        self.netG(fixed_noise, self.c_code)
                    save_img_results(self.imgs_tcpu,
                                     (self.fake_imgs + self.fg_imgs +
                                      self.mk_imgs + self.fg_mk), self.num_Ds,
                                     count, self.image_dir,
                                     self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

            end_t = time.time()
            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Time: %.2fs
                      ''' %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data[0],
                   errG_total.data[0], end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)

        print(
            "Done with the normal training. Now performing hard negative training.."
        )
        count = 0
        start_t = time.time()
        for step, data in enumerate(self.data_loader, 0):

            self.imgs_tcpu, self.real_fimgs, self.real_cimgs, \
                self.c_code, self.warped_bbox = self.prepare_data(data)

            if (count % 2) == 0:  # Train on normal batch of images

                # Feedforward through Generator. Obtain stagewise fake images
                noise.data.normal_(0, 1)
                self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
                    self.netG(noise, self.c_code)

                self.p_code = child_to_parent(self.c_code,
                                              cfg.FINE_GRAINED_CATEGORIES,
                                              cfg.SUPER_CATEGORIES)

                # Update discriminator networks
                errD_total = 0
                for i in range(self.num_Ds):
                    if i == 0 or i == 2:
                        errD = self.train_Dnet(i, count)
                        errD_total += errD

                # Update the generator network
                errG_total = self.train_Gnet(count)

            else:  # Train on degenerate images
                repeat_times = 10
                all_hard_z = Variable(
                    torch.zeros(self.batch_size * repeat_times, nz)).cuda()
                all_hard_class = Variable(
                    torch.zeros(self.batch_size * repeat_times,
                                cfg.FINE_GRAINED_CATEGORIES)).cuda()
                all_logits = Variable(
                    torch.zeros(self.batch_size * repeat_times, )).cuda()

                for hard_it in range(repeat_times):
                    hard_noise = hard_noise.data.normal_(0, 1)
                    hard_class = Variable(
                        torch.zeros(
                            [self.batch_size,
                             cfg.FINE_GRAINED_CATEGORIES])).cuda()
                    my_rand_id = []

                    for c_it in range(self.batch_size):
                        rand_class = random.sample(
                            range(cfg.FINE_GRAINED_CATEGORIES), 1)
                        hard_class[c_it][rand_class] = 1
                        my_rand_id.append(rand_class)

                    all_hard_z[self.batch_size * hard_it:self.batch_size *
                               (hard_it + 1)] = hard_noise.data
                    all_hard_class[self.batch_size * hard_it:self.batch_size *
                                   (hard_it + 1)] = hard_class.data
                    self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = self.netG(
                        hard_noise.detach(), hard_class.detach())

                    fake_logits = self.netsD[2](self.fg_mk[1].detach())
                    smax_class = softmax(fake_logits[0], dim=1)

                    for b_it in range(self.batch_size):
                        all_logits[(self.batch_size * hard_it) +
                                   b_it] = smax_class[b_it][my_rand_id[b_it]]

                sorted_val, indices_hard = torch.sort(all_logits)
                noise = all_hard_z[indices_hard[0:self.batch_size]]
                self.c_code = all_hard_class[indices_hard[0:self.batch_size]]

                self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
                    self.netG(noise, self.c_code)

                self.p_code = child_to_parent(self.c_code,
                                              cfg.FINE_GRAINED_CATEGORIES,
                                              cfg.SUPER_CATEGORIES)

                # Update Discriminator networks
                errD_total = 0
                for i in range(self.num_Ds):
                    if i == 0 or i == 2:
                        errD = self.train_Dnet(i, count)
                        errD_total += errD

                # Update generator network
                errG_total = self.train_Gnet(count)

            for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                avg_p.mul_(0.999).add_(0.001, p.data)
            count = count + 1

            if count % cfg.TRAIN.SNAPSHOT_INTERVAL_HARDNEG == 0:
                backup_para = copy_G_params(self.netG)
                save_model(self.netG, avg_param_G, self.netsD, count + 500000,
                           self.model_dir)
                load_params(self.netG, avg_param_G)

                self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
                    self.netG(fixed_noise, self.c_code)
                save_img_results(self.imgs_tcpu,
                                 (self.fake_imgs + self.fg_imgs +
                                  self.mk_imgs + self.fg_mk), self.num_Ds,
                                 count, self.image_dir, self.summary_writer)
                #
                load_params(self.netG, backup_para)

            end_t = time.time()

            if (count % 100) == 0:
                print(
                    '''[%d/%d][%d]
                             Loss_D: %.2f Loss_G: %.2f Time: %.2fs
                          ''' %
                    (count, cfg.TRAIN.HARDNEG_MAX_ITER, self.num_batches,
                     errD_total.data[0], errG_total.data[0], end_t - start_t))

            if (count == cfg.TRAIN.HARDNEG_MAX_ITER
                ):  # Hard negative training complete
                break

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        self.summary_writer.close()
Example #16
0
class GANTrainerSpv(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        #path = "../data/birds/birds.en.vec"
        path = os.path.join(cfg.DATA_DIR, cfg.DATASET_NAME + ".en.vec")
        txt_dico, _txt_emb = load_external_embeddings(path)
        #params.src_dico = src_dico
        txt_emb = nn.Embedding(len(txt_dico), 300, sparse=False)
        txt_emb.weight.data.copy_(_txt_emb)
        txt_emb.weight.requires_grad = False
        self.txt_dico = txt_dico
        self.txt_emb = txt_emb

        self.vis = visdom.Visdom(server='http://bvisionserver9.cs.unc.edu',
                                 port=8088,
                                 env="birds_spv2")
        self.vis_win1 = self.vis.images(np.ones((64, 3, 64, 64)))
        self.vis_win2 = self.vis.images(np.ones((64, 3, 64, 64)))
        self.vis_win3 = self.vis.images(np.ones((64, 3, 64, 64)))
        self.vis_txt1 = self.vis.text('')
        #self.vis_txt2 = self.vis.text('')

    def ind_from_sent(self, caption):
        return [self.txt_dico.SOS_TOKEN] + [
            self.txt_dico.word2id[word]
            if word in self.txt_dico.word2id else self.txt_dico.UNK_TOKEN
            for word in caption.split(" ")
        ] + [self.txt_dico.EOS_TOKEN]

    def pad_seq(self, seq, max_length):
        seq += [self.txt_dico.PAD_TOKEN for i in range(max_length - len(seq))]
        return seq

    def process_captions(self, captions):
        seqs = []
        for i in range(len(captions)):
            seqs.append(self.ind_from_sent(captions[i]))

        input_lengths = [len(s) for s in seqs]
        padded = [self.pad_seq(s, max(input_lengths)) for s in seqs]

        input_var = Variable(torch.LongTensor(padded)).transpose(0, 1)
        lengths = torch.LongTensor(input_lengths)
        if cfg.CUDA:
            input_var = input_var.cuda()
            lengths = lengths.cuda()
        return input_var, lengths

    def add_noise(self, inds, lens):

        pwd = 0.1
        k = 3

        inds = inds.transpose(0, 1)

        mask = torch.rand(inds.size()) > pwd
        if cfg.CUDA: mask = mask.cuda()
        mask = mask

        max_len = inds.size(1)
        masked = torch.masked_select(inds, mask)
        chopped_lens = torch.sum(mask, dim=1)
        i = 0
        seq = []
        for cl in chopped_lens:
            zeros = torch.ones(max_len - cl).long(
            ) * self.txt_dico.PAD_TOKEN  #torch.zeros(max_len - cl).long() # should this be padding?
            zeros = zeros.cuda() if cfg.CUDA else zeros
            seq.append(torch.cat((masked[i:i + cl], zeros)))
            i += cl
        seq = torch.stack(seq)

        # get sequence lengths
        EOS = self.txt_dico.EOS_TOKEN

        seq_lens = []
        eos_inds = torch.nonzero(seq == EOS)

        # in case there are no predicted EOS
        for b_idx in range(lens.size(0)):
            if eos_inds.size() == () or b_idx not in eos_inds[:, 0]:
                app = torch.cuda.LongTensor(
                    [b_idx, 20]) if cfg.CUDA else torch.LongTensor([b_idx, 20])
                eos_inds = torch.cat((eos_inds, app.unsqueeze(0)), dim=0)

        ind = -1
        for s in eos_inds:
            if s[0] != ind:
                if s[1] == 0:  # HACK TO MAKE SENTS WITH EOS FIRST IND HAVE NONZERO LEN
                    seq_lens.append(s[1] + 1)
                else:
                    seq_lens.append(s[1])
                ind = s[0]

        # permute words in window of k
        for b_idx, s in enumerate(seq):
            l = seq_lens[b_idx] - 1  # 1 for EOS
            for i in range(1, l, k):  # skip SOS
                ki = k if i + k < l else l - i
                p = torch.randperm(ki) + i
                p = p.cuda() if cfg.CUDA else p
                seq[b_idx, i:ki + i] = s[p]

        seq_lens = torch.cuda.LongTensor(
            seq_lens) if cfg.CUDA else torch.LongTensor(seq_lens)
        return seq.transpose(0, 1), seq_lens

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        from model import EncoderRNN, LuongAttnDecoderRNN
        from model import STAGE1_ImageEncoder, EncodingDiscriminator

        netG = STAGE1_G()
        netG.apply(weights_init)
        #print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        #print(netD)

        emb_dim = 300
        encoder = EncoderRNN(emb_dim, self.txt_emb, 1, dropout=0.0)

        attn_model = 'general'
        decoder = LuongAttnDecoderRNN(attn_model,
                                      emb_dim,
                                      len(self.txt_dico.id2word),
                                      self.txt_emb,
                                      n_layers=1,
                                      dropout=0.0)

        image_encoder = STAGE1_ImageEncoder()
        image_encoder.apply(weights_init)

        e_disc = EncodingDiscriminator(emb_dim)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.ENCODER != '':
            state_dict = \
                torch.load(cfg.ENCODER,
                           map_location=lambda storage, loc: storage)
            encoder.load_state_dict(state_dict)
            print('Load from: ', cfg.ENCODER)
        if cfg.DECODER != '':
            state_dict = \
                torch.load(cfg.DECODER,
                           map_location=lambda storage, loc: storage)
            decoder.load_state_dict(state_dict)
            print('Load from: ', cfg.DECODER)
        if cfg.IMAGE_ENCODER != '':
            state_dict = \
                torch.load(cfg.IMAGE_ENCODER,
                           map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load from: ', cfg.IMAGE_ENCODER)

        # load classification model and freeze weights
        #clf_model = models.alexnet(pretrained=True)
        clf_model = models.vgg16(pretrained=True)
        for param in clf_model.parameters():
            param.requires_grad = False

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
            encoder.cuda()
            decoder.cuda()
            image_encoder.cuda()
            e_disc.cuda()
            clf_model.cuda()

#         ## finetune model for a bit???
#         output_size = 512 * 2 * 2
#         num_classes = 200
#         clf_model.classifier = nn.Sequential(
#             nn.Linear(output_size, 1024, bias=True),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(0.5),
#             nn.Linear(1024, num_classes, bias=True)
#             )

#         clf_optim = optim.SGD(clf_model.parameters(), lr=1e-2, momentum=0.9)

        return netG, netD, encoder, decoder, image_encoder, e_disc, clf_model

    #
    # do an initial pass for autoencoding and finetuning the classifier model
    #
    def train_initial_step(self, data_loader, dataset):
        netG, netD, encoder, decoder, image_encoder, enc_disc, clf_model = self.load_network_stageI(
        )
        print("to do")

    #
    # train with both autoencoding and cross-domain losses
    #
    def train(self, data_loader, dataset, stage=1):

        netG, netD, encoder, decoder, image_encoder, enc_disc, clf_model = self.load_network_stageI(
        )

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(
            1))  # try discriminator smoothing
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))

        optim_fn, optim_params = get_optimizer("adam,lr=0.001")
        enc_params = filter(lambda p: p.requires_grad, encoder.parameters())
        enc_optimizer = optim_fn(enc_params, **optim_params)
        optim_fn, optim_params = get_optimizer("adam,lr=0.001")
        dec_params = filter(lambda p: p.requires_grad, decoder.parameters())
        dec_optimizer = optim_fn(dec_params, **optim_params)

        # image_enc_optimizer = \
        #     optim.Adam(image_encoder.parameters(),
        #                lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        image_enc_optimizer = \
            optim.SGD(image_encoder.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR)

        enc_disc_optimizer = \
            optim.Adam(enc_disc.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))

        count = 0

        criterionCycle = nn.SmoothL1Loss()
        #criterionCycle = torch.nn.BCELoss()
        semantic_criterion = nn.CosineEmbeddingLoss()

        for epoch in range(self.max_epoch):

            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.75
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.75
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                _, real_img_cpu, _, captions, pred_cap = data

                raw_inds, raw_lengths = self.process_captions(captions)

                # need to fix noise addition
                #inds, lengths = self.add_noise(raw_inds.data, raw_lengths)
                inds, lengths = raw_inds.data, raw_lengths

                inds = Variable(inds)
                lens_sort, sort_idx = lengths.sort(0, descending=True)

                # need to dataparallel the encoders?
                txt_encoder_output = encoder(inds[:, sort_idx],
                                             lens_sort.cpu().numpy(), None)
                encoder_out, encoder_hidden, real_txt_code, real_txt_mu, real_txt_logvar = txt_encoder_output

                real_imgs = Variable(real_img_cpu)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (real_txt_code, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                #######################################################
                # (2b) Decode z from txt and calc auto-encoding loss
                ######################################################
                loss_auto = 0
                auto_dec_inp = Variable(
                    torch.LongTensor([self.txt_dico.SOS_TOKEN] *
                                     self.batch_size))
                auto_dec_inp = auto_dec_inp.cuda(
                ) if cfg.CUDA else auto_dec_inp
                auto_dec_hidden = real_txt_code.unsqueeze(0)

                max_target_length = inds.size(0)

                for t in range(max_target_length):

                    auto_dec_out, auto_dec_hidden, auto_dec_attn = decoder(
                        auto_dec_inp, auto_dec_hidden, encoder_out)

                    loss_auto = loss_auto + F.cross_entropy(
                        auto_dec_out,
                        inds[:, sort_idx][t],
                        ignore_index=self.txt_dico.PAD_TOKEN)
                    auto_dec_inp = inds[:, sort_idx][t]

                loss_auto = loss_auto / lengths.float().sum()

                #######################################################
                # (2c) Decode z from real imgs and calc auto-encoding loss
                ######################################################

                real_img_out = nn.parallel.data_parallel(
                    image_encoder, (real_imgs[sort_idx]), self.gpus)

                real_img_feats, real_img_emb, real_img_code, real_img_mu, real_img_logvar = real_img_out

                noise.data.normal_(0, 1)
                inputs = (real_img_code, noise)
                _, fake_from_real_img, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                loss_img = criterionCycle(F.sigmoid(fake_from_real_img),
                                          F.sigmoid(real_imgs[sort_idx]))

                # loss_img = F.binary_cross_entropy_with_logits(fake_from_real_img.view(batch_size, -1),
                #                                               real_imgs.view(batch_size, -1))

                #######################################################
                # (2c) Decode z from fake imgs and calc cycle loss
                ######################################################

                fake_img_out = nn.parallel.data_parallel(
                    image_encoder, (real_imgs[sort_idx]), self.gpus)

                fake_img_feats, fake_img_emb, fake_img_code, fake_img_mu, fake_img_logvar = fake_img_out
                fake_img_feats = fake_img_feats.transpose(0, 1)

                loss_cd = 0
                cd_dec_inp = Variable(
                    torch.LongTensor([self.txt_dico.SOS_TOKEN] *
                                     self.batch_size))
                cd_dec_inp = cd_dec_inp.cuda() if cfg.CUDA else cd_dec_inp

                cd_dec_hidden = fake_img_code.unsqueeze(0)

                max_target_length = inds.size(0)

                for t in range(max_target_length):

                    cd_dec_out, cd_dec_hidden, cd_dec_attn = decoder(
                        cd_dec_inp, cd_dec_hidden, fake_img_feats)

                    loss_cd = loss_cd + F.cross_entropy(
                        cd_dec_out,
                        inds[:, sort_idx][t],
                        ignore_index=self.txt_dico.PAD_TOKEN)
                    cd_dec_inp = inds[:, sort_idx][t]

                loss_cd = loss_cd / lengths.float().sum()

                loss_dc = criterionCycle(fake_imgs, real_imgs[sort_idx])

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                enc_disc.zero_grad()

                errD = 0

                errD_im, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                     real_labels, fake_labels,
                                                     real_txt_mu, self.gpus)

                # updating discriminator for encoding
                txt_enc_labels = Variable(
                    torch.FloatTensor(batch_size).fill_(0))
                img_enc_labels = Variable(
                    torch.FloatTensor(batch_size).fill_(1))
                if cfg.CUDA:
                    txt_enc_labels = txt_enc_labels.cuda()
                    img_enc_labels = img_enc_labels.cuda()

                disc_real_txt_emb = encoder_hidden[0].detach()
                disc_real_img_emb = real_img_emb.detach()

                pred_txt = enc_disc(disc_real_txt_emb)
                pred_img = enc_disc(disc_real_img_emb)

                enc_disc_loss_txt = F.binary_cross_entropy_with_logits(
                    pred_txt.squeeze(), txt_enc_labels)
                enc_disc_loss_img = F.binary_cross_entropy_with_logits(
                    pred_img.squeeze(), img_enc_labels)

                errD = errD + errD_im + enc_disc_loss_txt + enc_disc_loss_img

                # check NaN
                if (errD != errD).data.any():
                    print("NaN detected (discriminator)")
                    pdb.set_trace()
                    exit()

                errD.backward()

                optimizerD.step()
                enc_disc_optimizer.step()

                ############################
                # (2) Update G network
                ###########################
                encoder.zero_grad()
                decoder.zero_grad()
                netG.zero_grad()
                image_encoder.zero_grad()

                errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                              real_txt_mu, self.gpus)

                img_kl_loss = KL_loss(real_img_mu, real_img_logvar)
                txt_kl_loss = KL_loss(real_txt_mu, real_txt_logvar)
                #f_img_kl_loss = KL_loss(fake_img_mu, fake_img_logvar)

                kl_loss = img_kl_loss + txt_kl_loss  #+ f_img_kl_loss

                #_, disc_hidden_g = encoder(inds[:, sort_idx], lens_sort.cpu().numpy(), None)
                #dg_mu, dg_logvar = nn.parallel.data_parallel(image_encoder, (real_imgs), self.gpus)
                #disc_img_g = torch.cat((dg_mu.unsqueeze(0), dg_logvar.unsqueeze(0)))

                pred_txt_g = enc_disc(encoder_hidden[0])
                pred_img_g = enc_disc(real_img_emb)

                enc_fake_loss_txt = F.binary_cross_entropy_with_logits(
                    pred_img_g.squeeze(), txt_enc_labels)
                enc_fake_loss_img = F.binary_cross_entropy_with_logits(
                    pred_txt_g.squeeze(), img_enc_labels)

                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL + loss_cd + loss_dc + loss_img + loss_auto + enc_fake_loss_txt + enc_fake_loss_img

                # check NaN
                if (errG_total != errG_total).data.any():
                    print("NaN detected (generator)")
                    pdb.set_trace()
                    exit()

                errG_total.backward()

                optimizerG.step()
                image_enc_optimizer.step()
                enc_optimizer.step()
                dec_optimizer.step()

                count = count + 1
                if i % 100 == 0:
                    #                     summary_D = summary.scalar('D_loss', errD.data[0])
                    #                     summary_D_r = summary.scalar('D_loss_real', errD_real)
                    #                     #summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    #                     summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    #                     summary_G = summary.scalar('G_loss', errG.data[0])
                    #                     #summary_KL = summary.scalar('KL_loss', kl_loss.data[0])

                    #                     self.summary_writer.add_summary(summary_D, count)
                    #                     self.summary_writer.add_summary(summary_D_r, count)
                    #                     #self.summary_writer.add_summary(summary_D_w, count)
                    #                     self.summary_writer.add_summary(summary_D_f, count)
                    #                     self.summary_writer.add_summary(summary_G, count)
                    #                     #self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (real_txt_code, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)

                    self.vis.images(normalize(
                        real_imgs[sort_idx].data.cpu().numpy()),
                                    win=self.vis_win1)
                    self.vis.images(normalize(fake_imgs.data.cpu().numpy()),
                                    win=self.vis_win2)
                    self.vis.text("\n*".join(captions), win=self.vis_txt1)
                    if (len(pred_cap)):
                        self.vis.images(normalize(
                            fake_from_fake_img.data.cpu().numpy()),
                                        win=self.vis_win3)

            end_t = time.time()

            prefix = "E%d/%s, %.1fs" % (epoch, time.strftime('D%d %X'),
                                        (end_t - start_t))
            gen_str = "G_all: %.3f Cy_T: %.3f AE_T: %.3f AE_I %.3f KL_T %.3f KL_I %.3f" % (
                errG_total.data[0], loss_cd.data[0], loss_auto.data[0],
                loss_img.data[0], txt_kl_loss.data[0], img_kl_loss.data[0])

            dis_str = "D_all: %.3f D_I: %.3f D_zT: %.3f D_zI: %.3f" % (
                errD.data[0], errD_im.data[0], enc_disc_loss_txt.data[0],
                enc_disc_loss_img.data[0])

            print("%s %s, %s" % (prefix, gen_str, dis_str))

            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, encoder, decoder, image_encoder, epoch,
                           self.model_dir)
        #
        save_model(netG, netD, encoder, decoder, image_encoder, self.max_epoch,
                   self.model_dir)
        #
        self.summary_writer.close()
Example #17
0
class FineGAN_trainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True
        self.subdataset_idx = None

    def prepare_data(self, data):
        fimgs, cimgs, c_code, _, masks, aux_masks = data
        if cfg.CUDA:
            vc_code = Variable(c_code).cuda()
            masks = Variable(masks).cuda()
            aux_masks = Variable(aux_masks).cuda()
            real_vfimgs = Variable(fimgs).cuda()
            real_vcimgs = Variable(cimgs).cuda()
        else:
            vc_code = Variable(c_code)
            masks = masks.detach()
            aux_masks = aux_masks.detach()
            real_vfimgs = Variable(fimgs)
            real_vcimgs = Variable(cimgs)
        return fimgs, real_vfimgs, real_vcimgs, vc_code, masks, aux_masks

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_fimgs.size(0)
        criterion, criterion_one = self.criterion, self.criterion_one

        if idx == 0:
            real_imgs = self.real_fimgs
            fake_imgs = self.fake_imgs[0]
            optD = self.optimizersD[0]

            netD = self.netsD[0]
            netD.zero_grad()
            real_logits = netD(real_imgs, self.alpha, self.masks.detach())
            fake_logits = netD(fake_imgs.detach(), self.alpha, self.aux_masks)
            real_labels = torch.ones_like(real_logits[1])
            fake_labels = torch.zeros_like(real_logits[1])

            errD_real = criterion_one(
                real_logits[1],
                real_labels)  # Real/Fake loss for the real image
            errD_fake = criterion_one(
                fake_logits[1],
                fake_labels)  # Real/Fake loss for the fake image
            errD0 = (errD_real + errD_fake) * cfg.TRAIN.BG_LOSS_WT_GLB

            netD = self.netsD[3]
            netD.zero_grad()

            _fg = self.masks == 0
            rev_masks = torch.zeros_like(self.masks)
            rev_masks.masked_fill_(_fg, 1.0)
            real_logits = netD(real_imgs, self.alpha, rev_masks)

            fake_labels = torch.zeros_like(real_logits[1])
            ext, output, fnl_masks = real_logits
            weights_real = torch.ones_like(output)
            real_labels = torch.ones_like(output)

            # for i in range(batch_size):
            invalid_patch = fnl_masks != 0.0
            weights_real.masked_fill_(invalid_patch, 0.0)

            norm_fact_real = weights_real.sum()
            norm_fact_fake = weights_real.shape[0] * weights_real.shape[
                1] * weights_real.shape[2] * weights_real.shape[3]
            real_logits = ext, output

            fake_logits = netD(fake_imgs.detach(), self.alpha)

            errD_real_uncond = criterion(
                real_logits[1], real_labels
            )  # Real/Fake loss for 'real background' (on patch level)
            errD_real_uncond = torch.mul(
                errD_real_uncond, weights_real
            )  # Masking output units which correspond to receptive fields which lie within the boundin box
            errD_real_uncond = errD_real_uncond.mean()

            errD_fake_uncond = criterion(
                fake_logits[1], fake_labels
            )  # Real/Fake loss for 'fake background' (on patch level)
            errD_fake_uncond = errD_fake_uncond.mean()

            if norm_fact_real > 0:  # Normalizing the real/fake loss for background after accounting the number of masked members in the output.
                errD_real = errD_real_uncond * ((norm_fact_fake * 1.0) /
                                                (norm_fact_real * 1.0))
            else:
                errD_real = errD_real_uncond

            errD_fake = errD_fake_uncond
            errD1 = (errD_real + errD_fake) * cfg.TRAIN.BG_LOSS_WT_LCL

            # Background/foreground classification loss
            errD_real_uncond_classi = criterion(real_logits[0], weights_real)
            errD_real_uncond_classi = errD_real_uncond_classi.mean()
            errD_classi = errD_real_uncond_classi * cfg.TRAIN.BG_CLASSI_WT

            # print(errD0, errD1)
            # sys.exit(0)

            errD = errD0 + errD1 + errD_classi

        elif idx == 2:  # Discriminator is only trained in background and child stage. (NOT in parent stage)
            netD, optD = self.netsD[2], self.optimizersD[2]
            real_imgs = self.real_cimgs
            fake_imgs = self.fake_imgs[2]
            netD.zero_grad()
            real_logits = netD(real_imgs, self.alpha)
            fake_logits = netD(fake_imgs.detach(), self.alpha)
            real_labels = torch.ones_like(real_logits[1])
            fake_labels = torch.zeros_like(real_logits[1])

            errD_real = criterion_one(
                real_logits[1],
                real_labels)  # Real/Fake loss for the real image
            errD_fake = criterion_one(
                fake_logits[1],
                fake_labels)  # Real/Fake loss for the fake image
            errD = errD_real + errD_fake

        errD.backward()
        optD.step()

        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.item())
            self.summary_writer.add_summary(summary_D, count)
            summary_D_real = summary.scalar('D_loss_real_%d' % idx,
                                            errD_real.item())
            self.summary_writer.add_summary(summary_D_real, count)
            summary_D_fake = summary.scalar('D_loss_fake_%d' % idx,
                                            errD_fake.item())
            self.summary_writer.add_summary(summary_D_fake, count)

        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        for myit in range(4):
            self.netsD[myit].zero_grad()

        errG_total = 0
        flag = count % 100
        batch_size = self.real_fimgs.size(0)
        criterion_one, criterion_class, c_code, p_code = self.criterion_one, self.criterion_class, self.c_code, self.p_code

        for i in range(3):
            if i == 0 or i == 2:  # real/fake loss for background (0) and child (2) stage
                if i == 0:
                    outputs = self.netsD[0](self.fake_imgs[0], self.alpha,
                                            self.aux_masks)
                    real_labels = torch.ones_like(outputs[1])
                    errG0 = criterion_one(outputs[1], real_labels)
                    errG0 = errG0 * cfg.TRAIN.BG_LOSS_WT_GLB

                    outputs = self.netsD[3](self.fake_imgs[0], self.alpha)
                    real_labels = torch.ones_like(outputs[1])
                    errG1 = criterion_one(outputs[1], real_labels)
                    errG1 = errG1 * cfg.TRAIN.BG_LOSS_WT_LCL

                    errG_classi = criterion_one(
                        outputs[0], real_labels
                    )  # Background/Foreground classification loss for the fake background image (on patch level)
                    errG_classi = errG_classi * cfg.TRAIN.BG_CLASSI_WT

                    errG = errG0 + errG1 + errG_classi
                    errG_total = errG_total + errG

                else:  # i = 2
                    outputs = self.netsD[2](self.fake_imgs[2], self.alpha)
                    real_labels = torch.ones_like(outputs[1])
                    errG = criterion_one(outputs[1], real_labels)
                    errG_total = errG_total + errG

            if i == 1:  # Mutual information loss for the parent stage (1)
                pred_p = self.netsD[i](self.fg_mk[i - 1], self.alpha)
                errG_info = criterion_class(pred_p[0],
                                            torch.nonzero(p_code.long())[:, 1])
            elif i == 2:  # Mutual information loss for the child stage (2)
                pred_c = self.netsD[i](self.fg_mk[i - 1], self.alpha)
                errG_info = criterion_class(pred_c[0],
                                            torch.nonzero(c_code.long())[:, 1])

            if i > 0:
                errG_total = errG_total + errG_info

            if flag == 0:
                if i > 0:
                    summary_D_class = summary.scalar('Information_loss_%d' % i,
                                                     errG_info.item())
                    self.summary_writer.add_summary(summary_D_class, count)

                if i == 0 or i == 2:
                    summary_D = summary.scalar('G_loss%d' % i, errG.item())
                    self.summary_writer.add_summary(summary_D, count)

        errG_total.backward()
        for myit in range(3):
            self.optimizerG[myit].step()
        return errG_total

    def get_dataloader(self, cur_depth):
        bshuffle = True
        imsize = 32 * (2**(cur_depth + 1))
        image_transform = transforms.Compose([
            transforms.Resize(int(imsize * 76 / 64)),
            transforms.RandomCrop(imsize),
            transforms.RandomHorizontalFlip()
        ])

        dataset = Dataset(cfg.DATA_DIR,
                          cur_depth=cur_depth,
                          transform=image_transform)

        if cfg.TRAIN.DATASET_SIZE != -1:
            if self.subdataset_idx is None:
                self.subdataset_idx = random.sample(range(0, len(dataset)),
                                                    cfg.TRAIN.DATASET_SIZE)
            dataset = torch.utils.data.Subset(dataset, self.subdataset_idx)

        assert dataset
        print('training dataset size: ', len(dataset))

        num_gpu = len(cfg.GPU_ID.split(','))
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batchsize_per_depth[cur_depth] * num_gpu,
            drop_last=True,
            shuffle=bshuffle,
            num_workers=int(cfg.WORKERS))
        return dataloader

    def train(self):
        self.netG, self.netsD, self.num_Ds, start_count = load_network(
            self.gpus)
        newly_loaded = True
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss(reduce=False)
        self.criterion_one = nn.BCELoss()
        self.criterion_class = nn.CrossEntropyLoss()

        nz = cfg.GAN.Z_DIM

        if cfg.CUDA:
            self.criterion.cuda()
            self.criterion_one.cuda()
            self.criterion_class.cuda()

        print("Starting normal FineGAN training..")
        count = start_count

        for cur_depth in range(start_depth, end_depth + 1):
            max_epoch = blend_epochs_per_depth[cur_depth] + \
                stable_epochs_per_depth[cur_depth]
            dataloader = self.get_dataloader(cur_depth)
            num_batches = len(dataloader)

            depth_ep_ctr = 0  # depth epoch counter
            batch_size = batchsize_per_depth[cur_depth] * self.num_gpus

            noise = Variable(torch.FloatTensor(batch_size, nz))
            fixed_noise = Variable(
                torch.FloatTensor(batch_size, nz).normal_(0, 1))

            if cfg.CUDA:
                noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

            start_epoch = start_count // (num_batches)
            start_count = 0

            for epoch in range(start_epoch, max_epoch):
                depth_ep_ctr += 1

                # switch dataset
                if depth_ep_ctr < blend_epochs_per_depth[cur_depth]:
                    self.alpha = depth_ep_ctr / blend_epochs_per_depth[
                        cur_depth]
                else:
                    self.alpha = 1

                start_t = time.time()
                for step, data in enumerate(dataloader, 0):
                    count += 1
                    _, self.real_fimgs, self.real_cimgs, \
                        self.c_code, self.masks, self.aux_masks = self.prepare_data(data)

                    # Feedforward through Generator. Obtain stagewise fake images
                    noise.data.normal_(0, 1)
                    fake_imgs, fg_imgs, mk_imgs, fg_mk = self.netG(
                        noise, self.c_code, self.alpha)

                    self.fake_imgs = fake_imgs[cur_depth * 3:cur_depth * 3 + 3]
                    self.fg_imgs = fg_imgs[cur_depth * 2:cur_depth * 2 + 2]
                    self.mk_imgs = mk_imgs[cur_depth * 2:cur_depth * 2 + 2]
                    self.fg_mk = fg_mk[cur_depth * 2:cur_depth * 2 + 2]

                    # Obtain the parent code given the child code
                    self.p_code = child_to_parent(self.c_code,
                                                  cfg.FINE_GRAINED_CATEGORIES,
                                                  cfg.SUPER_CATEGORIES)

                    # Update Discriminator networks
                    errD_total = 0
                    for i in range(3):
                        if i == 0 or i == 2:  # only at parent and child stage
                            errD = self.train_Dnet(i, count)
                            errD_total += errD

                    # Update the Generator networks
                    errG_total = self.train_Gnet(count)
                    for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                        avg_p.mul_(0.999).add_(0.001, p.data)

                    newly_loaded = False
                    if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                        backup_para = copy_G_params(self.netG)
                        if count % cfg.TRAIN.SAVEMODEL_INTERVAL == 0:
                            save_model(self.netG, avg_param_G, self.netsD,
                                       count, self.model_dir, cur_depth)

                        # Save images
                        load_params(self.netG, avg_param_G)

                        fake_imgs, fg_imgs, mk_imgs, fg_mk = self.netG(
                            fixed_noise, self.c_code, self.alpha)
                        save_img_results((fake_imgs[cur_depth*3:cur_depth*3+3] + fg_imgs[cur_depth*2:cur_depth*2+2] \
                                            + mk_imgs[cur_depth*2:cur_depth*2+2] + fg_mk[cur_depth*2:cur_depth*2+2]),
                                         count, self.image_dir, self.summary_writer, cur_depth)
                        #
                        load_params(self.netG, backup_para)

                end_t = time.time()
                print('''[%d/%d][%d]Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' %
                      (epoch, max_epoch, num_batches, errD_total.item(),
                       errG_total.item(), end_t - start_t))
            # sys.exit(0)
            if not newly_loaded:
                save_model(self.netG, avg_param_G, self.netsD, count,
                           self.model_dir, cur_depth)
            self.update_network()
            avg_param_G = copy_G_params(self.netG)

    def update_network(self):
        self.netG.module.inc_depth()
        # self.netG = torch.nn.DataParallel(self.netG, device_ids=self.gpus)
        print(self.netG)

        for netD in self.netsD:
            netD.module.inc_depth()
            # netD = torch.nn.DataParallel(netD, device_ids=self.gpus)
            print(netD)

        if cfg.CUDA:
            self.netG.cuda()
            for netD in self.netsD:
                netD.cuda()

        self.optimizersD = []
        for netD in self.netsD:
            opt = optim.Adam(netD.parameters(),
                             lr=cfg.TRAIN.DISCRIMINATOR_LR,
                             betas=(0.5, 0.999))
            self.optimizersD.append(opt)

        self.optimizerG = []
        self.optimizerG.append(
            optim.Adam(self.netG.parameters(),
                       lr=cfg.TRAIN.GENERATOR_LR,
                       betas=(0.5, 0.999)))

        opt = optim.Adam(self.netsD[1].parameters(),
                         lr=cfg.TRAIN.GENERATOR_LR,
                         betas=(0.5, 0.999))
        self.optimizerG.append(opt)

        opt = optim.Adam(
            [{
                'params':
                self.netsD[2].module.down_net[0].jointConv.parameters()
            }, {
                'params': self.netsD[2].module.down_net[0].logits.parameters()
            }],
            lr=cfg.TRAIN.GENERATOR_LR,
            betas=(0.5, 0.999))
        self.optimizerG.append(opt)
def run(args):
    with tf.Graph().as_default():
        global_step = tf.Variable(0, name='global_step', trainable=False)
        train_flag = tf.placeholder(tf.bool)
        keep_prob = tf.placeholder(tf.float32)
        feature_1 = tf.placeholder(tf.float32, [None, args.input_dimension], 'feature_1')
        feature_2 = tf.placeholder(tf.float32, [None, args.input_dimension], 'feature_2')
        feature_3 = tf.placeholder(tf.float32, [None, args.input_dimension], 'feature_3')
        real_labels=tf.placeholder(tf.int32, [None, ], 'real_labels')

        ### generated predicate features ###
        bottle_z = ST_encoder(dim_G, feature_1, keep_prob, reuse=False, training=train_flag)
        reconstruction = ST_decoder(dim_D,1000, bottle_z, feature_2, keep_prob, reuse=False, training=train_flag)
        errL1 = tf.reduce_mean(tf.losses.absolute_difference(reconstruction, feature_3, reduction=tf.losses.Reduction.NONE))
        if args.ac_weight > 0:
            ac_loss = aux_classifier(reconstruction, real_labels, args.num_predicates, keep_prob, reuse=False, training=train_flag)
        else:
            ac_loss=tf.zeros(1,dtype=tf.dtypes.float32)
        #errL1 = tf.reduce_mean(tf.abs(reconstruction - feature_3))

        errD_fake = netD(256, reconstruction, n_layers=0, reuse=False)
        errD_real = netD(256, feature_3, n_layers=0, reuse=True)

        # cost functions
        errD = tf.reduce_mean(errD_fake) - tf.reduce_mean(errD_real)
        errG = -tf.reduce_mean(errD_fake)
        if args.ac_weight > 0:
            errG_total = errG + errL1 * args.L1_weight + args.ac_weight * ac_loss
        else:
            errG_total = errG + errL1 * args.L1_weight

        # gradient penalty
        epsilon = tf.random_uniform([], 0.0, 1.0)
        x_hat = feature_3 * (1 - epsilon) + epsilon * reconstruction
        d_hat = netD(256,x_hat, n_layers=0, reuse=True)
        gradients = tf.gradients(d_hat, x_hat)[0]
        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = 10 * tf.reduce_mean((slopes - 1.0) ** 2)
        errD_total = errD + gradient_penalty

        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'Discriminator' in var.name]
        g_vars = [var for var in t_vars if 'Generator' in var.name]
        learning_rate = get_learning_rate(data_num, global_step)
        G_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5, beta2=0.9).minimize(errG_total,global_step,var_list=g_vars)
        D_train_op = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5, beta2=0.9).minimize(errD_total,global_step,var_list=d_vars)

        ops = {'D_train_op': D_train_op,
               'G_train_op': G_train_op,
               'feature_1': feature_1,
               'feature_2': feature_2,
               'feature_3': feature_3,
               'keep_prob': keep_prob,
               'real_labels': real_labels,
               'train_flag': train_flag,
               'errD': errD,
               'errG': errG,
               'errL1': errL1,
               'ac_loss': ac_loss,
               'reconstruction': reconstruction}

        saver = tf.train.Saver(max_to_keep=None)
        init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        ### make gpu memory grow according to needed ###
        gpu_options = tf.GPUOptions(allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        sess.run(init)
        summary_writer = FileWriter(log_path, graph=tf.get_default_graph())

        # tf.add_to_collection('train_op', train_op)
        tf.add_to_collection('G_train_op', G_train_op)
        tf.add_to_collection('D_train_op', D_train_op)

        start_epoch=0
        if args.training:
            # restore previous model if there is one
            ckpt = tf.train.get_checkpoint_state(model_pth)
            if ckpt and ckpt.model_checkpoint_path:
                print("Restoring previous model...")
                try:
                    start_epoch = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) + 1
                    print(start_epoch)
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    print("Model restored")
                except:
                    print("Could not restore model")
                    pass

            ########################################### training portion
            for epoch in range(start_epoch,args.max_epoch):
                start = time.time()
                train_loss_d, train_loss_g, train_loss_L1Loss,train_loss_acLoss = train_one_epoch(sess, input_data, ops, args)
                print('epoch:', epoch, 'D loss:', train_loss_d.avg, 'G_loss:', train_loss_g.avg, 'L1:',train_loss_L1Loss.avg,'AC:',train_loss_acLoss.avg, 'time:', time.time() - start)

                summary_D = summary.scalar('D_loss', train_loss_d.avg)
                summary_writer.add_summary(summary_D, epoch)
                summary_G = summary.scalar('G_loss', train_loss_g.avg)
                summary_writer.add_summary(summary_G, epoch)
                summary_G_L1 = summary.scalar('G_L1', train_loss_L1Loss.avg)
                summary_writer.add_summary(summary_G_L1, epoch)
                summary_AC = summary.scalar('G_AC', train_loss_acLoss.avg)
                summary_writer.add_summary(summary_AC, epoch)
                if (epoch + 1) % 10 == 0:
                    print('save model')
                    if not os.path.exists(model_pth):
                        os.makedirs(model_pth)
                    saver.save(sess, model_pth + 'checkpoint-' + str(epoch))
                    saver.export_meta_graph(model_pth + 'checkpoint-' + str(epoch) + '.meta')
        else:
            print('evaluation')
            ckpt = tf.train.get_checkpoint_state(model_pth)
            try:
                epoch = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
                saver.restore(sess, ckpt.model_checkpoint_path)
                print("Model restored")
            except:
                print("Could not restore model")
                exit(0)
                pass

            ### generate whole data###
            if args.test_setting == 'wholedata':
                print('generate whole data:', epoch)
                generate_wholedata(sess, input_data, ops, epoch)
            ### generate lowshot vrd data###
            elif args.test_setting == 'lowshot':
                print('generate lowshot data:', epoch)
                generate_lowshot(sess, input_data, ops, args, epoch)
    input_data.close()
Example #19
0
class condGANTrainer(object):
    def __init__(self, output_dir, data_loader, imsize):
        self.model_dir = os.path.join(output_dir, 'Model')
        self.image_dir = os.path.join(output_dir, 'Image')
        self.log_dir = os.path.join(output_dir, 'Log')
        self.testImage_dir = os.path.join(output_dir, 'TestImage')
        if cfg.TRAIN.FLAG:
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            mkdir_p(self.testImage_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        imgs, w_imgs, s_imgs, t_embedding, class_id, _ = data

        real_vimgs, wrong_vimgs = [], []
        if cfg.CUDA:
            vembedding = Variable(t_embedding).cuda()
            same_vimg = Variable(s_imgs).cuda()
        else:
            vembedding = Variable(t_embedding)
            same_vimg = Variable(s_imgs)
        for i in range(self.num_Ds):
            if cfg.CUDA:
                real_vimgs.append(Variable(imgs[i]).cuda())
                wrong_vimgs.append(Variable(w_imgs[i]).cuda())
            else:
                real_vimgs.append(Variable(imgs[i]))
                wrong_vimgs.append(Variable(w_imgs[i]))

        return imgs, real_vimgs, wrong_vimgs, same_vimg, vembedding, class_id

    def train_MDnet(self, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        real_imgs = self.real_imgs[-1]
        wrong_imgs = self.wrong_imgs[-1]
        fake_imgs = self.fake_imgs[-1]
        similar_imgs = self.similar_imgs
        #
        netMD = self.netMD
        optMD = self.optimizerMD
        netMD.zero_grad()

        same_labels = self.same_labels[:batch_size]
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.real_labels[:batch_size]
        wrong_labels = self.wrong_labels[:batch_size]

        real_feat = self.image_cnn(real_imgs.detach())
        real_feat = self.image_encoder(real_feat.detach())
        similar_feat = self.image_cnn(similar_imgs.detach())
        similar_feat = self.image_encoder(similar_feat.detach())
        fake_feat = self.image_cnn(fake_imgs.detach())
        fake_feat = self.image_encoder(fake_feat.detach())
        wrong_feat = self.image_cnn(wrong_imgs.detach())
        wrong_feat = self.image_encoder(wrong_feat.detach())

        same_logits = netMD(real_feat, real_feat)
        real_logits2 = netMD(real_feat, similar_feat)
        fake_logits2 = netMD(real_feat, fake_feat.detach())
        wrong_logits2 = netMD(real_feat, wrong_feat)

        errMD_si = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()(
            real_logits2, real_labels.long())
        errMD_sa = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()(
            same_logits, same_labels.long())
        errMD_fa = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()(
            fake_logits2, fake_labels.long())
        errMD_wr = cfg.TRAIN.COEFF.MD_LOSS * nn.CrossEntropyLoss()(
            wrong_logits2, wrong_labels.long())
        if cfg.DATASET_NAME == 'birds' or cfg.DATASET_NAME == 'flowers':
            errMD = errMD_si + errMD_sa + errMD_fa + errMD_wr
        else:
            errMD = errMD_si + errMD_fa + errMD_wr

        # backward
        errMD.backward()
        optMD.step()
        # log
        if flag == 0:
            summary_MD = summary.scalar('MD_loss', errMD.item())
            self.summary_writer.add_summary(summary_MD, count)
        return errMD

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu = self.criterion, self.mu

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        wrong_imgs = self.wrong_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        #
        netD.zero_grad()
        # Forward
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        # for real
        real_logits = netD(real_imgs, mu.detach())
        wrong_logits = netD(wrong_imgs, mu.detach())
        fake_logits = netD(fake_imgs.detach(), mu.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                wrong_logits[1], real_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                fake_logits[1], fake_labels)
            #
            errD_real = errD_real + errD_real_uncond
            errD_wrong = errD_wrong + errD_wrong_uncond
            errD_fake = errD_fake + errD_fake_uncond
            #
            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)
        # backward
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.item())
            self.summary_writer.add_summary(summary_D, count)
        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu, logvar = self.criterion, self.mu, self.logvar
        real_labels = self.real_labels[:batch_size]

        for i in range(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], mu)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                    outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0 or cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS > 0 or cfg.TRAIN.COEFF.MD_LOSS > 0:
                fake_feat = self.image_cnn(self.fake_imgs[i])
                fake_feat = self.image_encoder(fake_feat)
            if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0 or cfg.TRAIN.COEFF.MD_LOSS > 0:
                real_feat = self.image_cnn(self.real_imgs[i])
                real_feat = self.image_encoder(real_feat)
            if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0:
                loss1, loss2 = batch_loss(real_feat, fake_feat, self.class_ids)
                errG_CC = loss1 + loss2
                errG_total = errG_total + errG_CC * cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS
            if cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS > 0:
                loss1, loss2 = batch_loss(self.txt_embedding, fake_feat,
                                          self.class_ids)
                errG_SC = loss1 + loss2
                errG_total = errG_total + errG_SC * cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS
            if cfg.TRAIN.COEFF.MD_LOSS > 0 and i == (self.num_Ds - 1):
                outputs2 = self.netMD(real_feat, fake_feat)
                errMG = nn.CrossEntropyLoss()(outputs2, real_labels.long())
                errG_total = errG_total + errMG * cfg.TRAIN.COEFF.MD_LOSS
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.item())
                self.summary_writer.add_summary(summary_D, count)

        # Compute color consistency losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = compute_mean_covariance(
                    self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(
                    covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu2', like_mu2.item())
                    self.summary_writer.add_summary(sum_mu, global_step=count)

                    sum_cov = summary.scalar('G_like_cov2', like_cov2.item())
                    self.summary_writer.add_summary(sum_cov, global_step=count)
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = compute_mean_covariance(
                    self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * nn.MSELoss()(
                    covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.item())
                    self.summary_writer.add_summary(sum_mu, count)

                    sum_cov = summary.scalar('G_like_cov1', like_cov1.item())
                    self.summary_writer.add_summary(sum_cov, count)

        kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL
        errG_total = errG_total + kl_loss
        errG_total.backward()
        self.optimizerG.step()
        return kl_loss, errG_total

    def train(self):
        if cfg.TRAIN.COEFF.MD_LOSS > 0:
            self.netG, self.netsD, self.netMD, self.num_Ds, self.inception_model, start_count = load_network(
                self.gpus, self.num_batches)
        else:
            self.netG, self.netsD, self.num_Ds, self.inception_model, start_count = load_network(
                self.gpus, self.num_batches)

        avg_param_G = copy_G_params(self.netG)

        if cfg.TRAIN.COEFF.CONTENTCONSIST_LOSS > 0 or cfg.TRAIN.COEFF.SEMANTICONSIST_LOSS > 0 or cfg.TRAIN.COEFF.MD_LOSS > 0:
            self.image_cnn = Inception_v3()
            self.image_encoder = LINEAR_ENCODER()
            if not isinstance(self.image_cnn, torch.nn.DataParallel):
                self.image_cnn = nn.DataParallel(self.image_cnn)
            if not isinstance(self.image_encoder, torch.nn.DataParallel):
                self.image_encoder = nn.DataParallel(self.image_encoder)
            if cfg.DATASET_NAME == 'birds':
                self.image_encoder.load_state_dict(
                    torch.load(
                        "outputs/pre_train/birds/models/best_image_model.pth"))
            if cfg.DATASET_NAME == 'flowers':
                self.image_encoder.load_state_dict(
                    torch.load(
                        "outputs/pre_train/flowers/models/best_image_model.pth"
                    ))

            if cfg.CUDA:
                self.image_cnn = self.image_cnn.cuda()
                self.image_encoder = self.image_encoder.cuda()
            self.image_cnn.eval()
            self.image_encoder.eval()
            for p in self.image_cnn.parameters():
                p.requires_grad = False
            for p in self.image_encoder.parameters():
                p.requires_grad = False

        self.optimizerG, self.optimizersD = define_optimizers(
            self.netG, self.netsD)
        if cfg.TRAIN.COEFF.MD_LOSS > 0:
            self.optimizerMD = optim.Adam(self.netMD.parameters(),
                                          lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                          betas=(0.5, 0.999))

        self.criterion = nn.BCELoss()
        self.real_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(0))
        self.same_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(0))
        self.wrong_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(2))

        self.gradient_one = torch.FloatTensor([1.0])
        self.gradient_half = torch.FloatTensor([0.5])

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = Variable(
            torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.same_labels = self.same_labels.cuda()
            self.wrong_labels = self.wrong_labels.cuda()
            self.gradient_one = self.gradient_one.cuda()
            self.gradient_half = self.gradient_half.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs, self.wrong_imgs, self.similar_imgs, self.txt_embedding, self.class_ids = self.prepare_data(
                    data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, self.mu, self.logvar = self.netG(
                    noise, self.txt_embedding)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD
                #update MD network
                errMD = self.train_MDnet(count)
                errD_total += errMD
                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                kl_loss, errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)  #

                # for inception score
                # pred = self.inception_model(self.fake_imgs[-1].detach())
                # predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.item())
                    summary_G = summary.scalar('G_loss', errG_total.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                count = count + 1
            if epoch % cfg.TRAIN.SAVE_EPOCH == 0:
                if cfg.TRAIN.COEFF.MD_LOSS > 0:
                    DIS_NET = [self.netsD, self.netMD]
                else:
                    DIS_NET = self.netsD
                save_model(self.netG, avg_param_G, DIS_NET, epoch,
                           self.model_dir)
            if epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0:
                # Save images
                backup_para = copy_G_params(self.netG)
                load_params(self.netG, avg_param_G)
                #
                self.fake_imgs, _, _ = self.netG(fixed_noise,
                                                 self.txt_embedding)
                save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds,
                                 epoch, self.image_dir, self.summary_writer)
                #
                load_params(self.netG, backup_para)

                #############################
                #***during the training process, the paramerter of G are updated alone
                #**why in the generating stage, use the weighting parameter of G

                #############################
                """
                # Compute inception score
                if len(predictions) > 500:
                    predictions = np.concatenate(predictions, 0)
                    mean, std = compute_inception_score(predictions, 10)
                    # print('mean:', mean, 'std', std)
                    m_incep = summary.scalar('Inception_mean', mean)
                    self.summary_writer.add_summary(m_incep, count)
                    #
                    mean_nlpp, std_nlpp = negative_log_posterior_probability(predictions, 10)
                    m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                    self.summary_writer.add_summary(m_nlpp, count)
                    #
                    predictions = []
                """

            end_t = time.time()
            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs
                      '''

                  # D(real): %.4f D(wrong):%.4f  D(fake) %.4f
                  %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), kl_loss.item(), end_t - start_t))

        if cfg.TRAIN.COEFF.MD_LOSS > 0:
            DIS_NET = [self.netsD, self.netMD]
        else:
            DIS_NET = self.netsD
        save_model(self.netG, avg_param_G, DIS_NET, epoch, self.model_dir)
        self.summary_writer.close()

    def save_superimages(self, images_list, filenames, save_dir, split_dir,
                         imsize):
        batch_size = images_list[0].size(0)
        num_sentences = len(images_list)
        for i in range(batch_size):
            s_tmp = '%s/super/%s/%s' % (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)
            #
            savename = '%s_%d.png' % (s_tmp, imsize)
            super_img = []
            for j in range(num_sentences):
                img = images_list[j][i]
                # print(img.size())
                img = img.view(1, 3, imsize, imsize)
                # print(img.size())
                super_img.append(img)
                # break
            super_img = torch.cat(super_img, 0)
            vutils.save_image(super_img, savename, nrow=10, normalize=True)

    def save_singleimages(self, images, filenames, save_dir, split_dir,
                          sentenceID, imsize):
        for i in range(images.size(0)):
            s_tmp = '%s/%s' % (save_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)

            fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID)
            # range from [-1, 1] to [0, 255]
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def evaluate(self, split_dir):
        NET_G_root = self.model_dir
        net_list = os.listdir(NET_G_root)
        G_NETS = []
        for net in net_list:
            if net.find('netG') != -1:
                s_tmp = net
                istart = s_tmp.rfind('_') + 1
                iend = s_tmp.rfind('.')
                epoch = int(s_tmp[istart:iend])
                if epoch >= 100 and epoch <= 600:  ##################********************************************250
                    G_NETS.append(net)

        for NET_G in G_NETS:
            NET_G_path = os.path.join(NET_G_root, NET_G)
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = torch.load(NET_G_path,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', NET_G_path)

            # the path to save generated images
            s_tmp = NET_G_path
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            epoch = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/epoch%d' % (self.testImage_dir, epoch)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                fake_img_list = []
                for i in range(embedding_dim):
                    fake_imgs, _, _ = netG(
                        noise,
                        t_embeddings[:, i, :])  #t_embeddings[:, i, :] by shawn
                    if cfg.TEST.B_EXAMPLE:
                        # fake_img_list.append(fake_imgs[0].data.cpu())
                        # fake_img_list.append(fake_imgs[1].data.cpu())
                        fake_img_list.append(fake_imgs[2].data.cpu())
                    else:
                        self.save_singleimages(fake_imgs[-1], filenames,
                                               save_dir, split_dir, i, 256)
                        # self.save_singleimages(fake_imgs[-2], filenames,
                        #                        save_dir, split_dir, i, 128)
                        # self.save_singleimages(fake_imgs[-3], filenames,
                        #                        save_dir, split_dir, i, 64)
                    # break
                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    self.save_superimages(fake_img_list, filenames, save_dir,
                                          split_dir, 256)
Example #20
0
class GANTrainer(object):
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        #torch.cuda.set_device(self.gpus[0])
        #torch._C._cuda_setDevice(-1)
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        imgs = data

        vimgs = []
        for i in range(self.num_Ds):
            if cfg.CUDA:
                vimgs.append(Variable(imgs[i]).cuda())
            else:
                vimgs.append(Variable(imgs[i]))

        return imgs, vimgs

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        #
        netD.zero_grad()
        #
        real_logits = netD(real_imgs)
        fake_logits = netD(fake_imgs.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        #
        errD = errD_real + errD_fake
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
            self.summary_writer.add_summary(summary_D, count)
        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion
        real_labels = self.real_labels[:batch_size]

        for i in range(self.num_Ds):
            netD = self.netsD[i]
            outputs = netD(self.fake_imgs[i])
            errG = criterion(outputs[0], real_labels)
            # errG = self.stage_coeff[i] * errG
            errG_total = errG_total + errG
            if flag == 0:
                summary_G = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_G, count)

        # Compute color preserve losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1

            if flag == 0:
                sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0])
                self.summary_writer.add_summary(sum_mu, count)
                sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0])
                self.summary_writer.add_summary(sum_cov, count)
                if self.num_Ds > 2:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0])
                    self.summary_writer.add_summary(sum_cov, count)

        errG_total.backward()
        self.optimizerG.step()
        return errG_total

    def train(self):
        self.netG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs = self.prepare_data(data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, _, _ = self.netG(noise)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # for inception score
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.data[0])
                    summary_G = summary.scalar('G_loss', errG_total.data[0])
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                if step == 0:
                    print('''[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f''' %
                          (epoch, self.max_epoch, step, self.num_batches,
                           errD_total.data[0], errG_total.data[0]))
                count = count + 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count,
                               self.model_dir)
                    save_model(self.netG, avg_param_G, self.netsD, count,
                               self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    self.fake_imgs, _, _ = self.netG(fixed_noise)
                    save_img_results(self.imgs_tcpu, self.fake_imgs,
                                     self.num_Ds, count, self.image_dir,
                                     self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        # print('mean:', mean, 'std', std)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []

            end_t = time.time()
            print('Total Time: %.2fsec' % (end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)

        self.summary_writer.close()

    def save_superimages(self, images, folder, startID, imsize):
        fullpath = '%s/%d_%d.png' % (folder, startID, imsize)
        vutils.save_image(images.data, fullpath, normalize=True)

    def save_singleimages(self, images, folder, startID, imsize):
        for i in range(images.size(0)):
            fullpath = '%s/%d_%d.png' % (folder, startID + i, imsize)
            # range from [-1, 1] to [0, 1]
            img = (images[i] + 1.0) / 2
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            # range from [0, 1] to [0, 255]
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d/%s' % (s_tmp, iteration, split_dir)
            if cfg.TEST.B_EXAMPLE:
                folder = '%s/super' % (save_dir)
            else:
                folder = '%s/single' % (save_dir)
            print('Make a new folder: ', folder)
            mkdir_p(folder)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            num_batches = int(cfg.TEST.SAMPLE_NUM / self.batch_size)
            cnt = 0
            for step in xrange(num_batches):
                noise.data.normal_(0, 1)
                fake_imgs, _, _ = netG(noise)
                if cfg.TEST.B_EXAMPLE:
                    self.save_superimages(fake_imgs[-1], folder, cnt, 256)
                else:
                    self.save_singleimages(fake_imgs[-1], folder, cnt, 256)
                    # self.save_singleimages(fake_imgs[-2], folder, 128)
                    # self.save_singleimages(fake_imgs[-3], folder, 64)
                cnt += self.batch_size
Example #21
0
else:
    base_path = args.base_path + args.mode + '/' + args.output_path
    model_path = base_path + '/model'
    log_path = base_path + '/log'

if not os.path.exists(model_path):
    os.makedirs(model_path)
if not os.path.exists(log_path):
    os.makedirs(log_path)
# exit(0)

with open(base_path + '/args.txt', 'w') as output_file:
    for x, y in vars(args).items():
        output_file.write("{} : {}\n".format(x, y))

summary_writer = FileWriter(log_path)

train_feature_all = h5py.File(args.train_path, 'r')
N_train_gt = train_feature_all['pre_label'].shape[0]
train_feature_use = train_feature_all['feature']
train_label_use = train_feature_all['pre_label']
N_train = N_train_gt
assert train_feature_use.shape[0] == train_label_use.shape[0]

if args.test_path is not None:
    test_feature_all = h5py.File(args.test_path, 'r')
    N_val = test_feature_all['pre_label'].shape[0]


def get_learning_rate(data_num, batch):
    learning_rate = tf.train.exponential_decay(
Example #22
0
class Trainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True
        
        # load fasttext embeddings (e.g., birds.en.vec)
        path = os.path.join(cfg.DATA_DIR, cfg.DATASET_NAME + ".en.vec")
        txt_dico, _txt_emb = load_external_embeddings(path)
        txt_emb = nn.Embedding(len(txt_dico), 300, sparse=False)
        txt_emb.weight.data.copy_(_txt_emb)
        txt_emb.weight.requires_grad = False
        self.txt_dico = txt_dico
        self.txt_emb = txt_emb
        
        # load networks and evaluator
        self.networks = self.load_network()
        self.evaluator = Evaluator(self.networks, self.txt_emb)
        
        # visualizer to visdom server
        self.vis = Visualizer(cfg.VISDOM_HOST, cfg.VISDOM_PORT, output_dir)
        self.vis.make_img_window("real_im")
        self.vis.make_img_window("fake_im")
        self.vis.make_txt_window("real_captions")
        self.vis.make_txt_window("genr_captions")        
        self.vis.make_plot_window("G_loss", num=7, 
                                  legend=["errG", "uncond", "cond", "latent", "cycltxt", "autoimg", "autotxt"])
        self.vis.make_plot_window("D_loss", num=4, 
                                  legend=["errD", "uncond", "cond", "latent"])
        self.vis.make_plot_window("KL_loss", num=4, 
                                  legend=["kl", "img", "txt", "fakeimg"])
        
        self.vis.make_plot_window("inception_score", num=2,
                                 legend=["real", "fake"])
        self.vis.make_plot_window("r_precision", num=1)
              
    #
    # convert a text sentence into indices
    #
    def ind_from_sent(self, caption):
        s = [self.txt_dico.SOS_TOKEN]
        s += [self.txt_dico.word2id[word] \
            if word in self.txt_dico.word2id \
            else self.txt_dico.UNK_TOKEN for word in caption.split(" ")]
        s += [self.txt_dico.EOS_TOKEN]
        return s
    
    #
    # pad a sequence with PAD TOKENs
    #
    def pad_seq(self, seq, max_length):
        seq += [self.txt_dico.PAD_TOKEN for i in range(max_length - len(seq))]
        return seq 
    
    #
    # convert a list of sentences into a padded tensor w lengths
    #
    def process_captions(self, captions):
        seqs = []
        for i in range(len(captions)):
            seqs.append(self.ind_from_sent(captions[i]))
                        
        input_lengths = [len(s) for s in seqs]
        padded = [self.pad_seq(s, max(input_lengths)) for s in seqs]
                        
        input_var = Variable(torch.LongTensor(padded)).transpose(0, 1)
        lengths = torch.LongTensor(input_lengths)
        if cfg.CUDA:
            input_var = input_var.cuda()
            lengths = lengths.cuda()
        return input_var, lengths

    #
    # load model components
    #
    def load_network(self):
        
        image_generator = ImageGenerator()
        image_generator.apply(weights_init)
        
        disc_image = DiscriminatorImage()
        disc_image.apply(weights_init)
        
        emb_dim = 300
        text_encoder = TextEncoder(emb_dim, self.txt_emb,
                         1, dropout=0.0)
        
        attn_model = 'general'
        text_generator = TextGenerator(attn_model, emb_dim, len(self.txt_dico.id2word), 
                                      self.txt_emb,
                                      n_layers=1, dropout=0.0)    
        
        image_encoder = ImageEncoder()
        image_encoder.apply(weights_init)
        
        disc_latent = DiscriminatorLatent(emb_dim)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
            
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
            
        if cfg.ENCODER != '':
            state_dict = \
                torch.load(cfg.ENCODER,
                           map_location=lambda storage, loc: storage)
            encoder.load_state_dict(state_dict)
            print('Load from: ', cfg.ENCODER)
            
        if cfg.DECODER != '':
            state_dict = \
                torch.load(cfg.DECODER,
                           map_location=lambda storage, loc: storage)
            decoder.load_state_dict(state_dict)
            print('Load from: ', cfg.DECODER)
            
        if cfg.IMAGE_ENCODER != '':
            state_dict = \
                torch.load(cfg.IMAGE_ENCODER,
                           map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load from: ', cfg.IMAGE_ENCODER)         
            
        if cfg.CUDA:
            image_encoder.cuda()
            image_generator.cuda()
            text_encoder.cuda()
            text_generator.cuda()
            disc_image.cuda()
            disc_latent.cuda()
            
        return image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent
    
    def define_optimizers(self, 
                          image_encoder, image_generator, 
                          text_encoder, text_generator, 
                          disc_image, disc_latent):

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        
        optim_disc_img = \
            optim.Adam(disc_image.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
            
        img_gen_params = filter(lambda p: p.requires_grad, image_generator.parameters())
        optim_img_gen = optim.Adam(img_gen_params,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        
        optim_fn, optim_params = get_optimizer("adam,lr=0.001")
        enc_params = filter(lambda p: p.requires_grad, text_encoder.parameters())
        optim_txt_enc = optim_fn(enc_params, **optim_params)
        
        optim_fn, optim_params = get_optimizer("adam,lr=0.001")
        dec_params = filter(lambda p: p.requires_grad, text_generator.parameters())
        optim_txt_gen = optim_fn(dec_params, **optim_params)
        
        
        optim_img_enc = \
            optim.SGD(image_encoder.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR)
            
        optim_disc_latent = \
            optim.Adam(disc_latent.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
            
        return optim_img_enc, optim_img_gen, \
                optim_txt_enc, optim_txt_gen, \
                optim_disc_img, optim_disc_latent
                        
    #
    # train with both autoencoding and cross-domain losses
    #
    def train(self, data_loader, dataset, stage=1):
        
        image_encoder, image_generator, text_encoder, text_generator, disc_image, disc_latent = self.networks
                           
        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
            
        #
        # make labels for real/fake
        #
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))  # try discriminator smoothing
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        
        txt_enc_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) 
        img_enc_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) 
        
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()
            txt_enc_labels = txt_enc_labels.cuda()
            img_enc_labels = img_enc_labels.cuda()                

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        
        optims = self.define_optimizers(image_encoder, image_generator, 
                                   text_encoder, text_generator, 
                                   disc_image, disc_latent)
        optim_img_enc, optim_img_gen, optim_txt_enc, optim_txt_gen, optim_disc_img, optim_disc_latent = optims
        
        count = 0
                
        for epoch in range(self.max_epoch):
            
            start_t = time.time()
            
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.75
                for param_group in optim_img_gen.param_groups:
                    param_group['lr'] = generator_lr
                    
                discriminator_lr *= 0.75
                for param_group in optim_disc_img.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                _, real_img_cpu, _, captions, pred_cap = data

                raw_inds, raw_lengths = self.process_captions(captions)
                
                inds, lengths = raw_inds.data, raw_lengths
                
                inds = Variable(inds)
                lens_sort, sort_idx = lengths.sort(0, descending=True)
                                
                # need to dataparallel the encoders?
                txt_encoder_output = text_encoder(inds[:, sort_idx], lens_sort.cpu().numpy(), None)
                encoder_out, encoder_hidden, real_txt_code, real_txt_mu, real_txt_logvar = txt_encoder_output
                
                real_imgs = Variable(real_img_cpu)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()

                #######################################################
                # (2) Generate fake images and their latent codes
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (real_txt_code, noise)
                fake_imgs = \
                    nn.parallel.data_parallel(image_generator, inputs, self.gpus)
                                        
                fake_img_out = nn.parallel.data_parallel(
                    image_encoder, (fake_imgs), self.gpus
                )
            
                fake_img_feats, fake_img_emb, fake_img_code, fake_img_mu, fake_img_logvar = fake_img_out
                fake_img_feats = fake_img_feats.transpose(0,1)                    
                    
                #######################################################
                # (2b) Calculate auto encoding loss for text
                ######################################################           
                loss_auto_txt, _ = compute_text_gen_loss(text_generator, 
                                                      inds[:,sort_idx],
                                                      real_txt_code.unsqueeze(0), 
                                                      encoder_out, 
                                                      self.txt_dico)
                loss_auto_txt = loss_auto_txt / lengths.float().sum() 

                #######################################################
                # (2c) Decode z from real imgs and calc auto-encoding loss
                ######################################################                    
                
                real_img_out = nn.parallel.data_parallel(
                    image_encoder, (real_imgs[sort_idx]), self.gpus
                )
                
                real_img_feats, real_img_emb, real_img_code, real_img_mu, real_img_logvar = real_img_out

                noise.data.normal_(0, 1)
                loss_auto_img, _ = compute_image_gen_loss(image_generator, 
                                                       real_imgs[sort_idx],
                                                       real_img_code,
                                                       noise,
                                                       self.gpus)
                
                #######################################################
                # (2c) Decode z from fake imgs and calc cycle loss
                ######################################################                    
                
                loss_cycle_text, gen_captions = compute_text_gen_loss(text_generator, 
                                                        inds[:,sort_idx], 
                                                        fake_img_code.unsqueeze(0), 
                                                        fake_img_feats, 
                                                        self.txt_dico)

                loss_cycle_text = loss_cycle_text / lengths.float().sum()
                
                ###############################################################
                # (2d) Generate image from predicted cap, calc img cycle loss
                ###############################################################
                
                loss_cycle_img = 0
                if (len(pred_cap)):
                    pred_inds, pred_lens = pred_cap
                    pred_inds = Variable(pred_inds.transpose(0,1))
                    pred_inds = pred_inds.cuda() if cfg.CUDA else pred_inds

                    pred_output = encoder(pred_inds[:, sort_idx], pred_lens.cpu().numpy(), None)
                    pred_txt_out, pred_txt_hidden, pred_txt_code, pred_txt_mu, pred_txt_logvar = pred_output
                  
                    noise.data.normal_(0, 1)
                    inputs = (pred_txt_code, noise)
                    _, fake_from_fake_img, mu, logvar = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                  
                    pred_img_out = nn.parallel.data_parallel(
                        image_encoder, (fake_from_fake_img), self.gpus
                    )                    
                  
                    pred_img_feats, pred_img_emb, pred_img_code, pred_img_mu, pred_img_logvar = pred_img_out
                  
                    semantic_target = Variable(torch.ones(batch_size))
                    if cfg.CUDA:
                        semantic_target = semantic_target.cuda()
                                              
                    loss_cycle_img = cosine_emb_loss(
                        pred_img_feats.contiguous().view(batch_size, -1), real_img_feats.contiguous().view(batch_size, -1), semantic_target
                    )
                
                ###########################
                # (3) Update D network
                ###########################
                optim_disc_img.zero_grad()
                optim_disc_latent.zero_grad()
                
                errD = 0
                
                errD_fake_imgs = compute_cond_discriminator_loss(disc_image, fake_imgs, 
                                                   fake_labels, encoder_hidden[0], self.gpus)               
                
                errD_im, errD_real, errD_fake = \
                    compute_uncond_discriminator_loss(disc_image, real_imgs, fake_imgs,
                                                      real_labels, fake_labels,
                                                      self.gpus)
                    
                err_latent_disc = compute_latent_discriminator_loss(disc_latent, 
                                                                    real_img_emb, encoder_hidden[0],
                                                                    img_enc_labels, txt_enc_labels,
                                                                    self.gpus)
                
                if (len(pred_cap)):
                    errD_fake_from_fake_imgs = compute_cond_disc(netD, fake_from_fake_img, 
                                                                 fake_labels, pred_txt_hidden[0], self.gpus)
                    errD += errD_fake_from_fake_imgs                 
                
                errD = errD + errD_im + errD_fake_imgs + err_latent_disc
                
                # check NaN
                if (errD != errD).data.any():
                    print("NaN detected (discriminator)")
                    pdb.set_trace()
                    exit()
                    
                errD.backward()
                                
                optim_disc_img.step()
                optim_disc_latent.step()
                
                ############################
                # (2) Update G network
                ###########################
                optim_img_enc.zero_grad()
                optim_img_gen.zero_grad()
                optim_txt_enc.zero_grad()
                optim_txt_gen.zero_grad()
                
                errG_total = 0
                
                err_g_uncond_loss = compute_uncond_generator_loss(disc_image, fake_imgs,
                                              real_labels, self.gpus)
                
                err_g_cond_disc_loss = compute_cond_generator_loss(disc_image, fake_imgs, 
                                                                   real_labels, encoder_hidden[0], self.gpus)
                                    
                err_latent_gen = compute_latent_generator_loss(disc_latent, 
                                                               real_img_emb, encoder_hidden[0],
                                                               img_enc_labels, txt_enc_labels,
                                                               self.gpus)
                
                errG = err_g_uncond_loss + err_g_cond_disc_loss + err_latent_gen + \
                        loss_cycle_text + \
                        loss_auto_img + \
                        loss_auto_txt
                
                if (len(pred_cap)):
                    errG_fake_from_fake_imgs = compute_cond_disc(netD, fake_from_fake_img, 
                                                                 real_labels, pred_txt_hidden[0], self.gpus)
                    errG += errG_fake_from_fake_imgs                
                
                img_kl_loss = KL_loss(real_img_mu, real_img_logvar)
                txt_kl_loss = KL_loss(real_txt_mu, real_txt_logvar)
                f_img_kl_loss = KL_loss(fake_img_mu, fake_img_logvar)

                kl_loss = img_kl_loss + txt_kl_loss + f_img_kl_loss
                           
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                
                # check NaN
                if (errG_total != errG_total).data.any():
                    print("NaN detected (generator)")
                    pdb.set_trace()
                    exit()
                
                errG_total.backward()
                
                optim_img_enc.step()
                optim_img_gen.step()
                optim_txt_enc.step()
                optim_txt_gen.step()               
                
                count = count + 1
                if i % 100 == 0:
                    self.vis.add_to_plot("D_loss", np.asarray([[
                                                    errD.data[0],
                                                    errD_im.data[0],
                                                    errD_fake_imgs.data[0],
                                                    err_latent_disc.data[0]
                                                    ]]), 
                                                    np.asarray([[count] * 4]))
                    self.vis.add_to_plot("G_loss", np.asarray([[
                                                    errG.data[0], 
                                                    err_g_uncond_loss.data[0],
                                                    err_g_cond_disc_loss.data[0],
                                                    err_latent_gen.data[0],
                                                    loss_cycle_text.data[0],
                                                    loss_auto_img.data[0],
                                                    loss_auto_txt.data[0]
                                                    ]]),
                                                    np.asarray([[count] * 7]))
                    self.vis.add_to_plot("KL_loss", np.asarray([[
                                                    kl_loss.data[0],
                                                    img_kl_loss.data[0],
                                                    txt_kl_loss.data[0],
                                                    f_img_kl_loss.data[0]
                                                    ]]), 
                                         np.asarray([[count] * 4]))
                
                    self.vis.show_images("real_im", real_imgs[sort_idx].data.cpu().numpy())
                    self.vis.show_images("fake_im", fake_imgs.data.cpu().numpy())
                    
                    sorted_captions = [captions[i] for i in sort_idx.cpu().tolist()]
                    gen_cap_text = []
                    for d_i, d in enumerate(gen_captions):
                        s = u""
                        for i in d:
                            if i == self.txt_dico.EOS_TOKEN:
                                break
                            if i != self.txt_dico.SOS_TOKEN:
                                s += self.txt_dico.id2word[i] + u" "
                        gen_cap_text.append(s)
                        
                    self.vis.show_text("real_captions", sorted_captions)
                    self.vis.show_text("genr_captions", gen_cap_text)
                    
                    r_precision = self.evaluator.r_precision_score(fake_img_code, real_txt_code)
                    self.vis.add_to_plot("r_precision", np.asarray([r_precision.data[0]]), np.asarray([count]))
                                                        
                        
#             # save pred caps for next iteration
#             for i, data in enumerate(data_loader, 0):
#                 keys, real_img_cpu, _, _, _ = data
#                 real_imgs = Variable(real_img_cpu)
#                 if cfg.CUDA:
#                     real_imgs = real_imgs.cuda()                
                
#                 cap_img_out = nn.parallel.data_parallel(
#                     image_encoder, (real_imgs[sort_idx]), self.gpus
#                 )
                
#                 cap_img_feats, cap_img_emb, cap_img_code, cap_img_mu, cap_img_logvar = cap_img_out
#                 cap_img_feats = cap_img_feats.transpose(0,1)
                                                
#                 cap_features = cap_img_code.unsqueeze(0)
                
#                 cap_dec_inp = Variable(torch.LongTensor([self.txt_dico.SOS_TOKEN] * self.batch_size))
#                 cap_dec_inp = cap_dec_inp.cuda() if cfg.CUDA else cap_dec_inp

#                 cap_dec_hidden = cap_features.detach()

#                 seq = torch.LongTensor([])
#                 seq = seq.cuda() if cfg.CUDA else seq

#                 max_target_length = 20
                
#                 lengths = torch.LongTensor(batch_size).fill_(20)

#                 for t in range(max_target_length):

#                     cap_dec_out, cap_dec_hidden, cap_dec_attn = decoder(
#                         cap_dec_inp, cap_dec_hidden, cap_img_feats
#                     )

#                     topv, topi = cap_dec_out.topk(1, dim=1)

#                     cap_dec_inp = topi #.squeeze(dim=2)
#                     cap_dec_inp = cap_dec_inp.cuda() if cfg.CUDA else cap_dec_inp

#                     seq = torch.cat((seq, cap_dec_inp.data), dim=1)

#                 dataset.save_captions(keys, seq.cpu(), lengths.cpu())

            iscore_mu_real, _ = self.evaluator.inception_score(real_imgs[sort_idx])
            iscore_mu_fake, _ = self.evaluator.inception_score(fake_imgs)
            self.vis.add_to_plot("inception_score", np.asarray([[
                        iscore_mu_real,
                        iscore_mu_fake
                    ]]),
                    np.asarray([[epoch] * 2]))    
            
            end_t = time.time()
            
            prefix = "Epoch %d; %s, %.1f sec" % (epoch, time.strftime('D%d %X'), (end_t-start_t))
            gen_str = "G_total: %.3f Gen loss: %.3f KL loss %.3f" % (
                                                                         errG_total.data[0],
                                                                         errG.data[0],
                                                                         kl_loss.data[0]
                                                                        )
            
            dis_str = "Img Disc: %.3f Latent Disc: %.3f" % (
                errD.data[0], 
                err_latent_disc.data[0]
            )
            
            eval_str = "Incep real: %.3f Incep fake: %.3f R prec %.3f" % (
                iscore_mu_real, 
                iscore_mu_fake,
                r_precision
            )
                
            print("%s %s, %s; %s" % (prefix, gen_str, dis_str, eval_str))
            
            if epoch % self.snapshot_interval == 0:
                save_model(image_encoder, image_generator, 
                           text_encoder, text_generator, 
                           disc_image, disc_latent,
                           epoch, self.model_dir)

        save_model(image_encoder, image_generator, 
                   text_encoder, text_generator, 
                   disc_image, disc_latent, 
                   epoch, self.model_dir)
        
        self.summary_writer.close()
    
    def sample(self, data_loader, stage=1):
        print("todo")
Example #23
0
class condGANTrainer(object):
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        #torch.cuda.set_device(self.gpus[0])
        #torch._C._cuda_setDevice(-1)
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        imgs, w_imgs, t_embedding, _ = data

        real_vimgs, wrong_vimgs = [], []
        if cfg.CUDA:
            vembedding = Variable(t_embedding).cuda()
        else:
            vembedding = Variable(t_embedding)
        for i in range(self.num_Ds):
            if cfg.CUDA:
                real_vimgs.append(Variable(imgs[i]).cuda())
                wrong_vimgs.append(Variable(w_imgs[i]).cuda())
            else:
                real_vimgs.append(Variable(imgs[i]))
                wrong_vimgs.append(Variable(w_imgs[i]))
        return imgs, real_vimgs, wrong_vimgs, vembedding

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu = self.criterion, self.mu

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        wrong_imgs = self.wrong_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        #
        netD.zero_grad()
        # Forward
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        # for real
        real_logits = netD(real_imgs, mu.detach())
        wrong_logits = netD(wrong_imgs, mu.detach())
        fake_logits = netD(fake_imgs.detach(), mu.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(wrong_logits[1], real_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(fake_logits[1], fake_labels)
            #
            errD_real = errD_real + errD_real_uncond
            errD_wrong = errD_wrong + errD_wrong_uncond
            errD_fake = errD_fake + errD_fake_uncond
            #
            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)
        # backward
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
            self.summary_writer.add_summary(summary_D, count)
        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu, logvar = self.criterion, self.mu, self.logvar
        real_labels = self.real_labels[:batch_size]
        for i in range(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], mu)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\
                    criterion(outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_D, count)

        # Compute color consistency losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0])
                    self.summary_writer.add_summary(sum_cov, count)
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0])
                    self.summary_writer.add_summary(sum_cov, count)

        kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL
        errG_total = errG_total + kl_loss
        errG_total.backward()
        self.optimizerG.step()
        return kl_loss, errG_total

    def train(self):
        self.netG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        self.gradient_one = torch.FloatTensor([1.0])
        self.gradient_half = torch.FloatTensor([0.5])

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.gradient_one = self.gradient_one.cuda()
            self.gradient_half = self.gradient_half.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \
                    self.txt_embedding = self.prepare_data(data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, self.mu, self.logvar = \
                    self.netG(noise, self.txt_embedding)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                kl_loss, errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # for inception score
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.data[0])
                    summary_G = summary.scalar('G_loss', errG_total.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                count = count + 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count,
                               self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    self.fake_imgs, _, _ = \
                        self.netG(fixed_noise, self.txt_embedding)
                    save_img_results(self.imgs_tcpu, self.fake_imgs,
                                     self.num_Ds, count, self.image_dir,
                                     self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        # print('mean:', mean, 'std', std)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []

            end_t = time.time()
            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs
                      '''

                  # D(real): %.4f D(wrong):%.4f  D(fake) %.4f
                  %
                  (epoch, self.max_epoch, self.num_batches, errD_total.data[0],
                   errG_total.data[0], kl_loss.data[0], end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        self.summary_writer.close()

    def save_superimages(self, images_list, filenames, save_dir, split_dir,
                         imsize):
        batch_size = images_list[0].size(0)
        num_sentences = len(images_list)
        for i in range(batch_size):
            s_tmp = '%s/super/%s/%s' %\
                (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)
            #
            savename = '%s_%d.png' % (s_tmp, imsize)
            super_img = []
            for j in range(num_sentences):
                img = images_list[j][i]
                # print(img.size())
                img = img.view(1, 3, imsize, imsize)
                # print(img.size())
                super_img.append(img)
                # break
            super_img = torch.cat(super_img, 0)
            vutils.save_image(super_img, savename, nrow=10, normalize=True)

    def save_singleimages(self, images, filenames, save_dir, split_dir,
                          sentenceID, imsize):
        for i in range(images.size(0)):
            s_tmp = '%s/single_samples/%s/%s' %\
                (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)

            fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID)
            # range from [-1, 1] to [0, 255]
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                fake_img_list = []
                for i in range(embedding_dim):
                    fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :])
                    if cfg.TEST.B_EXAMPLE:
                        # fake_img_list.append(fake_imgs[0].data.cpu())
                        # fake_img_list.append(fake_imgs[1].data.cpu())
                        fake_img_list.append(fake_imgs[2].data.cpu())
                    else:
                        self.save_singleimages(fake_imgs[-1], filenames,
                                               save_dir, split_dir, i, 256)
                        # self.save_singleimages(fake_imgs[-2], filenames,
                        #                        save_dir, split_dir, i, 128)
                        # self.save_singleimages(fake_imgs[-3], filenames,
                        #                        save_dir, split_dir, i, 64)
                    # break
                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    self.save_superimages(fake_img_list, filenames, save_dir,
                                          split_dir, 256)
Example #24
0
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (7,7) # Enlarge the figures

from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils import to_categorical

from tensorboardX import FileWriter, summary
import datetime

# create tensorboardX writer for tensorboard output
output_directory = "./tmp/{}/{}".format("Aufgabe-1", datetime.datetime.now())
writer = FileWriter(output_directory)


# writes metrics to log file, is called after each epoch is finished
def log_history(hist, epoch):
    for ep in hist.epoch:
        for val in hist.history:
            writer.add_summary(
                summary=summary.scalar("nn/" + val, hist.history[val][ep]), global_step=epoch
            )


# MNIST has classes 0 to 9
nb_classes = 10

# download MNIST data
(x_train_orig, y_train_orig), (x_test_orig, y_test_orig) = mnist.load_data()