示例#1
0
    def __init__(self, output_dir, label_loader, unlabel_loader):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            self.theta = 0.5
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.label_loader = label_loader
        self.unlabel_loader = unlabel_loader
        self.num_batches = len(self.unlabel_loader)
        self.label_num_batches = len(self.label_loader)
class GradientMetric:
    def __init__(self):
        self.training_log = 'logs/train/network'
        self.gradient_metric_name = '{0}_gradient'
        self.summary_writer = FileWriter(self.training_log)

    def log_gradient(self, network_name, gradient):
        assert self.summary_writer
        summary_value = summary.histogram('{0}'.format(network_name), gradient)
        self.summary_writer.add_summary(summary_value)
示例#3
0
def test_log_scalar_summary():
    logdir = './experiment/scalar'
    writer = FileWriter(logdir)
    for i in range(10):
        s = scalar('scalar', i)
        writer.add_summary(s, i + 1)
    writer.flush()
    writer.close()
示例#4
0
def configure(logdir, flush_secs=2):
    """Configure logging: a file will be written to logdir, and flushed every
    flush_secs.

    """
    global _tf_logger
    if _tf_logger is not None:
        raise ValueError
    _tf_logger = FileWriter(logdir, flush_secs=flush_secs)
def test_log_histogram_summary():
    logdir = './experiment/histogram'
    writer = FileWriter(logdir)
    for i in range(10):
        mu, sigma = i * 0.1, 1.0
        values = np.random.normal(mu, sigma,
                                  10000)  # larger for better looking.
        hist = summary.histogram('discrete_normal', values)
        writer.add_summary(hist, i + 1)
    writer.flush()
    writer.close()
示例#6
0
    def __init__(self, output_dir,
                 max_epoch,
                 snapshot_interval,
                 gpu_id, batch_size,
                 train_flag, net_g, 
                 net_d, cuda, stage1_g,
                 z_dim, generator_lr,
                 discriminator_lr, lr_decay_epoch,
                 coef_kl, regularizer
                 ):
        if train_flag:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = max_epoch
        self.snapshot_interval = snapshot_interval
        self.net_g = net_g
        self.net_d = net_d
        self.cuda = cuda
        self.stage1_g = stage1_g
        self.nz = z_dim
        self.generator_lr = generator_lr
        self.discriminator_lr = discriminator_lr
        self.lr_decay_step = lr_decay_epoch
        self.coef_kl = coef_kl
        self.regularizer = regularizer
        
        s_gpus = gpu_id.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = batch_size * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True
示例#7
0
def test_log_image_summary():
    logdir = './experiment/image'
    writer = FileWriter(logdir)

    path = 'http://yann.lecun.com/exdb/mnist/'
    (train_lbl, train_img) = read_data(path + 'train-labels-idx1-ubyte.gz',
                                       path + 'train-images-idx3-ubyte.gz')

    for i in range(10):
        tensor = np.reshape(train_img[i], (28, 28, 1))
        im = summary.image(
            'mnist/' + str(i),
            tensor)  # in this case, images are grouped under `mnist` tag.
        writer.add_summary(im, i + 1)
    writer.flush()
    writer.close()
示例#8
0
    def __init__(self, data):
        if cfg.TRAIN.FLAG:
            self.model_dir = cfg.NET
            self.log_dir = cfg.TRAIN.LOG_DIR
            self.summary_writer = FileWriter(self.log_dir)

        self.cuda = torch.cuda.is_available()

        if self.cuda:
            self.gpu = int(cfg.GPU_ID)
            torch.cuda.set_device(self.gpu)
            print("Use CUDA")

        self.epoch = cfg.TRAIN.NUM_EPOCH
        self.batch_size = cfg.TRAIN.BATCH_SIZE

        self.vocab_size = data.vocab_size
        self.embedding_size = cfg.EMBEDDING_SIZE
        self.index2word = data.index2word
        self.seqs = data.indexed_seqs

        self.net = self.load_network()

        self.data = data
示例#9
0
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True
示例#10
0
class GANTrainer(object):
    def __init__(self, output_dir,
                 max_epoch,
                 snapshot_interval,
                 gpu_id, batch_size,
                 train_flag, net_g, 
                 net_d, cuda, stage1_g,
                 z_dim, generator_lr,
                 discriminator_lr, lr_decay_epoch,
                 coef_kl, regularizer
                 ):
        if train_flag:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = max_epoch
        self.snapshot_interval = snapshot_interval
        self.net_g = net_g
        self.net_d = net_d
        self.cuda = cuda
        self.stage1_g = stage1_g
        self.nz = z_dim
        self.generator_lr = generator_lr
        self.discriminator_lr = discriminator_lr
        self.lr_decay_step = lr_decay_epoch
        self.coef_kl = coef_kl
        self.regularizer = regularizer
        
        s_gpus = gpu_id.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = batch_size * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

    # ############# For training stageI GAN #############
    def load_network_stageI(self, text_dim, gf_dim, condition_dim, z_dim, df_dim):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G(text_dim, gf_dim, condition_dim, z_dim, self.cuda)
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D(df_dim, condition_dim)
        netD.apply(weights_init)
        print(netD)

        if self.net_g != '':
            state_dict = torch.load(self.net_g)
            #torch.load(self.net_g, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', self.net_g)
        if self.net_d != '':
            state_dict = torch.load(self.net_d, map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', self.net_d)
        if self.cuda:
            netG.cuda()
            netD.cuda()
        return netG, netD

    # ############# For training stageII GAN  #############
    def load_network_stageII(self, text_dim, gf_dim, condition_dim, z_dim, df_dim, res_num):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G(text_dim, gf_dim, condition_dim, z_dim, self.cuda)
        netG = STAGE2_G(Stage1_G, text_dim, gf_dim, condition_dim, z_dim, res_num, self.cuda)
        netG.apply(weights_init)
        print(netG)
        if self.net_g != '':
            state_dict = torch.load(self.net_g)
            #torch.loadself.net_g, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', self.net_g)
        elif self.stage1_g != '':
            state_dict = torch.load(self.stage1_g, map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict)
            print('Load from: ', self.stage1_g)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D(df_dim, condition_dim)
        netD.apply(weights_init)
        if self.net_d != '':
            state_dict = torch.load(self.net_d, map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', self.net_d)
        print(netD)

        if self.cuda:
            netG.cuda()
            netD.cuda()
        return netG, netD

    def train(self, data_loader, stage, text_dim, gf_dim, condition_dim, z_dim, df_dim, res_num):
        if stage == 1:
            netG, netD = self.load_network_stageI(text_dim, gf_dim, condition_dim, z_dim, df_dim)
        else:
            netG, netD = self.load_network_stageII(text_dim, gf_dim, condition_dim, z_dim, df_dim, res_num)

        batch_size = self.batch_size
        noise = torch.FloatTensor(batch_size, self.nz)
        fixed_noise = torch.FloatTensor(batch_size, self.nz).normal_(0, 1)
        real_labels = torch.ones(batch_size)
        fake_labels = torch.zeros(batch_size)
        
        if self.cuda:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        optimizerD = optim.Adam(netD.parameters(),
                       lr=self.discriminator_lr, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                self.generator_lr,
                                betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % self.lr_decay_step == 0 and epoch > 0:
                self.generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = self.generator_lr
                self.discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = self.discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_imgs, txt_embedding = data
                if self.cuda:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs,
                                              real_labels, mu, self.gpus)
                
                if self.regularizer == 'KL':
                    regularizer_loss = KL_loss(mu, logvar)
                else:
                    regularizer_loss = JSD_loss(mu, logvar)

                errG_total = errG + regularizer_loss * self.coef_kl
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('Regularizer_loss', regularizer_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_imgs.cpu(), fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.item(), errG.item(), regularizer_loss.item(),
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()

    def sample(self, datapath, stage=1):
        if stage == 1:
            netG, _ = self.load_network_stageI()
        else:
            netG, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        t_file = torchfile.load(datapath)
        captions_list = t_file.raw_txt
        embeddings = np.concatenate(t_file.fea_txt, axis=0)
        num_embeddings = len(captions_list)
        print('Successfully load sentences from: ', datapath)
        print('Total number of sentences:', num_embeddings)
        print('num_embeddings:', num_embeddings, embeddings.shape)
        # path to save generated samples
        save_dir = self.net_g[:self.net_g.find('.pth')]
        mkdir_p(save_dir)

        batch_size = np.minimum(num_embeddings, self.batch_size)
        noise = torch.FloatTensor(batch_size, self.nz)
        if self.cuda:
            noise = noise.cuda()
        count = 0
        while count < num_embeddings:
            if count > 3000:
                break
            iend = count + batch_size
            if iend > num_embeddings:
                iend = num_embeddings
                count = num_embeddings - batch_size
            embeddings_batch = embeddings[count:iend]
            # captions_batch = captions_list[count:iend]
            txt_embedding = torch.FloatTensor(embeddings_batch)
            if self.cuda:
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            inputs = (txt_embedding, noise)
            _, fake_imgs, mu, logvar = \
                nn.parallel.data_parallel(netG, inputs, self.gpus)
            for i in range(batch_size):
                save_name = '%s/%d.png' % (save_dir, count + i)
                im = fake_imgs[i].data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                # print('im', im.shape)
                im = np.transpose(im, (1, 2, 0))
                # print('im', im.shape)
                im = Image.fromarray(im)
                im.save(save_name)
            count += batch_size

    def birds_eval(self, data_loader, stage=2):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()
        netG.eval()
        
        nz = self.z_dim
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if self.cuda:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        # path to save generated samples
        save_dir = self.netG[:self.netG.find('.pth')]
        print("Save directory", save_dir)
        mkdir_p(save_dir)
        
        count = 0
        for i, data in enumerate(data_loader, 0):
            ######################################################
            # (1) Prepare training data
            ######################################################
            real_img_cpu, txt_embedding = data
            real_imgs = Variable(real_img_cpu)
            txt_embedding = Variable(txt_embedding)
            if self.cuda:
                real_imgs = real_imgs.cuda()
                txt_embedding = txt_embedding.cuda()
            print("Batch Running:", i)
            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            inputs = (txt_embedding, noise)
            _, fake_imgs, mu, logvar = \
                nn.parallel.data_parallel(netG, inputs, self.gpus)
            for i in range(batch_size):
                save_name = '%s/%d.png' % (save_dir, count + i)
                im = fake_imgs[i].data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                # print('im', im.shape)
                im = np.transpose(im, (1, 2, 0))
                # print('im', im.shape)
                im = Image.fromarray(im)
                im.save(save_name)
            count += batch_size   
示例#11
0
def train(iter_cnt, model, corpus, args, optimizer):

    train_writer = FileWriter(args.run_path + "/train", flush_secs=5)

    pos_file_path = "{}.pos.txt".format(args.train)
    neg_file_path = "{}.neg.txt".format(args.train)
    pos_batch_loader = FileLoader(
        [tuple([pos_file_path, os.path.dirname(args.train)])], args.batch_size)
    neg_batch_loader = FileLoader(
        [tuple([neg_file_path, os.path.dirname(args.train)])], args.batch_size)
    #neg_batch_loader = RandomLoader(
    #    corpus = corpus,
    #    exclusive_set = zip(pos_batch_loader.data_left, pos_batch_loader.data_right),
    #    batch_size = args.batch_size
    #)
    #neg_batch_loader = CombinedLoader(
    #    neg_batch_loader_1,
    #    neg_batch_loader_2,
    #    args.batch_size
    #)

    use_content = False
    if args.use_content:
        use_content = True

    embedding_layer = model.embedding_layer

    criterion = model.compute_loss

    start = time.time()
    tot_loss = 0.0
    tot_cnt = 0

    for batch, labels in tqdm(
            pad_iter(corpus,
                     embedding_layer,
                     pos_batch_loader,
                     neg_batch_loader,
                     use_content,
                     pad_left=False)):
        iter_cnt += 1
        model.zero_grad()
        labels = labels.type(torch.LongTensor)
        new_batch = []
        if args.use_content:
            for x in batch:
                for y in x:
                    new_batch.append(y)
            batch = new_batch
        if args.cuda:
            batch = [x.cuda() for x in batch]
            labels = labels.cuda()
        batch = map(Variable, batch)
        labels = Variable(labels)
        repr_left = None
        repr_right = None
        if not use_content:
            repr_left = model(batch[0])
            repr_right = model(batch[1])
        else:
            repr_left = model(batch[0]) + model(batch[1])
            repr_right = model(batch[2]) + model(batch[3])
        output = model.compute_similarity(repr_left, repr_right)
        loss = criterion(output, labels)
        loss.backward()
        prev_emb = embedding_layer.embedding.weight.cpu().data.numpy()
        optimizer.step()
        current_emb = embedding_layer.embedding.weight.cpu().data.numpy()
        diff = np.sum(np.absolute(current_emb - prev_emb))
        tot_loss += loss.data[0] * output.size(0)
        tot_cnt += output.size(0)
        if iter_cnt % 100 == 0:
            outputManager.say("\r" + " " * 50)
            outputManager.say("\r{} loss: {:.4f}  eps: {:.0f} ".format(
                iter_cnt, tot_loss / tot_cnt, tot_cnt / (time.time() - start)))
            s = summary.scalar('loss', tot_loss / tot_cnt)
            train_writer.add_summary(s, iter_cnt)

    outputManager.say("\n")
    train_writer.close()
    #if model.criterion.startswith('classification'):
    #    print model.output_op.weight.min().data[0], model.output_op.weight.max().data[0]

    return iter_cnt
 def __init__(self):
     self.training_log = 'logs/train/network'
     self.gradient_metric_name = '{0}_gradient'
     self.summary_writer = FileWriter(self.training_log)
示例#13
0
class condGANTrainer(object):
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        imgs, w_imgs, t_embedding, _ = data

        real_vimgs, wrong_vimgs = [], []
        if cfg.CUDA:
            vembedding = Variable(t_embedding).cuda()
        else:
            vembedding = Variable(t_embedding)
        for i in range(self.num_Ds):
            if cfg.CUDA:
                real_vimgs.append(Variable(imgs[i]).cuda())
                wrong_vimgs.append(Variable(w_imgs[i]).cuda())
            else:
                real_vimgs.append(Variable(imgs[i]))
                wrong_vimgs.append(Variable(w_imgs[i]))
        return imgs, real_vimgs, wrong_vimgs, vembedding

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu = self.criterion, self.mu

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        wrong_imgs = self.wrong_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        #
        netD.zero_grad()
        # Forward
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        # for real
        real_logits = netD(real_imgs, mu.detach())
        wrong_logits = netD(wrong_imgs, mu.detach())
        fake_logits = netD(fake_imgs.detach(), mu.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(wrong_logits[1], real_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(fake_logits[1], fake_labels)
            #
            errD_real = errD_real + errD_real_uncond
            errD_wrong = errD_wrong + errD_wrong_uncond
            errD_fake = errD_fake + errD_fake_uncond
            #
            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)
        # backward
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.data)
            self.summary_writer.add_summary(summary_D, count)
        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu, logvar = self.criterion, self.mu, self.logvar
        real_labels = self.real_labels[:batch_size]
        for i in range(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], mu)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\
                    criterion(outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.data)
                self.summary_writer.add_summary(summary_D, count)

        # Compute color consistency losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0])
                    self.summary_writer.add_summary(sum_cov, count)
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0])
                    self.summary_writer.add_summary(sum_cov, count)

        kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL
        errG_total = errG_total + kl_loss
        errG_total.backward()
        self.optimizerG.step()
        return kl_loss, errG_total

    def train(self):
        self.netG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        print("Models loaded")
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        self.gradient_one = torch.FloatTensor([1.0])
        self.gradient_half = torch.FloatTensor([0.5])

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.gradient_one = self.gradient_one.cuda()
            self.gradient_half = self.gradient_half.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        print("Start epoches!")
        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            print("Epoch {:d} started".format(epoch))
            start_t = time.time()
            start_iters = time.time()
            N_TOTAL_BATCH  = len(self.data_loader)
         #   data_iterator = iter(self.data_loader)
	
            for step, data in enumerate(self.data_loader, 0):
        #    for step in range(N_TOTAL_BATCH): #enumerate(self.data_loader, 0):
        #        print("Iter {:d}".format(step))
          #      data = next(data_iterator)
                if step % 50 == 0:
                    curr_time = time.time()
                    print("Itteration {:d}/{:d}\nTime for 50 batch steps:{:.2f}\nTotal time:{:.2f}".format(step, N_TOTAL_BATCH, curr_time - start_iters, curr_time-start_t))
                    start_iters = time.time()

                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \
                    self.txt_embedding = self.prepare_data(data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                #print("Data prepeared!")
                noise.data.normal_(0, 1)
                self.fake_imgs, self.mu, self.logvar = \
                    self.netG(noise, self.txt_embedding)

                #######################################################
                # (2) Update D network
                ######################################################
                #print("Fake image generated")
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                #print("D network updated")
                kl_loss, errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # for inception score
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                #print("G network uddated")
                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.data)
                    summary_G = summary.scalar('G_loss', errG_total.data)
                    summary_KL = summary.scalar('KL_loss', kl_loss.data)
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                count = count + 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    self.fake_imgs, _, _ = \
                        self.netG(fixed_noise, self.txt_embedding)
                    save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds,
                                     count, self.image_dir, self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        # print('mean:', mean, 'std', std)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []

            end_t = time.time()
            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs
                      '''  # D(real): %.4f D(wrong):%.4f  D(fake) %.4f
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.data, errG_total.data,
                     kl_loss.data, end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        self.summary_writer.close()

    def save_superimages(self, images_list, filenames,
                         save_dir, split_dir, imsize):
        batch_size = images_list[0].size(0)
        num_sentences = len(images_list)
        for i in range(batch_size):
            s_tmp = '%s/super/%s/%s' %\
                (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)
            #
            savename = '%s_%d.png' % (s_tmp, imsize)
            super_img = []
            for j in range(num_sentences):
                img = images_list[j][i]
                # print(img.size())
                img = img.view(1, 3, imsize, imsize)
                # print(img.size())
                super_img.append(img)
                # break
            super_img = torch.cat(super_img, 0)
            vutils.save_image(super_img, savename, nrow=10, normalize=True)

    def save_singleimages(self, images, filenames,
                          save_dir, split_dir, sentenceID, imsize):
        for i in range(images.size(0)):
            s_tmp = '%s/single_samples/%s/%s' %\
                (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)

            fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID)
            # range from [-1, 1] to [0, 255]
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            #if split_dir == 'test':
             #   split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                fake_img_list = []
                for i in range(embedding_dim):
                    fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :])
                    if cfg.TEST.B_EXAMPLE:
                        # fake_img_list.append(fake_imgs[0].data.cpu())
                        # fake_img_list.append(fake_imgs[1].data.cpu())
                        fake_img_list.append(fake_imgs[2].data.cpu())
                    else:
                        self.save_singleimages(fake_imgs[-1], filenames,
                                               save_dir, split_dir, i, 256)
                        # self.save_singleimages(fake_imgs[-2], filenames,
                        #                        save_dir, split_dir, i, 128)
                        # self.save_singleimages(fake_imgs[-3], filenames,
                        #                        save_dir, split_dir, i, 64)
                    # break
                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    self.save_superimages(fake_img_list, filenames,
                                          save_dir, split_dir, 256)
示例#14
0
def test_event_logging():
    logdir = './experiment/'
    summary_writer = FileWriter(logdir)
    scalar_value = 1.0
    s = scalar('test_scalar', scalar_value)
    summary_writer.add_summary(s, global_step=1)
    summary_writer.close()
    assert os.path.isdir(logdir)
    assert len(os.listdir(logdir)) == 1

    summary_writer = FileWriter(logdir)
    scalar_value = 1.0
    s = scalar('test_scalar', scalar_value)
    summary_writer.add_summary(s, global_step=1)
    summary_writer.close()
    assert os.path.isdir(logdir)
    assert len(os.listdir(logdir)) == 2

    # clean up.
    shutil.rmtree(logdir)
示例#15
0
import logging
import os
import threading
import gym
from datetime import datetime
import time
from a3cmodule import A3CModule
from tensorboard import summary
from tensorboard import FileWriter

T = 0
TMAX = 80000000
t_max = 32

logdir = './a3c_logs/'
summary_writer = FileWriter(logdir)

parser = argparse.ArgumentParser(description='Traing A3C with OpenAI Gym')
parser.add_argument('--test',
                    action='store_true',
                    help='run testing',
                    default=False)
parser.add_argument('--log-file', type=str, help='the name of log file')
parser.add_argument('--log-dir',
                    type=str,
                    default="./log",
                    help='directory of the log file')
parser.add_argument('--model-prefix',
                    type=str,
                    help='the prefix of the model to load')
parser.add_argument('--save-model-prefix',
示例#16
0
class QRNN(object):
    def __init__(self, data):
        if cfg.TRAIN.FLAG:
            self.model_dir = cfg.NET
            self.log_dir = cfg.TRAIN.LOG_DIR
            self.summary_writer = FileWriter(self.log_dir)

        self.cuda = torch.cuda.is_available()

        if self.cuda:
            self.gpu = int(cfg.GPU_ID)
            torch.cuda.set_device(self.gpu)
            print("Use CUDA")

        self.epoch = cfg.TRAIN.NUM_EPOCH
        self.batch_size = cfg.TRAIN.BATCH_SIZE

        self.vocab_size = data.vocab_size
        self.embedding_size = cfg.EMBEDDING_SIZE
        self.index2word = data.index2word
        self.seqs = data.indexed_seqs

        self.net = self.load_network()

        self.data = data

    def load_network(self):
        net = Net(self.vocab_size)
        if self.cuda:
            net = net.cuda()
        if not cfg.TRAIN.FLAG:
            state_dict = torch.load(cfg.NET,
                                    map_location=lambda storage, loc: storage)
            print("Load from", cfg.NET)

        return net

    def train(self):
        start_idx = 0

        net = self.net
        criterion = nn.CrossEntropyLoss()
        lr = cfg.TRAIN.LEARNING_RATE
        optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0, 0.99))

        count = 0
        start_idx = 0
        iteration = len(self.seqs) // self.batch_size + 1
        for epoch in range(self.epoch):
            for i in tqdm(range(iteration)):
                x, lengths = prepare_batch(self.seqs, start_idx,
                                           self.batch_size)
                if self.cuda:
                    x = Variable(torch.LongTensor(x)).cuda()
                else:
                    x = Variable(torch.LongTensor(x))

                logits = net(x)
                loss = []

                for i in range(logits.size(0)):
                    logit = logits[i][:lengths[i] - 1]
                    target = x[i][1:lengths[i]]
                    loss.append(criterion(logit, target))
                loss = sum(loss) / len(loss)
                net.zero_grad()
                loss.backward()
                optimizer.step()

            end = time.time()

            summary_net = summary.scalar("Loss", loss.data[0])
            self.summary_writer.add_summary(summary_net, count)
            count += 1
            print("epoch %d done. train loss: %f" % (epoch + 1, loss.data[0]))

            if count % cfg.TRAIN.LR_DECAY_INTERVAL == 0:
                lr = lr * 0.95
                optimizer = optim.Adam(net.parameters(),
                                       lr=lr,
                                       betas=(0.9, 0.999))
                print("decayed learning rate")

        save_model(net, self.model_dir)
        f = open("output/Data/data.pkl", "wb")
        pickle.dump(self.data, f)
        f.close

    def predict(self, seq):
        pass
def predict(w, x):
    a = np.exp(np.dot(x, w))
    a_sum = np.sum(a, axis=1, keepdims=True)
    prob = a / a_sum
    return prob


def train_loss(w, x):
    prob = predict(w, x)
    loss = -np.sum(label * np.log(prob)) / num_samples
    return loss


"""Use Minpy's auto-grad to derive a gradient function off loss"""
grad_function = grad_and_loss(train_loss)
train_writer = FileWriter(summaries_dir + '/train')

# Using gradient descent to fit the correct classes.


def train(w, x, loops):
    for i in range(loops):
        dw, loss = grad_function(w, x)
        # gradient descent
        w -= 0.1 * dw
        if i % 10 == 0:
            print('Iter {}, training loss {}'.format(i, loss))
        # summary1 = scalar('loss', loss)
        # train_writer.add_summary(summary1, i)
        # print(loss)
        for ele in loss:
示例#18
0
class GANTrainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,  map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    # ############# For training stageII GAN  #############
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = torch.load(cfg.NET_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = torch.load(cfg.STAGE1_G,
                                    map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    def train(self, data_loader, stage=1, max_objects=3):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        # with torch.no_grad():
        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))

        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, bbox, label, txt_embedding = data

                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    if cfg.STAGE == 1:
                        bbox = bbox.cuda()
                    elif cfg.STAGE == 2:
                        bbox = [bbox[0].cuda(), bbox[1].cuda()]
                    label = label.cuda()
                    txt_embedding = txt_embedding.cuda()

                if cfg.STAGE == 1:
                    bbox = bbox.view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices = compute_transformation_matrix(bbox)
                    transf_matrices = transf_matrices.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                elif cfg.STAGE == 2:
                    _bbox = bbox[0].view(-1, 4)
                    transf_matrices_inv = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv = transf_matrices_inv.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                    _bbox = bbox[1].view(-1, 4)
                    transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                        _bbox)
                    transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)
                    transf_matrices_s2 = compute_transformation_matrix(_bbox)
                    transf_matrices_s2 = transf_matrices_s2.view(
                        real_imgs.shape[0], max_objects, 2, 3)

                # produce one-hot encodings of the labels
                _labels = label.long()
                # remove -1 to enable one-hot converting
                _labels[_labels < 0] = 80
                label_one_hot = torch.cuda.FloatTensor(noise.shape[0],
                                                       max_objects,
                                                       81).fill_(0)
                label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                # _, fake_imgs, mu, logvar, _ = netG(txt_embedding, noise, transf_matrices_inv, label_one_hot)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()

                if cfg.STAGE == 1:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices, transf_matrices_inv,
                                                   mu, self.gpus)
                elif cfg.STAGE == 2:
                    errD, errD_real, errD_wrong, errD_fake = \
                        compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                                   real_labels, fake_labels,
                                                   label_one_hot, transf_matrices_s2, transf_matrices_inv_s2,
                                                   mu, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                if cfg.STAGE == 1:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices,
                                                  transf_matrices_inv, mu,
                                                  self.gpus)
                elif cfg.STAGE == 2:
                    errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                                  label_one_hot,
                                                  transf_matrices_s2,
                                                  transf_matrices_inv_s2, mu,
                                                  self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count += 1
                if i % 500 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        if cfg.STAGE == 1:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, label_one_hot)
                        elif cfg.STAGE == 2:
                            inputs = (txt_embedding, noise,
                                      transf_matrices_inv, transf_matrices_s2,
                                      transf_matrices_inv_s2, label_one_hot)
                        lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                            netG, inputs, self.gpus)
                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch,
                                             self.image_dir)
            with torch.no_grad():
                if cfg.STAGE == 1:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              label_one_hot)
                elif cfg.STAGE == 2:
                    inputs = (txt_embedding, noise, transf_matrices_inv,
                              transf_matrices_s2, transf_matrices_inv_s2,
                              label_one_hot)
                lr_fake, fake, _, _, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)
        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()

    def sample(self,
               datapath,
               num_samples=25,
               stage=1,
               draw_bbox=True,
               max_objects=3):
        from PIL import Image, ImageDraw, ImageFont
        import cPickle as pickle
        import torchvision
        import torchvision.utils as vutils
        img_dir = cfg.IMG_DIR
        if stage == 1:
            netG, _ = self.load_network_stageI()
        else:
            netG, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        t_file = torchfile.load(datapath + "val_captions.t7")
        captions_list = t_file.raw_txt
        embeddings = np.concatenate(t_file.fea_txt, axis=0)
        num_embeddings = len(captions_list)
        label, bbox = load_validation_data(datapath)

        filepath = os.path.join(datapath, 'filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames = pickle.load(f)
        print('Successfully load sentences from: ', datapath)
        print('Total number of sentences:', num_embeddings)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_visualize_bbox"
        print("saving to:", save_dir)
        mkdir_p(save_dir)

        if cfg.CUDA:
            if cfg.STAGE == 1:
                bbox = bbox.cuda()
            elif cfg.STAGE == 2:
                bbox = [bbox.clone().cuda(), bbox.cuda()]
            label = label.cuda()

        #######################################
        if cfg.STAGE == 1:
            bbox_ = bbox.clone()
        elif cfg.STAGE == 2:
            bbox_ = bbox[0].clone()

        if cfg.STAGE == 1:
            bbox = bbox.view(-1, 4)
            transf_matrices_inv = compute_transformation_matrix_inverse(bbox)
            transf_matrices_inv = transf_matrices_inv.view(
                num_embeddings, max_objects, 2, 3)
        elif cfg.STAGE == 2:
            _bbox = bbox[0].view(-1, 4)
            transf_matrices_inv = compute_transformation_matrix_inverse(_bbox)
            transf_matrices_inv = transf_matrices_inv.view(
                num_embeddings, max_objects, 2, 3)

            _bbox = bbox[1].view(-1, 4)
            transf_matrices_inv_s2 = compute_transformation_matrix_inverse(
                _bbox)
            transf_matrices_inv_s2 = transf_matrices_inv_s2.view(
                num_embeddings, max_objects, 2, 3)
            transf_matrices_s2 = compute_transformation_matrix(_bbox)
            transf_matrices_s2 = transf_matrices_s2.view(
                num_embeddings, max_objects, 2, 3)

        # produce one-hot encodings of the labels
        _labels = label.long()
        # remove -1 to enable one-hot converting
        _labels[_labels < 0] = 80
        label_one_hot = torch.cuda.FloatTensor(num_embeddings, max_objects,
                                               81).fill_(0)
        label_one_hot = label_one_hot.scatter_(2, _labels, 1).float()
        #######################################

        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(9, nz))
        if cfg.CUDA:
            noise = noise.cuda()

        imsize = 64 if stage == 1 else 256

        for count in range(num_samples):
            index = int(np.random.randint(0, num_embeddings, 1))
            key = filenames[index]
            img_name = img_dir + "/" + key + ".jpg"
            img = Image.open(img_name).convert('RGB').resize((imsize, imsize),
                                                             Image.ANTIALIAS)
            val_image = torchvision.transforms.functional.to_tensor(img)
            val_image = val_image.view(1, 3, imsize, imsize)
            val_image = (val_image - 0.5) * 2

            embeddings_batch = embeddings[index]
            transf_matrices_inv_batch = transf_matrices_inv[index]
            label_one_hot_batch = label_one_hot[index]

            embeddings_batch = np.reshape(embeddings_batch,
                                          (1, 1024)).repeat(9, 0)
            transf_matrices_inv_batch = transf_matrices_inv_batch.view(
                1, 3, 2, 3).repeat(9, 1, 1, 1)
            label_one_hot_batch = label_one_hot_batch.view(1, 3,
                                                           81).repeat(9, 1, 1)

            if cfg.STAGE == 2:
                transf_matrices_s2_batch = transf_matrices_s2[index]
                transf_matrices_s2_batch = transf_matrices_s2_batch.view(
                    1, 3, 2, 3).repeat(9, 1, 1, 1)
                transf_matrices_inv_s2_batch = transf_matrices_inv_s2[index]
                transf_matrices_inv_s2_batch = transf_matrices_inv_s2_batch.view(
                    1, 3, 2, 3).repeat(9, 1, 1, 1)

            txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
            if cfg.CUDA:
                label_one_hot_batch = label_one_hot_batch.cuda()
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            # inputs = (txt_embedding, noise, transf_matrices_inv_batch, label_one_hot_batch)
            if cfg.STAGE == 1:
                inputs = (txt_embedding, noise, transf_matrices_inv_batch,
                          label_one_hot_batch)
            elif cfg.STAGE == 2:
                inputs = (txt_embedding, noise, transf_matrices_inv_batch,
                          transf_matrices_s2_batch,
                          transf_matrices_inv_s2_batch, label_one_hot_batch)
            with torch.no_grad():
                _, fake_imgs, mu, logvar, _ = nn.parallel.data_parallel(
                    netG, inputs, self.gpus)

            data_img = torch.FloatTensor(10, 3, imsize, imsize).fill_(0)
            data_img[0] = val_image
            data_img[1:10] = fake_imgs

            if draw_bbox:
                for idx in range(3):
                    x, y, w, h = tuple(
                        [int(imsize * x) for x in bbox_[index, idx]])
                    w = imsize - 1 if w > imsize - 1 else w
                    h = imsize - 1 if h > imsize - 1 else h
                    if x <= -1:
                        break
                    data_img[:10, :, y, x:x + w] = 1
                    data_img[:10, :, y:y + h, x] = 1
                    data_img[:10, :, y + h, x:x + w] = 1
                    data_img[:10, :, y:y + h, x + w] = 1

            vutils.save_image(data_img,
                              '{}/{}.png'.format(save_dir,
                                                 captions_list[index]),
                              normalize=True,
                              nrow=10)

        print("Saved {} files to {}".format(count + 1, save_dir))
示例#19
0
class RecurrentGANTrainer:
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'model')
            self.image_dir = os.path.join(output_dir, 'image')
            self.log_dir = os.path.join(output_dir, 'log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        imgs, w_imgs, t_embedding, _, caption_tensors, len_vector = data

        v_caption_tensors = []
        v_len_vector = []
        real_vimgs, wrong_vimgs = [], []
        if cfg.CUDA:
            vembedding = Variable(t_embedding).cuda()
            if caption_tensors is not None:
                v_caption_tensors = Variable(caption_tensors).cuda()
                v_len_vector = len_vector.cuda()
        else:
            vembedding = Variable(t_embedding)
            if caption_tensors is not None:
                v_caption_tensors = Variable(caption_tensors)
                v_len_vector = len_vector

        for i in range(len(imgs)):
            if cfg.CUDA:
                real_vimgs.append(Variable(imgs[i]).cuda())
                wrong_vimgs.append(Variable(w_imgs[i]).cuda())
            else:
                real_vimgs.append(Variable(imgs[i]))
                wrong_vimgs.append(Variable(w_imgs[i]))

        return imgs, real_vimgs, wrong_vimgs, vembedding, v_caption_tensors, v_len_vector

    def train_Dnet(self, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.citerion

        netD = self.netD
        optD = self.optimizerD
        real_imgs = self.real_imgs[0]
        wrong_imgs = self.wrong_imgs[0]
        fake_imgs = self.fake_imgs[-1]  # Take only the last image

        netD.zero_grad()

        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]

        # Calculating the logits
        mu = self.mus[-1]
        real_logits = netD(real_imgs, mu.detach())
        wrong_logits = netD(wrong_imgs, mu.detach())
        fake_logits = netD(fake_imgs.detach(), mu.detach())

        # Calculating the error
        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)

        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                wrong_logits[1], fake_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                fake_logits[1], fake_labels)

            errD_real += errD_real_uncond
            errD_wrong += errD_wrong_uncond
            errD_fake += errD_fake_uncond

            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)

        # Calculating the gradients
        errD.backward()

        # Backproping
        optD.step()

        if flag == 0:
            summary_D = summary.scalar('D_loss%d', errD.data[0])
            self.summary_writer.add_summary(summary_D, count)

        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)

        criterion = self.citerion

        mus, logvars = self.mus, self.logvars

        real_labels = self.real_labels[:batch_size]

        # Looping through each time-step.
        for i in range(len(self.fake_imgs)):
            logits = self.netD(self.fake_imgs[i], mus[i])
            errG = criterion(logits[0], real_labels)
            if len(logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * criterion(
                    logits[1], real_labels)
                errG += errG_uncond
            errG_total += errG

            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_D, count)

        kl_loss = 0
        for i in range(len(self.fake_imgs)):
            kl_loss += KL_loss(mus[i], logvars[i]) * cfg.TRAIN.COEFF.KL

        errG_total += kl_loss

        # Compute the gradients
        errG_total.backward()

        # BPTT
        self.optimizerG.step()

        return kl_loss, errG_total

    def save_singleimages(self,
                          images,
                          filenames,
                          save_dir,
                          split_dir,
                          sentenceID,
                          imsize,
                          mean=0):
        for i in range(images.size(0)):
            s_tmp = '%s/single_samples/%s/%s' %\
                (save_dir, split_dir, filenames[i]+'_'+str(mean))
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)

            fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID)
            # range from [-1, 1] to [0, 255]
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def train(self):
        self.netG, self.netD, self.inception_model, start_count = load_network(
            self.gpus)

        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizerD = define_optimizers(
            self.netG, self.netD)

        self.citerion = nn.BCELoss()

        self.real_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(0))
        # self.gradient_one = torch.FloatTensor([1.0])
        # self.gradient_half = torch.FloatTensor([0.5])

        # Initial Hidden State
        h0 = Variable(
            torch.FloatTensor(self.batch_size, 1, cfg.HIDDEN_STATE_SIZE,
                              cfg.HIDDEN_STATE_SIZE))
        h0_initalized = Variable(
            torch.FloatTensor(self.batch_size, 1, cfg.HIDDEN_STATE_SIZE,
                              cfg.HIDDEN_STATE_SIZE).normal_(0, 1))

        if cfg.CUDA:
            self.citerion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            h0 = h0.cuda()
            h0_initalized = h0_initalized.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // self.num_batches
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                self.imgs_tcpu, self.real_imgs, self.wrong_imgs, self.txt_embeddings, self.caption_tensors, self.len_vector = self.prepare_data(
                    data)

                # 1. Generate Fake Data from Generator
                h0.data.normal_(0, 1)
                self.fake_imgs, self.mus, self.logvars = self.netG(
                    h0, self.txt_embeddings)

                # 2. Update Discriminator
                errD_total = self.train_Dnet(count)

                # 3. Update Generator
                kl_loss, errG_total = self.train_Gnet(count)

                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.data[0])
                    summary_G = summary.scalar('G_loss', errG_total.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                count += 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netD, count,
                               self.model_dir)

                    # Save Images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    self.fake_imgs, _, _ = self.netG(h0_initalized,
                                                     self.txt_embeddings)
                    save_img_results(self.imgs_tcpu, self.fake_imgs, count,
                                     self.image_dir, self.summary_writer)
                    load_params(self.netG, backup_para)

                    # Compute Inception Score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        score_summary = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(score_summary, count)

                        mean_nlpp, std_nlpp = negative_log_posterior_probability(
                            predictions, 10)
                        mean_nlpp_summary = summary.scalar(
                            'NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(
                            mean_nlpp_summary, count)

                        predictions = []
            end_t = time.time()
            print(
                '''[%d/%d][%d--%d] Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs
                      ''' % (epoch, self.max_epoch, self.num_batches, count,
                             errD_total.data[0], errG_total.data[0],
                             kl_loss.data[0], end_t - start_t))

        save_model(self.netG, avg_param_G, self.netD, count, self.model_dir)
        self.summary_writer.close()

    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: Could not find the saved Generator Model.')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'

            netG = Generator()

            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)

            state_dict = torch.load(cfg.TRAIN.NET_G,
                                    map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Loaded weights to Generator Network.', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            h0 = Variable(
                torch.FloatTensor(self.batch_size, 1, cfg.INITIAL_IMAGE_SIZE,
                                  cfg.INITIAL_IMAGE_SIZE))

            if cfg.CUDA:
                netG.cuda()
                h0 = h0.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                h0.data.normal_(0, 1)

                fake_imgs, _, _ = netG(h0, t_embeddings)
                self.save_singleimages(fake_imgs[-1], filenames, save_dir,
                                       split_dir, 1, 32)
示例#20
0
class GANTrainer(object):
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        imgs = data

        vimgs = []
        for i in range(self.num_Ds):
            if cfg.CUDA:
                vimgs.append(Variable(imgs[i]).cuda())
            else:
                vimgs.append(Variable(imgs[i]))

        return imgs, vimgs

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        #
        netD.zero_grad()
        #
        real_logits = netD(real_imgs)
        fake_logits = netD(fake_imgs.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        #
        errD = errD_real + errD_fake
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
            self.summary_writer.add_summary(summary_D, count)
        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion
        real_labels = self.real_labels[:batch_size]

        for i in range(self.num_Ds):
            netD = self.netsD[i]
            outputs = netD(self.fake_imgs[i])
            errG = criterion(outputs[0], real_labels)
            # errG = self.stage_coeff[i] * errG
            errG_total = errG_total + errG
            if flag == 0:
                summary_G = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_G, count)

        # Compute color preserve losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1

            if flag == 0:
                sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0])
                self.summary_writer.add_summary(sum_mu, count)
                sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0])
                self.summary_writer.add_summary(sum_cov, count)
                if self.num_Ds > 2:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0])
                    self.summary_writer.add_summary(sum_cov, count)

        errG_total.backward()
        self.optimizerG.step()
        return errG_total

    def train(self):
        self.netG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs = self.prepare_data(data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, _, _ = self.netG(noise)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # for inception score
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.data[0])
                    summary_G = summary.scalar('G_loss', errG_total.data[0])
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                if step == 0:
                    print('''[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f'''
                           % (epoch, self.max_epoch, step, self.num_batches,
                              errD_total.data[0], errG_total.data[0]))
                count = count + 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
                    save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    self.fake_imgs, _, _ = self.netG(fixed_noise)
                    save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds,
                                    count, self.image_dir, self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        # print('mean:', mean, 'std', std)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []

            end_t = time.time()
            print('Total Time: %.2fsec' % (end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)

        self.summary_writer.close()

    def save_superimages(self, images, folder, startID, imsize):
        fullpath = '%s/%d_%d.png' % (folder, startID, imsize)
        vutils.save_image(images.data, fullpath, normalize=True)

    def save_singleimages(self, images, folder, startID, imsize):
        for i in range(images.size(0)):
            fullpath = '%s/%d_%d.png' % (folder, startID + i, imsize)
            # range from [-1, 1] to [0, 1]
            img = (images[i] + 1.0) / 2
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            # range from [0, 1] to [0, 255]
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d/%s' % (s_tmp, iteration, split_dir)
            if cfg.TEST.B_EXAMPLE:
                folder = '%s/super' % (save_dir)
            else:
                folder = '%s/single' % (save_dir)
            print('Make a new folder: ', folder)
            mkdir_p(folder)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            num_batches = int(cfg.TEST.SAMPLE_NUM / self.batch_size)
            cnt = 0
            for step in xrange(num_batches):
                noise.data.normal_(0, 1)
                fake_imgs, _, _ = netG(noise)
                if cfg.TEST.B_EXAMPLE:
                    self.save_superimages(fake_imgs[-1], folder, cnt, 256)
                else:
                    self.save_singleimages(fake_imgs[-1], folder, cnt, 256)
                    # self.save_singleimages(fake_imgs[-2], folder, 128)
                    # self.save_singleimages(fake_imgs[-3], folder, 64)
                cnt += self.batch_size
示例#21
0
def fit(args, network, data_loader, batch_end_callback=None):
    # kvstore
    kv = mx.kvstore.create(args.kv_store)

    # logging
    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
    if 'log_file' in args and args.log_file is not None:
        log_file = args.log_file
        log_dir = args.log_dir
        log_file_full_name = os.path.join(log_dir, log_file)
        if not os.path.exists(log_dir):
            os.mkdir(log_dir)
        logger = logging.getLogger()
        handler = logging.FileHandler(log_file_full_name)
        formatter = logging.Formatter(head)
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.setLevel(logging.DEBUG)
        logger.info('start with arguments %s', args)
    else:
        logging.basicConfig(level=logging.DEBUG, format=head)
        logging.info('start with arguments %s', args)

    # load model
    model_prefix = args.model_prefix
    if model_prefix is not None:
        model_prefix += "-%d" % (kv.rank)
    model_args = {}
    if args.load_epoch is not None:
        assert model_prefix is not None
        tmp = mx.model.FeedForward.load(model_prefix, args.load_epoch)
        model_args = {'arg_params' : tmp.arg_params,
                      'aux_params' : tmp.aux_params,
                      'begin_epoch' : args.load_epoch}
        # TODO: check epoch_size for 'dist_sync'
        epoch_size = args.num_examples / args.batch_size
        model_args['begin_num_update'] = epoch_size * args.load_epoch

    # save model
    save_model_prefix = args.save_model_prefix
    if save_model_prefix is None:
        save_model_prefix = model_prefix
    checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)

    # data
    (train, val) = data_loader(args, kv)

    # train
    devs = [mx.cpu(i) for i in range(4)] if args.gpus is None else [
        mx.gpu(int(i)) for i in args.gpus.split(',')]

    epoch_size = args.num_examples / args.batch_size

    if args.kv_store == 'dist_sync':
        epoch_size /= kv.num_workers
        model_args['epoch_size'] = epoch_size

    if 'lr_factor' in args and args.lr_factor < 1:
        model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
            step = max(int(epoch_size * args.lr_factor_epoch), 1),
            factor = args.lr_factor)

    if 'clip_gradient' in args and args.clip_gradient is not None:
        model_args['clip_gradient'] = args.clip_gradient

    # disable kvstore for single device
    if 'local' in kv.type and (
            args.gpus is None or len(args.gpus.split(',')) is 1):
        kv = None
    
    if args.init == 'uniform':
        init = mx.init.Uniform(0.1)
    if args.init == 'normal':
        init = mx.init.Normal(0,0.1)
    if args.init == 'xavier':
        init = mx.init.Xavier(factor_type="in", magnitude=2.34)
    model = mx.model.FeedForward(
        ctx                = devs,
        symbol             = network,
        num_epoch          = args.num_epochs,
        learning_rate      = args.lr,
        momentum           = 0.9,
        wd                 = 0.00001,
        initializer        = init,
        **model_args)

    eval_metrics = ['accuracy']
    ## TopKAccuracy only allows top_k > 1
    for top_k in [5]:
        eval_metrics.append(mx.metric.create('top_k_accuracy', top_k = top_k))

    if batch_end_callback is not None:
        if not isinstance(batch_end_callback, list):
            batch_end_callback = [batch_end_callback]
    else:
        batch_end_callback = []
    batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
    
    logdir = './logs/'
    summary_writer = FileWriter(logdir)
    def get_grad(g):
        # logging using tensorboard
        grad = g.asnumpy().flatten()
        s = summary.histogram(args.name, grad)
        summary_writer.add_summary(s)
        return mx.nd.norm(g)/np.sqrt(g.size)
    mon = mx.mon.Monitor(int(args.num_examples/args.batch_size), get_grad, pattern='fc_backward_weight')  # get weight of first fully-connnected layer
    
    model.fit(
        X                  = train,
        eval_data          = val,
        eval_metric        = eval_metrics,
        kvstore            = kv,
        monitor            = mon,
        epoch_end_callback = checkpoint)

    summary_writer.close()
示例#22
0
class condGANTrainer(object):
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        imgs, w_imgs, t_embedding, _ = data

        real_vimgs, wrong_vimgs = [], []
        if cfg.CUDA:
            vembedding = Variable(t_embedding).cuda()
        else:
            vembedding = Variable(t_embedding)
        for i in range(self.num_Ds):
            if cfg.CUDA:
                real_vimgs.append(Variable(imgs[i]).cuda())
                wrong_vimgs.append(Variable(w_imgs[i]).cuda())
            else:
                real_vimgs.append(Variable(imgs[i]))
                wrong_vimgs.append(Variable(w_imgs[i]))
        return imgs, real_vimgs, wrong_vimgs, vembedding

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu = self.criterion, self.mu

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        wrong_imgs = self.wrong_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        #
        netD.zero_grad()
        # Forward
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        # for real
        real_logits = netD(real_imgs, mu.detach())
        wrong_logits = netD(wrong_imgs, mu.detach())
        fake_logits = netD(fake_imgs.detach(), mu.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(wrong_logits[1], real_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(fake_logits[1], fake_labels)
            #
            errD_real = errD_real + errD_real_uncond
            errD_wrong = errD_wrong + errD_wrong_uncond
            errD_fake = errD_fake + errD_fake_uncond
            #
            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)
        # backward
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.item())
            self.summary_writer.add_summary(summary_D, count)
        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu, logvar = self.criterion, self.mu, self.logvar
        real_labels = self.real_labels[:batch_size]
        for i in range(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], mu)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\
                    criterion(outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.item())
                self.summary_writer.add_summary(summary_D, count)

        # Compute color consistency losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu2', like_mu2.item())
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov2', like_cov2.item())
                    self.summary_writer.add_summary(sum_cov, count)
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.item())
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov1', like_cov1.item())
                    self.summary_writer.add_summary(sum_cov, count)

        kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL
        errG_total = errG_total + kl_loss
        # Postpone the backward propagation
        # errG_total.backward()
        # self.optimizerG.step()
        return kl_loss, errG_total

    def train(self):
        self.netG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss()
        self.SATcriterion = nn.CrossEntropyLoss()

        self.real_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = Variable(
            torch.FloatTensor(self.batch_size).fill_(0))

        self.gradient_one = torch.FloatTensor([1.0])
        self.gradient_half = torch.FloatTensor([0.5])

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = Variable(
            torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        # Data parameters
        data_folder = 'birds_output'  # folder with data files saved by create_input_files.py
        data_name = 'CUB_5_cap_per_img_5_min_word_freq'  # base name shared by data files

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        # Show, Attend, and Tell Dataloader
        train_loader = torch.utils.data.DataLoader(
            CaptionDataset(data_folder,
                           data_name,
                           'TRAIN',
                           transform=transforms.Compose([normalize])),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=int(cfg.WORKERS),
            pin_memory=True)

        if cfg.CUDA:
            self.criterion.cuda()
            self.SATcriterion.cuda()  # Compute SATloss
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.gradient_one = self.gradient_one.cuda()
            self.gradient_half = self.gradient_half.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            # for step, data in enumerate(self.data_loader, 0):
            for step, data in enumerate(zip(self.data_loader, train_loader),
                                        0):

                data_1 = data[0]
                _, caps, caplens = data[1]
                data = data_1

                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \
                    self.txt_embedding = self.prepare_data(data)

                # Testing line for real samples
                if epoch == start_epoch and step == 0:
                    print('Checking real samples at first...')
                    save_real(self.imgs_tcpu, self.image_dir)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, self.mu, self.logvar = \
                    self.netG(noise, self.txt_embedding)

                # len(self.fake_imgs) = NUM_BRANCHES
                # self.fake_imgs[0].shape = [batch_size, 3, 64, 64]
                # self.fake_imgs[1].shape = [batch_size, 3, 128, 128]
                # self.fake_imgs[2].shape = [batch_size, 3, 256, 256]

                #######################################################
                # (*) Forward fake images to SAT
                ######################################################
                from SATmodels import Encoder, DecoderWithAttention
                from torch.nn.utils.rnn import pack_padded_sequence

                fine_tune_encoder = False

                # Read word map
                word_map_file = os.path.join(data_folder,
                                             'WORDMAP_' + data_name + '.json')
                with open(word_map_file, 'r') as j:
                    word_map = json.load(j)

                # Define the encoder/decoder structure for SAT model
                decoder = DecoderWithAttention(attention_dim=512,
                                               embed_dim=512,
                                               decoder_dim=512,
                                               vocab_size=len(word_map),
                                               dropout=0.5).cuda()
                decoder_optimizer = torch.optim.Adam(params=filter(
                    lambda p: p.requires_grad, decoder.parameters()),
                                                     lr=4e-4)

                encoder = Encoder().cuda()
                encoder.fine_tune(fine_tune_encoder)
                encoder_optimizer = torch.optim.Adam(
                    params=filter(lambda p: p.requires_grad,
                                  encoder.parameters()),
                    lr=1e-4) if fine_tune_encoder else None

                SATloss = 0

                # Compute the SAT loss after forwarding the SAT model
                for idx in range(len(self.fake_imgs)):
                    img = encoder(self.fake_imgs[idx])
                    scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(
                        img, caps, caplens)
                    targets = caps_sorted[:, 1:]
                    scores, _ = pack_padded_sequence(scores,
                                                     decode_lengths,
                                                     batch_first=True).cuda()
                    targets, _ = pack_padded_sequence(targets,
                                                      decode_lengths,
                                                      batch_first=True).cuda()
                    SATloss += self.SATcriterion(scores, targets) + 1 * (
                        (1. - alphas.sum(dim=1))**2).mean()

                # Set zero_grad for encoder/decoder
                decoder_optimizer.zero_grad()
                if encoder_optimizer is not None:
                    encoder_optimizer.zero_grad()

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                kl_loss, errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # Combine with G and SAT first, then back propagation
                errG_total += SATloss
                errG_total.backward()
                self.optimizerG.step()

                #######################################################
                # (*) Update SAT network:
                ######################################################
                # Update weights
                decoder_optimizer.step()
                if encoder_optimizer is not None:
                    encoder_optimizer.step()

                #######################################################
                # (*) Prediction and Inception score:
                ######################################################
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.item())
                    summary_G = summary.scalar('G_loss', errG_total.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                count += 1

                #######################################################
                # (*) Save Images/Log/Model per SNAPSHOT_INTERVAL:
                ######################################################
                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count,
                               self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    self.fake_imgs, _, _ = self.netG(fixed_noise,
                                                     self.txt_embedding)
                    save_img_results(self.imgs_tcpu, self.fake_imgs,
                                     self.num_Ds, count, self.image_dir,
                                     self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        # print('mean:', mean, 'std', std)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = negative_log_posterior_probability(
                            predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []

            end_t = time.time()
            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs
                      '''

                  # D(real): %.4f D(wrong):%.4f  D(fake) %.4f
                  %
                  (epoch, self.max_epoch, self.num_batches, errD_total.item(),
                   errG_total.item(), kl_loss.item(), end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        self.summary_writer.close()

    def save_superimages(self, images_list, filenames, save_dir, split_dir,
                         imsize):
        batch_size = images_list[0].size(0)
        num_sentences = len(images_list)
        for i in range(batch_size):
            s_tmp = '%s/super/%s/%s' %\
                (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)
            #
            savename = '%s_%d.png' % (s_tmp, imsize)
            super_img = []
            for j in range(num_sentences):
                img = images_list[j][i]
                # print(img.size())
                img = img.view(1, 3, imsize, imsize)
                # print(img.size())
                super_img.append(img)
                # break
            super_img = torch.cat(super_img, 0)
            vutils.save_image(super_img, savename, nrow=10, normalize=True)

    def save_singleimages(self, images, filenames, save_dir, split_dir,
                          sentenceID, imsize):
        for i in range(images.size(0)):
            s_tmp = '%s/single_samples/%s/%s' %\
                (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)

            fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID)
            # range from [-1, 1] to [0, 255]
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                fake_img_list = []
                for i in range(embedding_dim):
                    fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :])
                    if cfg.TEST.B_EXAMPLE:
                        # fake_img_list.append(fake_imgs[0].data.cpu())
                        # fake_img_list.append(fake_imgs[1].data.cpu())
                        fake_img_list.append(fake_imgs[2].data.cpu())
                    else:
                        self.save_singleimages(fake_imgs[-1], filenames,
                                               save_dir, split_dir, i, 256)
                        # self.save_singleimages(fake_imgs[-2], filenames,
                        #                        save_dir, split_dir, i, 128)
                        # self.save_singleimages(fake_imgs[-3], filenames,
                        #                        save_dir, split_dir, i, 64)
                    # break
                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    self.save_superimages(fake_img_list, filenames, save_dir,
                                          split_dir, 256)
示例#23
0
class GANTrainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
        self.max_objects = 3

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,  map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD


    def train(self, data_loader):
        netG, netD = self.load_network_stageI()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))

        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))

        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999))

        print("Starting training...")
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, bbox, label = data

                real_imgs = Variable(real_img_cpu)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    bbox = bbox.cuda()
                    label_one_hot = label.cuda().float()

                bbox = bbox.view(-1, 4)
                transf_matrices_inv = compute_transformation_matrix_inverse(bbox).float()
                transf_matrices_inv = transf_matrices_inv.view(real_imgs.shape[0], self.max_objects, 2, 3)
                transf_matrices = compute_transformation_matrix(bbox).float()
                transf_matrices = transf_matrices.view(real_imgs.shape[0], self.max_objects, 2, 3)

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, transf_matrices_inv, label_one_hot)
                # _, fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus)
                _, fake_imgs = netG(noise, transf_matrices_inv, label_one_hot)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               label_one_hot, transf_matrices, transf_matrices_inv, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot,
                                              transf_matrices, transf_matrices_inv, self.gpus)
                errG_total = errG
                errG_total.backward()
                optimizerG.step()

                ############################
                # (3) Log results
                ###########################
                count += 1
                if i % 500 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        inputs = (noise, transf_matrices_inv, label_one_hot)
                        lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus)
                        real_img_cpu = pad_imgs(real_img_cpu)
                        fake = pad_imgs(fake)
                        save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                        if lr_fake is not None:
                            save_img_results(None, lr_fake, epoch, self.image_dir)
            with torch.no_grad():
                inputs = (noise, transf_matrices_inv, label_one_hot)
                lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus)
                real_img_cpu = pad_imgs(real_img_cpu)
                fake = pad_imgs(fake)
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                if lr_fake is not None:
                    save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.item(), errG.item(),
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()

    def sample(self, datapath, num_samples=25, draw_bbox=True, num_digits_per_img=3, change_bbox_size=False):
        from PIL import Image, ImageDraw, ImageFont
        import cPickle as pickle
        import torchvision
        import torchvision.utils as vutils
        img_dir = os.path.join(datapath, "normal", "imgs/")
        netG, _ = self.load_network_stageI()
        netG.eval()
        test_set_size = 10000

        label, bbox = load_validation_data(datapath)
        if num_digits_per_img < 3:
            label = label[:, :num_digits_per_img, :]
            bbox = bbox[:, :num_digits_per_img, ...]
        elif num_digits_per_img > 3:
            def get_one_hot(targets, nb_classes):
                res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
                return res.reshape(list(targets.shape) + [nb_classes])

            labels_sample = np.random.randint(0, 10, size=(bbox.shape[0], num_digits_per_img-3))
            labels_sample = get_one_hot(labels_sample, 10)
            labels_new = np.zeros((label.shape[0], num_digits_per_img, 10))
            labels_new[:, :3, :] = label
            labels_new[:, 3:, :] = labels_sample
            label = torch.from_numpy(labels_new)

            bboxes_x = np.random.random((bbox.shape[0], num_digits_per_img-3, 1))
            bboxes_y = np.random.random((bbox.shape[0], num_digits_per_img-3, 1))
            bboxes_w = np.random.randint(10, 20, size=(bbox.shape[0], num_digits_per_img-3, 1)) / 64.0
            bboxes_h = np.random.randint(16, 20, size=(bbox.shape[0], num_digits_per_img-3, 1)) / 64.0

            bbox_new_concat = np.concatenate((bboxes_x, bboxes_y, bboxes_w, bboxes_h), axis=2)
            bbox_new = np.zeros([bbox.shape[0], num_digits_per_img, 4])
            bbox_new[:, :3, :] = bbox
            bbox_new[:, 3:, :] = bbox_new_concat
            bbox = torch.from_numpy(bbox_new)

        if change_bbox_size:
            bbox_idx = np.random.randint(0, bbox.shape[1])
            scale_x = np.random.random(bbox.shape[0])
            scale_x[scale_x < 0.5] = 0.5
            scale_y = np.random.random(bbox.shape[0])
            scale_y[scale_y < 0.5] = 0.5

            bbox[:, bbox_idx, 2] *= torch.from_numpy(scale_x)
            bbox[:, bbox_idx, 3] *= torch.from_numpy(scale_y)

        filepath = os.path.join(datapath, "normal", 'filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames = pickle.load(f)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_samples_" + str(num_digits_per_img) + "_digits"
        if change_bbox_size:
            save_dir += "_change_bbox_size"
        print("Saving {} to {}:".format(num_samples, save_dir))
        mkdir_p(save_dir)
        if cfg.CUDA:
            bbox = bbox.cuda()
            label_one_hot = label.cuda().float()

        #######################################
        bbox_ = bbox.clone()
        bbox = bbox.view(-1, 4)
        transf_matrices_inv = compute_transformation_matrix_inverse(bbox).float()
        transf_matrices_inv = transf_matrices_inv.view(test_set_size, num_digits_per_img, 2, 3)
        #######################################

        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(9, nz))
        if cfg.CUDA:
            noise = noise.cuda()

        imsize = 64

        for count in range(num_samples):
            index = int(np.random.randint(0, test_set_size, 1))
            key = filenames[index].split("/")[-1]
            img_name = img_dir + key
            img = Image.open(img_name)
            val_image = torchvision.transforms.functional.to_tensor(img)
            val_image = val_image.view(1, 1, imsize, imsize)
            val_image = (val_image - 0.5) * 2

            transf_matrices_inv_batch = transf_matrices_inv[index]
            label_one_hot_batch = label_one_hot[index]

            transf_matrices_inv_batch = transf_matrices_inv_batch.view(1, num_digits_per_img, 2, 3).repeat(9, 1, 1, 1)
            label_one_hot_batch = label_one_hot_batch.view(1, num_digits_per_img, 10).repeat(9, 1, 1)

            if cfg.CUDA:
                label_one_hot_batch = label_one_hot_batch.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            inputs = (noise, transf_matrices_inv_batch, label_one_hot_batch, num_digits_per_img)
            _, fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus)

            data_img = torch.FloatTensor(20, 1, imsize, imsize).fill_(0)
            data_img[0] = val_image
            data_img[1:10] = fake_imgs

            if draw_bbox:
                for idx in range(num_digits_per_img):
                    x, y, w, h = tuple([int(imsize*x) for x in bbox_[index, idx]])
                    w = imsize-1 if w > imsize-1 else w
                    h = imsize-1 if h > imsize-1 else h
                    while x + w >= 64:
                        x -= 1
                        w -= 1
                    while y + h >= 64:
                        y -= 1
                        h -= 1
                    if x <= -1:
                        break
                    data_img[:10, :, y, x:x + w] = 1
                    data_img[:10, :, y:y + h, x] = 1
                    data_img[:10, :, y+h, x:x + w] = 1
                    data_img[:10, :, y:y + h, x + w] = 1

            # write digit identities into image
            text_img = Image.new('L', (imsize*10, imsize), color = 'white')
            d = ImageDraw.Draw(text_img)
            label = label_one_hot_batch[0]
            label = label.cpu().numpy()
            label = np.argmax(label, axis=1)
            label = ", ".join([str(label[_]) for _ in range(num_digits_per_img)])
            d.text((10,10), label)
            text_img = torchvision.transforms.functional.to_tensor(text_img)
            text_img = torch.chunk(text_img, 10, 2)
            text_img = torch.cat([text_img[i].view(1, 1, imsize, imsize) for i in range(10)], 0)
            data_img[10:] = text_img
            vutils.save_image(data_img, '{}/vis_{}.png'.format(save_dir, count), normalize=True, nrow=10)
        print("Saved {} files to {}".format(count+1, save_dir))
示例#24
0
class GANTrainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            #state_dict = \
                #torch.load(cfg.NET_G,
                          # map_location=lambda storage, loc: storage)
            state_dict = torch.load(cfg.NET_G, map_location='cuda:0')
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    # ############# For training stageII GAN  #############
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = \
                torch.load(cfg.STAGE1_G,
                           map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict)
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        with torch.no_grad():
            #Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     #volatile=True)
            fixed_noise = \
                torch.FloatTensor(batch_size, nz).normal_(0, 1)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            #print('dataLoader, line 156 trainer.py...........')
            #print(data_loader)
            num_batches = len(data_loader)
            print('Number of batches: ' + str(len(data_loader)))
            for i, data in enumerate(data_loader, 0):

                print('Epoch number: ' + str(epoch) + '\tBatches: ' + str(i) + '/' + str(num_batches), end='\r')
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                #print(txt_embedding.shape)  #(Batch_size,1024)
                #exit(0)
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()
                #print('train line 170')
                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                #print('Fake images generated shape = ' + str(fake_imgs.shape))
                #print('Shape of fake image: ' + str(fake_imgs.shape))  [Batch_size, Channels(3), N, N]
                #print('Fake images: ')
                #Display one image
                ### Check this line! How to display image?? ##############
                #plt.imshow(fake_imgs[0].permute(1,2,0).cpu().detach().numpy())
                #exit(0)

                ################################################

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                #print('train line 186')
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs,
                                              real_labels, mu, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()
                #print('train line 203')


                count = count + 1
                if i % 100 == 0:

                    """
                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    """
                    ## My lines
                    summary_D = summary.scalar('D_loss', errD.data)
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data)
                    summary_KL = summary.scalar('KL_loss', kl_loss.data)
                    #### End of my lines

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
                    del inputs
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.data, errG.data, kl_loss.data,
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            #    % (epoch, self.max_epoch, i, len(data_loader),
            #       errD.data[0], errG.data[0], kl_loss.data[0],
            #       errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            print('################EPOCH COMPLETED###########')
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)

        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()

    def sample(self, datapath, stage=1):
        if stage == 1:
            netG, _ = self.load_network_stageI()
        else:
            netG, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        t_file = torchfile.load(datapath)
        #with open(datapath, 'rb') as f:
        #    embeddings = pickle.load(f)
        #embeddings = np.array(embeddings)
        captions_list = t_file.raw_txt
        embeddings = np.concatenate(t_file.fea_txt, axis=0)
        num_embeddings = len(captions_list)
        #num_embeddings = len(embeddings)
        print('Successfully load sentences from: ', datapath)
        print('Total number of sentences:', num_embeddings)
        print('num_embeddings:', num_embeddings, embeddings.shape)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')]
        mkdir_p(save_dir)

        batch_size = np.minimum(num_embeddings, self.batch_size)
        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        if cfg.CUDA:
            noise = noise.cuda()
        count = 0
        while count < num_embeddings:
            if count > 3000:
                break
            iend = count + batch_size
            if iend > num_embeddings:
                iend = num_embeddings
                count = num_embeddings - batch_size
            embeddings_batch = embeddings[count:iend]
            # captions_batch = captions_list[count:iend]
            txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
            if cfg.CUDA:
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            inputs = (txt_embedding, noise)
            _, fake_imgs, mu, logvar = \
                nn.parallel.data_parallel(netG, inputs, self.gpus)
            for i in range(batch_size):
                save_name = '%s/%d.png' % (save_dir, count + i)
                im = fake_imgs[i].data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                # print('im', im.shape)
                im = np.transpose(im, (1, 2, 0))
                # print('im', im.shape)
                im = Image.fromarray(im)
                im.save(save_name)
            count += batch_size
示例#25
0
class condGANTrainer(object):
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        imgs, w_imgs, t_embedding, _ = data

        real_vimgs, wrong_vimgs = [], []
        if cfg.CUDA:
            vembedding = Variable(t_embedding).cuda()
        else:
            vembedding = Variable(t_embedding)
        for i in range(self.num_Ds):
            if cfg.CUDA:
                real_vimgs.append(Variable(imgs[i]).cuda())
                wrong_vimgs.append(Variable(w_imgs[i]).cuda())
            else:
                real_vimgs.append(Variable(imgs[i]))
                wrong_vimgs.append(Variable(w_imgs[i]))
        return imgs, real_vimgs, wrong_vimgs, vembedding

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu = self.criterion, self.mu

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        wrong_imgs = self.wrong_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        #
        netD.zero_grad()
        # Forward
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        # for real
        real_logits = netD(real_imgs, mu.detach())
        wrong_logits = netD(wrong_imgs, mu.detach())
        fake_logits = netD(fake_imgs.detach(), mu.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_wrong = criterion(wrong_logits[0], fake_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(real_logits[1], real_labels)
            errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(wrong_logits[1], real_labels)
            errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
                criterion(fake_logits[1], fake_labels)
            #
            errD_real = errD_real + errD_real_uncond
            errD_wrong = errD_wrong + errD_wrong_uncond
            errD_fake = errD_fake + errD_fake_uncond
            #
            errD = errD_real + errD_wrong + errD_fake
        else:
            errD = errD_real + 0.5 * (errD_wrong + errD_fake)
        # backward
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
            self.summary_writer.add_summary(summary_D, count)
        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion, mu, logvar = self.criterion, self.mu, self.logvar
        real_labels = self.real_labels[:batch_size]
        for i in range(self.num_Ds):
            outputs = self.netsD[i](self.fake_imgs[i], mu)
            errG = criterion(outputs[0], real_labels)
            if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
                errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\
                    criterion(outputs[1], real_labels)
                errG = errG + errG_patch
            errG_total = errG_total + errG
            if flag == 0:
                summary_D = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_D, count)

        # Compute color consistency losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0])
                    self.summary_writer.add_summary(sum_cov, count)
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1
                if flag == 0:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0])
                    self.summary_writer.add_summary(sum_cov, count)

        kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL
        errG_total = errG_total + kl_loss
        errG_total.backward()
        self.optimizerG.step()
        return kl_loss, errG_total

    def train(self):
        self.netG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))

        self.gradient_one = torch.FloatTensor([1.0])
        self.gradient_half = torch.FloatTensor([0.5])

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()
            self.gradient_one = self.gradient_one.cuda()
            self.gradient_half = self.gradient_half.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \
                    self.txt_embedding = self.prepare_data(data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, self.mu, self.logvar = \
                    self.netG(noise, self.txt_embedding)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                kl_loss, errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # for inception score
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.data[0])
                    summary_G = summary.scalar('G_loss', errG_total.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                count = count + 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    self.fake_imgs, _, _ = \
                        self.netG(fixed_noise, self.txt_embedding)
                    save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds,
                                     count, self.image_dir, self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        # print('mean:', mean, 'std', std)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []

            end_t = time.time()
            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Loss_KL: %.2f Time: %.2fs
                      '''  # D(real): %.4f D(wrong):%.4f  D(fake) %.4f
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.data[0], errG_total.data[0],
                     kl_loss.data[0], end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        self.summary_writer.close()

    def save_superimages(self, images_list, filenames,
                         save_dir, split_dir, imsize):
        batch_size = images_list[0].size(0)
        num_sentences = len(images_list)
        for i in range(batch_size):
            s_tmp = '%s/super/%s/%s' %\
                (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)
            #
            savename = '%s_%d.png' % (s_tmp, imsize)
            super_img = []
            for j in range(num_sentences):
                img = images_list[j][i]
                # print(img.size())
                img = img.view(1, 3, imsize, imsize)
                # print(img.size())
                super_img.append(img)
                # break
            super_img = torch.cat(super_img, 0)
            vutils.save_image(super_img, savename, nrow=10, normalize=True)

    def save_singleimages(self, images, filenames,
                          save_dir, split_dir, sentenceID, imsize):
        for i in range(images.size(0)):
            s_tmp = '%s/single_samples/%s/%s' %\
                (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)

            fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID)
            # range from [-1, 1] to [0, 255]
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            if split_dir == 'test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                fake_img_list = []
                for i in range(embedding_dim):
                    fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :])
                    if cfg.TEST.B_EXAMPLE:
                        # fake_img_list.append(fake_imgs[0].data.cpu())
                        # fake_img_list.append(fake_imgs[1].data.cpu())
                        fake_img_list.append(fake_imgs[2].data.cpu())
                    else:
                        self.save_singleimages(fake_imgs[-1], filenames,
                                               save_dir, split_dir, i, 256)
                        # self.save_singleimages(fake_imgs[-2], filenames,
                        #                        save_dir, split_dir, i, 128)
                        # self.save_singleimages(fake_imgs[-3], filenames,
                        #                        save_dir, split_dir, i, 64)
                    # break
                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    self.save_superimages(fake_img_list, filenames,
                                          save_dir, split_dir, 256)
示例#26
0
    sys.exit()

if (torch.cuda.device_count() > 0) & config['cuda']:
    batch_size = config["batch_size"]
    batch_size *= torch.cuda.device_count()
    cuda = True
else:
    batch_size = 100
    cuda = False

#%% load data
print("Preparing data")
dir_data = 'logdir/%s/data' % (args.model_dir)

##train dataset
log_writer_train = FileWriter('%s/TB/train' % logdir)
fname_train = '%s/%s' % (dir_data, config['data_train'][0])
if os.path.exists('%s' % fname_train):
    if os.path.splitext(fname_train)[-1] == '.pkl':
        #if training data is pickle file
        with open(fname_train, 'rb') as f:
            X_train = pickle.load(f)
    else:
        #!!! EXPERIMENTAL !!!
        X_train = load_npy_files('%s/*.npy' % (fname_train))
else:
    print("Train data does not exist. Run data preparation. Exiting")
    sys.exit()

##val dataset
log_writer_val = FileWriter('%s/TB/val' % logdir)
示例#27
0
class GANTrainer(object):
    def __init__(self, output_dir, data_loader, imsize):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.data_loader = data_loader
        self.num_batches = len(self.data_loader)

    def prepare_data(self, data):
        imgs = data

        vimgs = []
        for i in range(self.num_Ds):
            if cfg.CUDA:
                vimgs.append(Variable(imgs[i]).cuda())
            else:
                vimgs.append(Variable(imgs[i]))

        return imgs, vimgs

    def train_Dnet(self, idx, count):
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        real_imgs = self.real_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        real_labels = self.real_labels[:batch_size]
        fake_labels = self.fake_labels[:batch_size]
        #
        netD.zero_grad()
        #
        real_logits = netD(real_imgs)
        fake_logits = netD(fake_imgs.detach())
        #
        errD_real = criterion(real_logits[0], real_labels)
        errD_fake = criterion(fake_logits[0], fake_labels)
        #
        errD = errD_real + errD_fake
        errD.backward()
        # update parameters
        optD.step()
        # log
        if flag == 0:
            summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
            self.summary_writer.add_summary(summary_D, count)
        return errD

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 100
        batch_size = self.real_imgs[0].size(0)
        criterion = self.criterion
        real_labels = self.real_labels[:batch_size]

        for i in range(self.num_Ds):
            netD = self.netsD[i]
            outputs = netD(self.fake_imgs[i])
            errG = criterion(outputs[0], real_labels)
            # errG = self.stage_coeff[i] * errG
            errG_total = errG_total + errG
            if flag == 0:
                summary_G = summary.scalar('G_loss%d' % i, errG.data[0])
                self.summary_writer.add_summary(summary_G, count)

        # Compute color preserve losses
        if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
            if self.num_Ds > 1:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-2].detach())
                like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu2 + like_cov2
            if self.num_Ds > 2:
                mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
                mu2, covariance2 = \
                    compute_mean_covariance(self.fake_imgs[-3].detach())
                like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
                like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
                    nn.MSELoss()(covariance1, covariance2)
                errG_total = errG_total + like_mu1 + like_cov1

            if flag == 0:
                sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0])
                self.summary_writer.add_summary(sum_mu, count)
                sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0])
                self.summary_writer.add_summary(sum_cov, count)
                if self.num_Ds > 2:
                    sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0])
                    self.summary_writer.add_summary(sum_mu, count)
                    sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0])
                    self.summary_writer.add_summary(sum_cov, count)

        errG_total.backward()
        self.optimizerG.step()
        return errG_total

    def train(self):
        print("Try to load network")
        self.netG, self.netsD, self.num_Ds,\
            self.inception_model, start_count = load_network(self.gpus)
        print("Network loaded")
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.BCELoss()

        self.real_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(1))
        self.fake_labels = \
            Variable(torch.FloatTensor(self.batch_size).fill_(0))
        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            self.real_labels = self.real_labels.cuda()
            self.fake_labels = self.fake_labels.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, data in enumerate(self.data_loader, 0):
                #######################################################
                # (0) Prepare training data
                ######################################################
                self.imgs_tcpu, self.real_imgs = self.prepare_data(data)

                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                self.fake_imgs, _, _ = self.netG(noise)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                # for inception score
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())

                if count % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.data[0])
                    summary_G = summary.scalar('G_loss', errG_total.data[0])
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                if step == 0:
                    print('''[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f'''
                           % (epoch, self.max_epoch, step, self.num_batches,
                              errD_total.data[0], errG_total.data[0]))
                count = count + 1

                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
                    save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    self.fake_imgs, _, _ = self.netG(fixed_noise)
                    save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds,
                                    count, self.image_dir, self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # Compute inception score
                    if len(predictions) > 500:
                        predictions = np.concatenate(predictions, 0)
                        mean, std = compute_inception_score(predictions, 10)
                        # print('mean:', mean, 'std', std)
                        m_incep = summary.scalar('Inception_mean', mean)
                        self.summary_writer.add_summary(m_incep, count)
                        #
                        mean_nlpp, std_nlpp = \
                            negative_log_posterior_probability(predictions, 10)
                        m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                        self.summary_writer.add_summary(m_nlpp, count)
                        #
                        predictions = []

            end_t = time.time()
            print('Total Time: %.2fsec' % (end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)

        self.summary_writer.close()

    def save_superimages(self, images, folder, startID, imsize):
        fullpath = '%s/%d_%d.png' % (folder, startID, imsize)
        vutils.save_image(images.data, fullpath, normalize=True)

    def save_singleimages(self, images, folder, startID, imsize):
        for i in range(images.size(0)):
            fullpath = '%s/%d_%d.png' % (folder, startID + i, imsize)
            # range from [-1, 1] to [0, 1]
            img = (images[i] + 1.0) / 2
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            # range from [0, 1] to [0, 255]
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d/%s' % (s_tmp, iteration, split_dir)
            if cfg.TEST.B_EXAMPLE:
                folder = '%s/super' % (save_dir)
            else:
                folder = '%s/single' % (save_dir)
            print('Make a new folder: ', folder)
            mkdir_p(folder)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            num_batches = int(cfg.TEST.SAMPLE_NUM / self.batch_size)
            cnt = 0
            for step in xrange(num_batches):
                noise.data.normal_(0, 1)
                fake_imgs, _, _ = netG(noise)
                if cfg.TEST.B_EXAMPLE:
                    self.save_superimages(fake_imgs[-1], folder, cnt, 256)
                else:
                    self.save_singleimages(fake_imgs[-1], folder, cnt, 256)
                    # self.save_singleimages(fake_imgs[-2], folder, 128)
                    # self.save_singleimages(fake_imgs[-3], folder, 64)
                cnt += self.batch_size
示例#28
0
class condGANTrainer(object):
    def __init__(self, output_dir, label_loader, unlabel_loader):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            self.theta = 0.5
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        self.label_loader = label_loader
        self.unlabel_loader = unlabel_loader
        self.num_batches = len(self.unlabel_loader)
        self.label_num_batches = len(self.label_loader)

    def prepare_data(self, data):
        imgs, labels, label_vectors = data
        #print(labels)
        #error_label_vectors = self.get_error_label(labels)
        real_vimgs = []
        if cfg.CUDA:
            vembedding = Variable(label_vectors).cuda()
            labels = Variable(labels).cuda()
            #error_label_vectors = Variable(error_label_vectors).cuda()
        else:
            vembedding = Variable(label_vectors)
            labels = Variable(labels)
            #error_label_vectors = Variable(error_label_vectors)
        for i in range(self.num_Ds):
            print("images_{}:{}".format(i, imgs[i].size()))
            if cfg.CUDA:
                real_vimgs.append(Variable(imgs[i]).cuda())
            else:
                real_vimgs.append(Variable(imgs[i]))
        return imgs, real_vimgs, labels, vembedding  #, error_label_vectors

    def get_extend(self, labels):
        ll = torch.zeros(self.batch_size, 1)

        extend_label = torch.cat((labels, Variable(ll)), 1)
        return extend_label

    def piece_wise(self, logits, count):
        logits_1 = logits > (0.5 - self.theta)
        logits_1 = logits_1.type(torch.FloatTensor)
        if cfg.CUDA:
            logits_1 = logits_1.cuda()
        logits_2 = logits_1 * logits
        logits_3 = logits > (0.5 + self.theta)
        logits_3 = logits_3.type(torch.FloatTensor)
        if cfg.CUDA:
            logits_3 = logits_3.cuda()
        logits_4 = (logits_3 + logits_2) - (logits_3 * logits_2)
        #print(logits_4)
        return logits_4

    def log_sum_exp(self, value, dim=None, keepdim=True):
        if dim is not None:
            m, _ = torch.max(value, dim=dim, keepdim=True)
            value0 = value - m
            if keepdim is False:
                m = m.squeeze(dim)
            return m + torch.log(
                torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
        else:
            m = torch.max(value)
            sum_exp = torch.sum(torch.exp(value - m))
            if isinstance(sum_exp, Number):
                return m + math.log(sum_exp)
            else:
                return m + torch.log(sum_exp)

    def train_Dnet(self, idx, count):
        flag = count % 25
        batch_size = cfg.TRAIN.BATCH_SIZE
        print("batch_size:%d" % batch_size)
        criterion = self.criterion

        netD, optD = self.netsD[idx], self.optimizersD[idx]
        label_imgs = self.label_real_imgs[idx]
        unlabel_imgs = self.unlabel_real_imgs[idx]
        fake_imgs = self.fake_imgs[idx]
        fake_imgs_2 = self.fake_imgs_2[idx]

        netD.zero_grad()

        lab_labels = self.labels[:batch_size].type(torch.LongTensor)
        if cfg.CUDA:
            lab_labels = lab_labels.cuda()
        error_labels = self.error_labels[:batch_size]
        print("unlabel_imgs:{}".format(unlabel_imgs.size()))
        unlabel_logits, unlabel_softmax_out, unlabel_hash_logits, _ = netD(
            unlabel_imgs)
        label_logits, label_softmax_out, label_hash_logits, _ = netD(
            label_imgs)
        fake_logits, fake_softmax_out, fake_hash_logits, _ = netD(
            fake_imgs.detach())
        fake2_logits, fake2_softmax_out, fake2_hash_logits, _ = netD(
            fake_imgs_2.detach())

        # standard classfication loss
        lab_loss = criterion(label_logits, lab_labels)
        fake_lab_loss = criterion(fake_logits, lab_labels)
        fake2_lab_loss = criterion(fake_logits, error_labels)

        supvised_loss = (lab_loss + fake_lab_loss + fake2_lab_loss) / 3

        # GAN true-fake loss   adversary stream
        unl_logsumexp = self.log_sum_exp(unlabel_logits, 1)
        fake_logsumexp = self.log_sum_exp(fake_logits, 1)
        fake2_logsumexp = self.log_sum_exp(fake2_logits, 1)

        true_loss = -0.5 * torch.mean(unl_logsumexp) + 0.5 * torch.mean(
            F.softplus(unl_logsumexp))
        fake_loss = 0.5 * torch.mean(F.softplus(fake_logsumexp))
        fake2_loss = 0.5 * torch.mean(F.softplus(fake2_logsumexp))
        adversary_loss = (true_loss + fake_loss + fake2_loss) / 3

        # loss for hash
        print("label_hash:{},fake_hash{}".format(label_hash_logits.size(),
                                                 fake_logits.size()))
        positive = torch.sum((label_hash_logits - fake_hash_logits)**2, 1)
        negtive = torch.sum((label_hash_logits - fake2_hash_logits)**2, 1)
        hash_loss = 1 + positive - negtive
        hash_loss_temp = hash_loss > 0
        hash_loss_temp = hash_loss_temp.type(torch.FloatTensor)
        if cfg.CUDA:
            hash_loss_temp = hash_loss_temp.cuda()
        hash_loss = torch.mean(hash_loss * hash_loss_temp)

        d_total_loss = supvised_loss + adversary_loss + hash_loss
        print("d_supervied_loss_{0}:{1}".format(idx, supvised_loss.data[0]))
        print("d_adversary_loss_{0}:{1}".format(idx, adversary_loss.data[0]))
        print("d_hash_loss_{0}:{1}".format(idx, hash_loss.data[0]))
        print("d_total_loss_{0}:{1}".format(idx, d_total_loss.data[0]))

        # adversary stream
        # for true
        # errD_real = criterion(real_logits, real_labels)
        # #errD_wrong = criterion(wrong_logits[0], fake_labels)
        # errD_fake = criterion(fake_logits, fake_labels)
        # # if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
        # #     errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
        # #                        criterion(real_logits[1], real_labels)
        # #     # errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
        # #     #                     criterion(wrong_logits[1], real_labels)
        # #     errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
        # #                        criterion(fake_logits[1], fake_labels)
        # #     #
        # #     errD_real = errD_real + errD_real_uncond
        # #     # errD_wrong = errD_wrong + errD_wrong_uncond
        # #     errD_fake = errD_fake + errD_fake_uncond
        # #     #
        # #     errD = errD_real + errD_wrong + errD_fake
        # # else:
        # #     errD = errD_real + 0.5 * (errD_wrong + errD_fake)

        # errD = errD_real + errD_fake
        # # # backward
        # errD.backward()
        # update parameters
        d_total_loss.backward()
        optD.step()

        # log
        if flag == 0:
            summary_D = summary.scalar('D_supervised%d' % idx,
                                       supvised_loss.data[0])
            summary_D1 = summary.scalar('D_hash_loss_%d' % idx,
                                        hash_loss.data[0])
            summary_D2 = summary.scalar('D_total_loss_%d' % idx,
                                        d_total_loss.data[0])
            summary_D3 = summary.scalar('D_adversary_loss_%d' % idx,
                                        adversary_loss.data[0])

            self.summary_writer.add_summary(summary_D, count)
            self.summary_writer.add_summary(summary_D1, count)
            self.summary_writer.add_summary(summary_D2, count)
            self.summary_writer.add_summary(summary_D3, count)
        return d_total_loss

    def train_Gnet(self, count):
        self.netG.zero_grad()
        errG_total = 0
        flag = count % 25
        batch_size = cfg.TRAIN.BATCH_SIZE
        print("batch_size:%d" % batch_size)
        criterion = self.criterion

        for i in range(self.num_Ds):

            netD = self.netsD[i]
            label_imgs = self.label_real_imgs[i]
            unlabel_imgs = self.unlabel_real_imgs[i]
            fake_imgs = self.fake_imgs[i]
            fake_imgs_2 = self.fake_imgs_2[i]

            lab_labels = self.labels[:batch_size].type(torch.LongTensor)
            if cfg.CUDA:
                lab_labels = lab_labels.cuda()
            error_labels = self.error_labels[:batch_size]
            label_logits, label_softmax_out, label_hash_logits, _ = netD(
                label_imgs)
            unlabel_logits, unlabel_softmax_out, unlabel_hash_logits, _ = netD(
                unlabel_imgs)
            fake_logits, fake_softmax_out, fake_hash_logits, _ = netD(
                fake_imgs)
            fake2_logits, fake2_softmax_out, fake2_hash_logits, _ = netD(
                fake_imgs_2)

            # standard classfication loss
            lab_loss = criterion(label_logits, lab_labels)
            fake_lab_loss = criterion(fake_logits, lab_labels)
            fake2_lab_loss = criterion(fake_logits, error_labels)

            supvised_loss = (lab_loss + fake_lab_loss + fake2_lab_loss) / 3

            # GAN true-fake loss   adversary stream
            unl_logsumexp = self.log_sum_exp(unlabel_logits)
            fake_logsumexp = self.log_sum_exp(fake_logits)
            fake2_logsumexp = self.log_sum_exp(fake2_logits)

            true_loss = -0.5 * torch.mean(unl_logsumexp) + 0.5 * torch.mean(
                F.softplus(unl_logsumexp))
            fake_loss = 0.5 * torch.mean(F.softplus(fake_logsumexp))
            fake2_loss = 0.5 * torch.mean(F.softplus(fake2_logsumexp))
            adversary_loss = (true_loss + fake_loss + fake2_loss) / 3

            # loss for hash
            positive = torch.sum((label_hash_logits - fake_hash_logits)**2, 1)
            negtive = torch.sum((label_hash_logits - fake2_hash_logits)**2, 1)
            hash_loss = 1 + positive - negtive
            hash_loss_temp = hash_loss > 0
            hash_loss_temp = hash_loss_temp.type(torch.FloatTensor)
            if cfg.CUDA:
                hash_loss_temp = hash_loss_temp.cuda()
            hash_loss = torch.mean(hash_loss * hash_loss_temp)

            g_total_loss = supvised_loss - adversary_loss + hash_loss
            print("g_supervied_loss_{0}:{1}".format(i, supvised_loss.data[0]))
            print("g_adversary_loss_{0}:{1}".format(i, adversary_loss.data[0]))
            print("g_hash_loss_{0}:{1}".format(i, hash_loss.data[0]))
            print("g_total_loss_{0}:{1}".format(i, g_total_loss.data[0]))

            # if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
            #     errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS * \
            #                  criterion(outputs[1], real_labels)
            #     errG = errG + errG_patch
            errG_total = errG_total + g_total_loss  # + hash_loss
            print("g_loss_%d: %f" % (i, errG_total.data[0]))

            if flag == 0:
                summary_D = summary.scalar('G_supervised%d' % i,
                                           supvised_loss.data[0])
                summary_D1 = summary.scalar('G_hash_loss_%d' % i,
                                            hash_loss.data[0])
                summary_D2 = summary.scalar('G_total_loss_%d' % i,
                                            g_total_loss.data[0])
                summary_D3 = summary.scalar('G_adversary_loss_%d' % i,
                                            adversary_loss.data[0])

                self.summary_writer.add_summary(summary_D, count)
                self.summary_writer.add_summary(summary_D1, count)
                self.summary_writer.add_summary(summary_D2, count)
                self.summary_writer.add_summary(summary_D3, count)

        #
        # # Compute color consistency losses
        # if cfg.TRAIN.COEFF.COLOR_LOSS > 0:
        #     if self.num_Ds > 1:
        #         mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-1])
        #         mu2, covariance2 = \
        #             compute_mean_covariance(self.fake_imgs[-2].detach())
        #         like_mu2 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
        #         like_cov2 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
        #                     nn.MSELoss()(covariance1, covariance2)
        #         errG_total = errG_total + like_mu2 + like_cov2
        #         if flag == 0:
        #             sum_mu = summary.scalar('G_like_mu2', like_mu2.data[0])
        #             self.summary_writer.add_summary(sum_mu, count)
        #             sum_cov = summary.scalar('G_like_cov2', like_cov2.data[0])
        #             self.summary_writer.add_summary(sum_cov, count)
        #     if self.num_Ds > 2:
        #         mu1, covariance1 = compute_mean_covariance(self.fake_imgs[-2])
        #         mu2, covariance2 = \
        #             compute_mean_covariance(self.fake_imgs[-3].detach())
        #         like_mu1 = cfg.TRAIN.COEFF.COLOR_LOSS * nn.MSELoss()(mu1, mu2)
        #         like_cov1 = cfg.TRAIN.COEFF.COLOR_LOSS * 5 * \
        #                     nn.MSELoss()(covariance1, covariance2)
        #         errG_total = errG_total + like_mu1 + like_cov1
        #         if flag == 0:
        #             sum_mu = summary.scalar('G_like_mu1', like_mu1.data[0])
        #             self.summary_writer.add_summary(sum_mu, count)
        #             sum_cov = summary.scalar('G_like_cov1', like_cov1.data[0])
        #             self.summary_writer.add_summary(sum_cov, count)

        # not use kl_loss
        # kl_loss = KL_loss(mu, logvar) * cfg.TRAIN.COEFF.KL
        # errG_total = errG_total + kl_loss
        errG_total.backward(retain_graph=True)
        self.optimizerG.step()
        return errG_total

    def get_error_label(self, labels):

        class_num = cfg.GAN.CLASS_NUM
        batch_size = cfg.TRAIN.BATCH_SIZE

        error_labels_vector = torch.FloatTensor(batch_size, class_num).zero_()
        error_labels = []
        for i in range(batch_size):
            m = randint(0, class_num - 1)
            while m == labels[i]:
                m = randint(0, 9)

            error_labels_vector[i][m] = 1
            error_labels.append(m)

        error_labels = torch.LongTensor(error_labels)
        if cfg.CUDA:
            error_labels_vector = Variable(error_labels_vector).cuda()
            error_labels = Variable(error_labels).cuda()

        else:
            error_labels_vector = Variable(error_labels_vector)
            error_labels = Variable(error_labels)
        return error_labels_vector, error_labels

    def adjust_lr(self, optimizer, epoch):
        epoch_ratio = float(epoch) / float(cfg.TRAIN.MAX_EPOCH)
        lr = max(cfg.TRAIN.DISCRIMINATOR_LR * min(3. * (1 - epoch_ratio), 1.),
                 0)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def train(self):
        self.netG, self.netsD, self.num_Ds, start_count = load_network(
            self.gpus)
        avg_param_G = copy_G_params(self.netG)

        self.optimizerG, self.optimizersD = \
            define_optimizers(self.netG, self.netsD)

        self.criterion = nn.CrossEntropyLoss()

        # self.real_labels = \
        #     Variable(torch.FloatTensor(self.batch_size).fill_(1))
        # self.fake_labels = \
        #     Variable(torch.FloatTensor(self.batch_size).fill_(0))

        self.gradient_one = torch.FloatTensor([1.0])
        self.gradient_half = torch.FloatTensor([0.5])

        nz = cfg.GAN.Z_DIM
        noise = Variable(torch.FloatTensor(self.batch_size, nz))
        noise_2 = Variable(torch.FloatTensor(self.batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(self.batch_size, nz))#.normal_(0, 1))

        if cfg.CUDA:
            self.criterion.cuda()
            # self.real_labels = self.real_labels.cuda()
            # self.fake_labels = self.fake_labels.cuda()
            self.gradient_one = self.gradient_one.cuda()
            self.gradient_half = self.gradient_half.cuda()
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

        predictions = []
        count = start_count
        start_epoch = start_count // (self.num_batches)
        label_iter = iter(self.label_loader)
        cnt = 0
        print("label_num_batch:{}".format(self.label_num_batches))
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()

            for step, unlabel_data in enumerate(self.unlabel_loader, 0):

                # print("data size:{0}".format(data))
                print("epoch:%d, step:%d" % (epoch, step))
                #######################################################
                # (0) Prepare training data
                ######################################################
                if cnt == self.label_num_batches:
                    label_iter = iter(self.label_loader)
                    cnt = 0
                cnt += 1
                label_data = label_iter.next()
                #print(label_data[0])
                self.label_imgs_tcpu, self.label_real_imgs, \
                self.labels, self.label_vectors = self.prepare_data(label_data)

                self.unlabel_imgs_tcpu, self.unlabel_real_imgs, \
                _, self.unlabel_vectors = self.prepare_data(unlabel_data)

                self.error_label_vector, self.error_labels = self.get_error_label(
                    label_data[1])
                #######################################################
                # (1) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                noise_2.data.normal_(0, 1)

                #
                self.fake_imgs = \
                    self.netG(noise, self.label_vectors)

                self.fake_imgs_2 = self.netG(noise_2, self.error_label_vector)

                #######################################################
                # (2) Update D network
                ######################################################
                errD_total = 0
                for i in range(self.num_Ds):
                    errD = self.train_Dnet(i, count)
                    errD_total += errD

                #######################################################
                # (3) Update G network: maximize log(D(G(z)))
                ######################################################
                self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)

                errG_total = self.train_Gnet(count)
                for p, avg_p in zip(self.netG.parameters(), avg_param_G):
                    avg_p.mul_(0.999).add_(0.001, p.data)
                """
                # for inception score
                pred = self.inception_model(self.fake_imgs[-1].detach())
                predictions.append(pred.data.cpu().numpy())
                """

                if count % 25 == 0:
                    summary_D = summary.scalar('D_loss', errD_total.data[0])
                    summary_G = summary.scalar('G_loss', errG_total.data[0])
                    # summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_G, count)
                    # self.summary_writer.add_summary(summary_KL, count)

                count = count + 1

                if count % 3000 == 0:
                    self.theta = self.theta * 0.8
                    print("theta:%f, count:%d" % (self.theta, count))
                if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
                    save_model(self.netG, avg_param_G, self.netsD, count,
                               self.model_dir)
                    # Save images
                    backup_para = copy_G_params(self.netG)
                    load_params(self.netG, avg_param_G)
                    #
                    fixed_noise.data.normal_(0, 1)
                    self.fake_imgs = \
                        self.netG(fixed_noise, self.label_vectors)
                    save_img_results(self.label_imgs_tcpu, self.fake_imgs,
                                     self.num_Ds, count, self.image_dir,
                                     self.summary_writer)
                    #
                    load_params(self.netG, backup_para)

                    # # Compute inception score
                    # if len(predictions) > 500:
                    #     predictions = np.concatenate(predictions, 0)
                    #     mean, std = compute_inception_score(predictions, 10)
                    #     # print('mean:', mean, 'std', std)
                    #     m_incep = summary.scalar('Inception_mean', mean)
                    #     self.summary_writer.add_summary(m_incep, count)
                    #     #
                    #     mean_nlpp, std_nlpp = \
                    #         negative_log_posterior_probability(predictions, 10)
                    #     m_nlpp = summary.scalar('NLPP_mean', mean_nlpp)
                    #     self.summary_writer.add_summary(m_nlpp, count)
                    #     #
                    #     predictions = []

                print(
                    "____________________________________________________________"
                )
                end_t = time.time()

            self.adjust_lr(self.optimizerG, epoch)
            for i in range(cfg.TREE.BRANCH_NUM):
                self.adjust_lr(self.optimizersD[i], epoch)

            print('''[%d/%d][%d]
                         Loss_D: %.2f Loss_G: %.2f Time: %.2fs
                      '''

                  # D(real): %.4f D(wrong):%.4f  D(fake) %.4f
                  % (epoch, self.max_epoch, self.num_batches,
                     errD_total.data[0], errG_total.data[0], end_t - start_t))

        save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
        self.summary_writer.close()

    def save_superimages(self, images_list, filenames, save_dir, split_dir,
                         imsize):
        batch_size = images_list[0].size(0)
        num_sentences = len(images_list)
        for i in range(batch_size):
            s_tmp = '%s/super/%s/%s' % \
                    (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)
            #
            savename = '%s_%d.png' % (s_tmp, imsize)
            super_img = []
            for j in range(num_sentences):
                img = images_list[j][i]
                # print(img.size())
                img = img.view(1, 3, imsize, imsize)
                # print(img.size())
                super_img.append(img)
                # break
            super_img = torch.cat(super_img, 0)
            vutils.save_image(super_img, savename, nrow=10, normalize=True)

    def save_singleimages(self, images, filenames, save_dir, split_dir,
                          sentenceID, imsize):
        for i in range(images.size(0)):
            s_tmp = '%s/single_samples/%s/%s' % \
                    (save_dir, split_dir, filenames[i])
            folder = s_tmp[:s_tmp.rfind('/')]
            if not os.path.isdir(folder):
                print('Make a new folder: ', folder)
                mkdir_p(folder)

            fullpath = '%s_%d_sentence%d.png' % (s_tmp, imsize, sentenceID)
            # range from [-1, 1] to [0, 255]
            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
            ndarr = img.permute(1, 2, 0).data.cpu().numpy()
            im = Image.fromarray(ndarr)
            im.save(fullpath)

    def evaluate(self, split_dir):
        if cfg.TRAIN.NET_G == '':
            print('Error: the path for morels is not found!')
        else:
            # Build and load the generator
            if split_dir == 'image_test':
                split_dir = 'valid'
            netG = G_NET()
            netG.apply(weights_init)
            netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
            print(netG)
            # state_dict = torch.load(cfg.TRAIN.NET_G)
            state_dict = \
                torch.load(cfg.TRAIN.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load ', cfg.TRAIN.NET_G)

            # the path to save generated images
            s_tmp = cfg.TRAIN.NET_G
            istart = s_tmp.rfind('_') + 1
            iend = s_tmp.rfind('.')
            iteration = int(s_tmp[istart:iend])
            s_tmp = s_tmp[:s_tmp.rfind('/')]
            save_dir = '%s/iteration%d' % (s_tmp, iteration)

            nz = cfg.GAN.Z_DIM
            noise = Variable(torch.FloatTensor(self.batch_size, nz))
            if cfg.CUDA:
                netG.cuda()
                noise = noise.cuda()

            # switch to evaluate mode
            netG.eval()
            for step, data in enumerate(self.data_loader, 0):
                imgs, t_embeddings, filenames = data
                if cfg.CUDA:
                    t_embeddings = Variable(t_embeddings).cuda()
                else:
                    t_embeddings = Variable(t_embeddings)
                # print(t_embeddings[:, 0, :], t_embeddings.size(1))

                embedding_dim = t_embeddings.size(1)
                batch_size = imgs[0].size(0)
                noise.data.resize_(batch_size, nz)
                noise.data.normal_(0, 1)

                fake_img_list = []
                for i in range(embedding_dim):
                    fake_imgs, _, _ = netG(noise, t_embeddings[:, i, :])
                    if cfg.TEST.B_EXAMPLE:
                        # fake_img_list.append(fake_imgs[0].data.cpu())
                        # fake_img_list.append(fake_imgs[1].data.cpu())
                        fake_img_list.append(fake_imgs[2].data.cpu())
                    else:
                        self.save_singleimages(fake_imgs[-1], filenames,
                                               save_dir, split_dir, i, 256)
                        # self.save_singleimages(fake_imgs[-2], filenames,
                        #                        save_dir, split_dir, i, 128)
                        # self.save_singleimages(fake_imgs[-3], filenames,
                        #                        save_dir, split_dir, i, 64)
                        # break
                if cfg.TEST.B_EXAMPLE:
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 64)
                    # self.save_superimages(fake_img_list, filenames,
                    #                       save_dir, split_dir, 128)
                    self.save_superimages(fake_img_list, filenames, save_dir,
                                          split_dir, 256)

    def get_hash(self, dataloader, netsD, dataset_name, save_steps):
        hash_dict = {}
        #imgs_total = None
        label_total = None
        output_dir = os.path.join("eval", dataset_name)
        print("len:%d" % len(self.unlabel_loader))
        #save_steps = len(self.unlabel_loader) / 10

        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        for step, data in enumerate(dataloader, 0):
            imgs_tcpu, real_imgs, \
            labels, label_vectors = self.prepare_data(data)

            for i in range(cfg.TREE.BRANCH_NUM):
                net = netsD[i]
                real_img = real_imgs[i]
                real_logits, softmax_out, hash_logits, _ = net(real_img)
                #hash_logits = hash_logits > 0.5
                if i in hash_dict:
                    hash_dict[i] = torch.cat((hash_dict[i], hash_logits), 0)
                else:
                    hash_dict[i] = hash_logits

            if label_total is None:
                label_total = labels
                imgs_total = real_imgs[0]
            else:
                label_total = torch.cat((label_total, labels), 0)
                imgs_total = torch.cat((imgs_total, real_imgs[0]), 0)

            if (step + 1) % save_steps == 0:
                cnt = (step + 1) / save_steps
                output_img = os.path.join(
                    output_dir, "%s_images_%d.npy" % (dataset_name, cnt))
                output_label = os.path.join(
                    output_dir, "%s_label_%d.npy" % (dataset_name, cnt))

                if cfg.CUDA:
                    np.save(output_img, imgs_total.cpu().data.numpy())
                    np.save(output_label, label_total.cpu().data.numpy())
                else:
                    np.save(output_img, imgs_total.data.numpy())
                    np.save(output_label, label_total.data.numpy())

                for i in range(cfg.TREE.BRANCH_NUM):
                    output_hash = os.path.join(
                        output_dir,
                        "branch_%d_hash_%s_%d.npy" % (i, dataset_name, cnt))
                    if cfg.CUDA:
                        np.save(output_hash, hash_dict[i].cpu().data.numpy())
                    else:
                        np.save(output_hash, hash_dict[i].data.numpy())

                imgs_total = None
                label_total = None
                hash_dict = {}
            print("step %d done!" % step)
            print("step %d done!" % step)

    def load_Dnet(self, gpus):
        if cfg.TRAIN.NET_D == '':
            print('Error: the path for morels is not found!')
            sys.exit(-1)
        else:
            netsD = []
            if cfg.TREE.BRANCH_NUM > 0:
                netsD.append(D_NET64())
            if cfg.TREE.BRANCH_NUM > 1:
                netsD.append(D_NET128())
            if cfg.TREE.BRANCH_NUM > 2:
                netsD.append(D_NET256())
            if cfg.TREE.BRANCH_NUM > 3:
                netsD.append(D_NET512())
            if cfg.TREE.BRANCH_NUM > 4:
                netsD.append(D_NET1024())

            self.num_Ds = len(netsD)
            for i in range(len(netsD)):
                netsD[i].apply(weights_init)
                netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)

            for i in range(cfg.TREE.BRANCH_NUM):
                print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
                state_dict = torch.load(
                    '%snetD%d_70000.pth' % (cfg.TRAIN.NET_D, i),
                    map_location=lambda storage, loc: storage)
                netsD[i].load_state_dict(state_dict)

            if cfg.CUDA:
                for i in range(len(netsD)):
                    netsD[i].cuda()
            return netsD

    def get_numpy(self, files):
        arr = None
        print(files)
        for f in files:
            a = np.load(f)
            #print("a:{0}".format(a.shape))
            if arr is not None:
                arr = np.concatenate((arr, a), axis=0)
            else:
                arr = a
        return arr

    def compute_MAP_sklearn(self,
                            test_features,
                            db_features,
                            test_label,
                            db_label,
                            metric='euclidean'):

        Y = cdist(test_features, db_features, metric)
        ind = np.argsort(Y, axis=1)
        prec_total = 0.0
        recall_total = None
        precision_total = None
        for k in range(np.shape(test_features)[0]):
            class_values = db_label[ind[k, :]]
            y_true = (test_label[k] == class_values)
            y_scores = np.arange(y_true.shape[0], 0, -1)
            ap = average_precision_score(y_true, y_scores)
            prec_total += ap
            if recall_total is None:
                precision_total, recall_total, r = precision_recall_curve(
                    y_true, y_scores)
                print("precision_shape:{}".format(precision_total.shape))
                print("recall_shape:{}".format(recall_total.shape))
            else:
                precision, recall, r = precision_recall_curve(y_true, y_scores)
                #if k % 100 == 0:
                #print(r)
                #print("precision_shape:{}".format(precision.shape))
                #print("recall_shape:{}".format(recall.shape))
                #precision_total = precision_total + precision
                #recall_total = recall_total + recall
        print(precision[5800:])
        print(recall[:20])
        test_num = test_features.shape[0]
        print("test_num:{}".format(test_num))
        MAP = prec_total / test_num
        #recall_total = [i/test_num for i in recall_total]
        #precision_total = [i / test_num for i in precision_total]

        print("MAP: %f" % MAP)
        #print("recall:{0}".format(recall_total[:30]))
        # print("precision:{0}".format(precision_total[:30]))

        with open(metric + "_result.txt", 'w') as f:
            f.write("MAP:%f\n" % MAP)
            #f.write("recall\n:{0}".format(recall_total))
            #f.write("precision\n:{0}".format(precision_total))

        #np.save("recall.npy", recall_total)
        #np.save("precision.npy", precision_total)

    def compute_MAP(self, root, branch, query_labels, db_labels):

        query_npys = os.listdir(os.path.join(root, "test"))
        db_npys = os.listdir(os.path.join(root, "db"))
        query_npys.sort()
        db_npys.sort()
        query_hash_path = [
            os.path.join(root, "test", i) for i in query_npys
            if "branch_" + str(branch) in i
        ]
        db_hash_path = [
            os.path.join(root, "db", i) for i in db_npys
            if "branch_" + str(branch) in i
        ]
        query_hashs = self.get_numpy(query_hash_path)
        db_hashs = self.get_numpy(db_hash_path)
        assert db_labels.shape[0] == db_hashs.shape[
            0], "db labels num must be equal to db hash num"
        assert query_labels.shape[0] == query_hashs.shape[
            0], "db labels num must be equal to db hash num"

        print(
            "-----------------------------------use features----------------------"
        )
        self.compute_MAP_sklearn(query_hashs, db_hashs, query_labels,
                                 db_labels)
        print(
            "-----------------------------------use hash--------------------------"
        )
        query_h = query_hashs > 0.5
        db_h = db_hashs > 0.5
        self.compute_MAP_sklearn(query_h, db_h, query_labels, db_labels)

        return db_hashs, query_hashs

    # query_labels = query_labels.tolist()
    # recall_num = 5900
    # MAP = 0
    # cnt = 0
    # precision_topk = [0 for i in range(len(query_labels) / 10 + 1)]
    # recall_curve = [0 for i in range(len(query_labels) / 25 + 1)]
    # precision_curve = [0 for i in range(len(query_labels) / 25 + 1)]

    # for i in range(query_labels.shape[0]):
    #     hamming_dis = 0.0
    #     for h in query_hashs:
    #         q = h[i]
    #         hamming_dis += np.sum((h - q) ** 2, axis=1)
    #     hamming_dis = hamming_dis.tolist()
    #     hamming_label = zip(query_labels, hamming_dis)
    #     sorted_by_hamming = sorted(hamming_label, key=lambda m: m[1])
    #
    #     smi = 0.0
    #     count = 0
    #     ap = 0
    #     for label, hamming in sorted_by_hamming:
    #         count += 1
    #         if label == query_labels[i]:
    #             smi += 1
    #             ap += smi / count
    #         if count % 10 == 0:
    #             precision_topk[count / 10] += smi / count
    #
    #         if count % 25 == 0:
    #             recall_curve[count / 25] += smi / recall_num
    #             precision_curve[count / 25] += smi / count
    #     #print("ap:%f"%(ap / smi))
    #     MAP += ap / smi
    #     cnt += 1
    #     if cnt % 100 == 0:
    #         print("has process %d queries" % cnt)
    #
    # MAP = MAP / len(query_labels)
    # precision_topk = [i / len(query_labels) for i in precision_topk]
    # recall_curve = [i / len(query_labels) for i in recall_curve]
    # precision_curve = [i / len(query_labels) for i in precision_curve]
    #
    # print("MAP_BRANCH_%d: %f" % (branch, MAP))
    # print("precision_top:{0}".format(precision_topk[0]))
    # print("recall curve:{0}".format(recall_curve[0]))
    # print("precision curve:{0}".format(precision_curve[0]))
    #
    # f = open("branch_%d_eval" % branch, 'w')
    # f.write("MAP: %f\n" % (MAP))
    # f.write("precision_top:\n")
    # self.write(f, precision_topk, "precision_top")
    # self.write(f, recall_curve, "recall_curve")
    # self.write(f, precision_curve, "precision_curve")
    # f.close()

    def compute_MAP_hash(self, root, paths, branch):
        query_label_path, query_hash_path = paths

        query_labels = self.get_numpy(root, query_label_path)
        query_hash = self.get_numpy(root, query_hash_path)
        print("total test number:%d" % query_hash.shape[0])
        assert query_hash.shape[0] == query_hash.shape[
            0], "query hash size not equal to query label size"
        query_hash = (query_hash > 0.5) + 0
        query_labels = query_labels.tolist()
        recall_num = 100

        r_1 = np.matmul(query_hash, query_hash.T)
        r_2 = np.matmul(1 - query_hash, (1 - query_hash).T)
        r = r_1 + r_2

        hamming_distance = cfg.GAN.HASH_DIM - r
        hamming_distance_list = hamming_distance.tolist()
        query_hamming = zip(query_labels, hamming_distance_list)

        MAP = 0
        cnt = 0
        precision_topk = [0 for i in range(len(query_labels) / 10 + 1)]
        recall_curve = [0 for i in range(len(query_labels) / 25 + 1)]
        precision_curve = [0 for i in range(len(query_labels) / 25 + 1)]

        for q_label, hamming_dis in query_hamming:
            hamming_label = zip(query_labels, hamming_dis)
            sorted_by_hamming = sorted(hamming_label, key=lambda m: m[1])

            smi = 0.0
            count = 0
            ap = 0
            for label, hamming in sorted_by_hamming:
                count += 1
                if label == q_label:
                    smi += 1
                    ap += smi / count
                if count % 10 == 0:
                    precision_topk[count / 10] += smi / count

                if count % 25 == 0:
                    recall_curve[count / 25] += smi / recall_num
                    precision_curve[count / 25] += smi / count
            # print("ap:%f"%(ap / smi))
            MAP += ap / smi
            cnt += 1
            if cnt % 100 == 0:
                print("has process %d queries" % cnt)

        MAP = MAP / len(query_labels)
        precision_topk = [i / len(query_labels) for i in precision_topk]
        recall_curve = [i / len(query_labels) for i in recall_curve]
        precision_curve = [i / len(query_labels) for i in precision_curve]

        print("MAP_BRANCH_%d: %f" % (branch, MAP))
        print("precision_top:{0}".format(precision_topk))
        print("recall curve:{0}".format(recall_curve))
        print("precision curve:{0}".format(precision_curve))

        f = open("branch_%d_eval" % branch, 'w')
        f.write("MAP: %f\n" % (MAP))
        f.write("precision_top:\n")
        self.write(f, precision_topk, "precision_top")
        self.write(f, recall_curve, "recall_curve")
        self.write(f, precision_curve, "precision_curve")
        f.close()

    def evaluate_MAP(self, db_dataloader, query_dataloader, root):

        netsD = self.load_Dnet(self.gpus)
        if len(os.listdir(os.path.join(root, "test"))) == 0:
            save_steps = len(query_dataloader) / 10
            self.get_hash(query_dataloader, netsD, "test", save_steps)
            print("get query image hash")

        if len(os.listdir(os.path.join(root, "db"))) == 0:
            save_steps = len(db_dataloader) / 10
            self.get_hash(db_dataloader, netsD, "db", save_steps)
            print("get db image hash!")

        eval_path = root
        db_hashs = None
        query_hashs = None
        query_npys = os.listdir(os.path.join(root, "test"))
        query_npys.sort()
        query_label_path = [
            os.path.join(root, "test", i) for i in query_npys if "label" in i
        ]
        db_npys = os.listdir(os.path.join(root, "db"))
        db_npys.sort()
        db_label_path = [
            os.path.join(root, "db", i) for i in db_npys if "label" in i
        ]
        query_labels = self.get_numpy(query_label_path)
        db_labels = self.get_numpy(db_label_path)
        for i in range(cfg.TREE.BRANCH_NUM):
            print(
                "--------------------------branch %d-------------------------------------------"
                % i)
            db_hash, query_hash = self.compute_MAP(eval_path, i, query_labels,
                                                   db_labels)

            if db_hashs is None:
                db_hashs = db_hash
                query_hashs = query_hash
            else:
                db_hashs += db_hash
                query_hashs += query_hash

        db_hashs /= cfg.TREE.BRANCH_NUM
        query_hashs /= cfg.TREE.BRANCH_NUM

        print(
            "--------------------------------total--------------------------------------------"
        )
        print(
            "-----------------------------------use features----------------------"
        )
        self.compute_MAP_sklearn(query_hashs, db_hashs, query_labels,
                                 db_labels)
        print(
            "-----------------------------------use hash--------------------------"
        )

        db_hashs = db_hashs > 0.5
        query_hashs = query_hashs > 0.5
        self.compute_MAP_sklearn(query_hashs,
                                 db_hashs,
                                 query_labels,
                                 db_labels,
                                 metric="hamming")
示例#29
0
class GANTrainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
        self.max_objects = 4

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        netD = STAGE1_D()
        netD.apply(weights_init)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict["netG"])
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,  map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    def train(self, data_loader, stage=1):
        netG, netD = self.load_network_stageI()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))

        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                               requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH

        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))

        print("Start training...")
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, transformation_matrices, label_one_hot, _ = data

                transf_matrices, transf_matrices_inv = tuple(
                    transformation_matrices)
                transf_matrices = transf_matrices.detach()
                transf_matrices_inv = transf_matrices_inv.detach()

                real_imgs = Variable(real_img_cpu)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    label_one_hot = label_one_hot.cuda()
                    transf_matrices = transf_matrices.cuda()
                    transf_matrices_inv = transf_matrices_inv.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (noise, transf_matrices_inv, label_one_hot)
                fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()

                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               label_one_hot, transf_matrices, transf_matrices_inv, self.gpus)
                errD.backward(retain_graph=True)
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels,
                                              label_one_hot, transf_matrices,
                                              transf_matrices_inv, self.gpus)

                errG_total = errG
                errG_total.backward()
                optimizerG.step()

                ############################
                # (3) Log results
                ###########################
                count = count + 1
                if i % 500 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)

                    # save the image result for each epoch
                    with torch.no_grad():
                        inputs = (noise, transf_matrices_inv, label_one_hot)
                        fake = nn.parallel.data_parallel(
                            netG, inputs, self.gpus)
                        save_img_results(real_img_cpu, fake, epoch,
                                         self.image_dir)
            with torch.no_grad():
                inputs = (noise, transf_matrices_inv, label_one_hot)
                fake = nn.parallel.data_parallel(netG, inputs, self.gpus)
                save_img_results(real_img_cpu, fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                  (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                   errG.item(), errD_real, errD_wrong, errD_fake,
                   (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, optimizerG, optimizerD, epoch,
                           self.model_dir)
        #
        save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir)
        #
        self.summary_writer.close()

    def sample(self,
               data_loader,
               num_samples=25,
               draw_bbox=True,
               max_objects=4):
        from PIL import Image, ImageDraw, ImageFont
        import pickle
        import torchvision
        import torchvision.utils as vutils
        netG, _ = self.load_network_stageI()
        netG.eval()

        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_samples_" + str(
            max_objects) + "_objects"
        print("saving to:", save_dir)
        mkdir_p(save_dir)

        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(9, nz))
        if cfg.CUDA:
            noise = noise.cuda()

        imsize = 64

        # for count in range(num_samples):
        count = 0
        for i, data in enumerate(data_loader, 0):
            if count == num_samples:
                break
            ######################################################
            # (1) Prepare training data
            ######################################################
            real_img_cpu, transformation_matrices, label_one_hot, bbox = data

            transf_matrices, transf_matrices_inv = tuple(
                transformation_matrices)
            transf_matrices_inv = transf_matrices_inv.detach()

            real_img = Variable(real_img_cpu)
            if cfg.CUDA:
                real_img = real_img.cuda()
                label_one_hot = label_one_hot.cuda()
                transf_matrices_inv = transf_matrices_inv.cuda()

            transf_matrices_inv_batch = transf_matrices_inv.view(
                1, max_objects, 2, 3).repeat(9, 1, 1, 1)
            label_one_hot_batch = label_one_hot.view(1, max_objects,
                                                     13).repeat(9, 1, 1)

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            inputs = (noise, transf_matrices_inv_batch, label_one_hot_batch)
            with torch.no_grad():
                fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus)

            data_img = torch.FloatTensor(20, 3, imsize, imsize).fill_(0)
            data_img[0] = real_img
            data_img[1:10] = fake_imgs

            if draw_bbox:
                for idx in range(max_objects):
                    x, y, w, h = tuple([int(imsize * x) for x in bbox[0, idx]])
                    w = imsize - 1 if w > imsize - 1 else w
                    h = imsize - 1 if h > imsize - 1 else h
                    if x <= -1 or y <= -1:
                        break
                    data_img[:10, :, y, x:x + w] = 1
                    data_img[:10, :, y:y + h, x] = 1
                    data_img[:10, :, y + h, x:x + w] = 1
                    data_img[:10, :, y:y + h, x + w] = 1

            # write caption into image
            shape_dict = {0: "cube", 1: "cylinder", 2: "sphere", 3: "empty"}

            color_dict = {
                0: "gray",
                1: "red",
                2: "blue",
                3: "green",
                4: "brown",
                5: "purple",
                6: "cyan",
                7: "yellow",
                8: "empty"
            }
            text_img = Image.new('L', (imsize * 10, imsize), color='white')
            d = ImageDraw.Draw(text_img)
            label = label_one_hot_batch[0]
            label = label.cpu().numpy()
            label_shape = label[:, :4]
            label_color = label[:, 4:]
            label_shape = np.argmax(label_shape, axis=1)
            label_color = np.argmax(label_color, axis=1)
            label_combined = ", ".join([
                color_dict[label_color[_]] + " " + shape_dict[label_shape[_]]
                for _ in range(max_objects)
            ])
            d.text((10, 10), label_combined)
            text_img = torchvision.transforms.functional.to_tensor(text_img)
            text_img = torch.chunk(text_img, 10, 2)
            text_img = torch.cat(
                [text_img[i].view(1, 1, imsize, imsize) for i in range(10)], 0)
            data_img[10:] = text_img
            vutils.save_image(data_img,
                              '{}/vis_{}.png'.format(save_dir, count),
                              normalize=True,
                              nrow=10)
            count += 1

        print("Saved {} files to {}".format(count, save_dir))
示例#30
0
def train(iter_cnt, model, domain_d, corpus, args, optimizer_encoder,
          optimizer_domain_d):

    train_writer = FileWriter(args.run_path + "/train", flush_secs=5)

    pos_file_path = "{}.pos.txt".format(args.train)
    neg_file_path = "{}.neg.txt".format(args.train)

    # for adversarial training just use natural language portions of inputs
    train_corpus_path = os.path.dirname(args.train) + "/nl.tsv.gz"
    cross_train_corpus_path = os.path.dirname(args.cross_train) + "/nl.tsv.gz"

    use_content = False
    if args.use_content:
        use_content = True

    pos_batch_loader = FileLoader(
        [tuple([pos_file_path, os.path.dirname(args.train)])], args.batch_size)
    neg_batch_loader = FileLoader(
        [tuple([neg_file_path, os.path.dirname(args.train)])], args.batch_size)
    cross_loader = TwoDomainLoader(
        [tuple([train_corpus_path,
                os.path.dirname(train_corpus_path)])], [
                    tuple([
                        cross_train_corpus_path,
                        os.path.dirname(cross_train_corpus_path)
                    ])
                ], args.batch_size * 2)

    embedding_layer = model.embedding_layer

    criterion1 = model.compute_loss
    criterion2 = domain_d.compute_loss

    start = time.time()
    task_loss = 0.0
    task_cnt = 0
    domain_loss = 0.0
    dom_cnt = 0
    total_loss = 0.0
    total_cnt = 0

    for batch, labels, domain_batch, domain_labels in tqdm(
            cross_pad_iter(corpus,
                           embedding_layer,
                           pos_batch_loader,
                           neg_batch_loader,
                           cross_loader,
                           use_content,
                           pad_left=False)):
        iter_cnt += 1

        new_batch = []
        if args.use_content:
            for x in batch:
                for y in x:
                    new_batch.append(y)
            batch = new_batch
            domain_batch = [x for x in domain_batch]

        if args.cuda:
            batch = [x.cuda() for x in batch]
            labels = labels.cuda()
            if not use_content:
                domain_batch = domain_batch.cuda()
            else:
                domain_batch = [x.cuda() for x in domain_batch]
            domain_labels = domain_labels.cuda()
        batch = map(Variable, batch)
        labels = Variable(labels)
        if not use_content:
            domain_batch = Variable(domain_batch)
        else:
            domain_batch = map(Variable, domain_batch)
        domain_labels = Variable(domain_labels)

        model.zero_grad()
        domain_d.zero_grad()

        repr_left = None
        repr_right = None
        if not use_content:
            repr_left = model(batch[0])
            repr_right = model(batch[1])
        else:
            repr_left = model(batch[0]) + model(batch[1])
            repr_right = model(batch[2]) + model(batch[3])
        output = model.compute_similarity(repr_left, repr_right)
        loss1 = criterion1(output, labels)
        task_loss += loss1.data[0] * output.size(0)
        task_cnt += output.size(0)

        domain_output = None
        if not use_content:
            domain_output = domain_d(model(domain_batch))
        else:
            domain_output = domain_d(model(domain_batch[0])) + domain_d(
                model(domain_batch[1]))
        loss2 = criterion2(domain_output, domain_labels)
        domain_loss += loss2.data[0] * domain_output.size(0)
        dom_cnt += domain_output.size(0)

        loss = loss1 - args.lambda_d * loss2
        total_loss += loss.data[0]
        total_cnt += 1
        loss.backward()
        optimizer_encoder.step()
        optimizer_domain_d.step()

        if iter_cnt % 100 == 0:
            outputManager.say("\r" + " " * 50)
            outputManager.say(
                "\r{} tot_loss: {:.4f} task_loss: {:.4f} domain_loss: {:.4f} eps: {:.0f} "
                .format(iter_cnt, total_loss / total_cnt, task_loss / task_cnt,
                        domain_loss / dom_cnt,
                        (task_cnt + dom_cnt) / (time.time() - start)))

            s = summary.scalar('total_loss', total_loss / total_cnt)
            train_writer.add_summary(s, iter_cnt)
            s = summary.scalar('domain_loss', domain_loss / dom_cnt)
            train_writer.add_summary(s, iter_cnt)
            s = summary.scalar('task_loss', task_loss / task_cnt)
            train_writer.add_summary(s, iter_cnt)

    outputManager.say("\n")
    train_writer.close()

    return iter_cnt
示例#31
0
def evaluate(iter_cnt, filepath, model, corpus, args, logging=True):
    if logging:
        valid_writer = FileWriter(args.run_path + "/valid", flush_secs=5)

    pos_file_path = "{}.pos.txt".format(filepath)
    neg_file_path = "{}.neg.txt".format(filepath)
    pos_batch_loader = FileLoader(
        [tuple([pos_file_path, os.path.dirname(args.eval)])], args.batch_size)
    neg_batch_loader = FileLoader(
        [tuple([neg_file_path, os.path.dirname(args.eval)])], args.batch_size)

    batchify = lambda bch: make_batch(model.embedding_layer, bch)
    model.eval()
    criterion = model.compute_loss
    auc_meter = AUCMeter()
    scores = [np.asarray([], dtype='float32') for i in range(2)]
    for loader_id, loader in tqdm(
            enumerate((neg_batch_loader, pos_batch_loader))):
        for data in tqdm(loader):
            data = map(corpus.get, data)
            batch = None
            if not args.eval_use_content:
                batch = (batchify(data[0][0]), batchify(data[1][0]))
            else:
                batch = (map(batchify, data[0]), map(batchify, data[1]))
                new_batch = []
                for x in batch:
                    for y in x:
                        new_batch.append(y)
                batch = new_batch
            labels = torch.ones(batch[0].size(1)).type(
                torch.LongTensor) * loader_id
            if args.cuda:
                batch = [x.cuda() for x in batch]
                labels = labels.cuda()
            if not args.eval_use_content:
                batch = (Variable(batch[0], volatile=True),
                         Variable(batch[1], volatile=True))
            else:
                batch = (Variable(batch[0], volatile=True),
                         Variable(batch[1], volatile=True),
                         Variable(batch[2], volatile=True),
                         Variable(batch[3], volatile=True))
            labels = Variable(labels)
            if not args.eval_use_content:
                repr_left = model(batch[0])
                repr_right = model(batch[1])
            else:
                repr_left = model(batch[0]) + model(batch[1])
                repr_right = model(batch[2]) + model(batch[3])
            output = model.compute_similarity(repr_left, repr_right)

            if model.criterion.startswith('classification'):
                assert output.size(1) == 2
                output = nn.functional.log_softmax(output)
                current_scores = -output[:, loader_id].data.cpu().squeeze(
                ).numpy()
                output = output[:, 1]
            else:
                assert output.size(1) == 1
                current_scores = output.data.cpu().squeeze().numpy()
            auc_meter.add(output.data, labels.data)
            scores[loader_id] = np.append(scores[loader_id], current_scores)

    auc_score = auc_meter.value()
    auc10_score = auc_meter.value(0.1)
    auc05_score = auc_meter.value(0.05)
    auc02_score = auc_meter.value(0.02)
    auc01_score = auc_meter.value(0.01)
    if model.criterion.startswith('classification'):
        avg_score = (scores[1].mean() + scores[0].mean()) * 0.5
    else:
        avg_score = scores[1].mean() - scores[0].mean()
    outputManager.say(
        "\r[{}] auc(.01): {:.3f}  auc(.02): {:.3f}  auc(.05): {:.3f}"
        "  auc(.1): {:.3f}  auc: {:.3f}"
        "  scores: {:.2f} ({:.2f} {:.2f})\n".format(
            os.path.basename(filepath).split('.')[0], auc01_score, auc02_score,
            auc05_score, auc10_score, auc_score, avg_score, scores[1].mean(),
            scores[0].mean()))

    if logging:
        s = summary.scalar('auc', auc_score)
        valid_writer.add_summary(s, iter_cnt)
        s = summary.scalar('auc (fpr<0.1)', auc10_score)
        valid_writer.add_summary(s, iter_cnt)
        s = summary.scalar('auc (fpr<0.05)', auc05_score)
        valid_writer.add_summary(s, iter_cnt)
        s = summary.scalar('auc (fpr<0.02)', auc02_score)
        valid_writer.add_summary(s, iter_cnt)
        s = summary.scalar('auc (fpr<0.01)', auc01_score)
        valid_writer.add_summary(s, iter_cnt)
        valid_writer.close()

    return auc05_score
示例#32
0
def train():
    '''train wgan
    '''
    ctxs = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    batch_size = args.batch_size
    z_dim = args.z_dim
    lr = args.lr
    epoches = args.epoches
    wclip = args.wclip
    frequency = args.frequency
    model_prefix = args.model_prefix
    rand_iter = RandIter(batch_size, z_dim)
    image_iter = ImageIter(args.data_path, batch_size, (3, 64, 64))
    # G and D
    symG, symD = dcgan64x64(ngf=args.ngf, ndf=args.ndf, nc=args.nc)
    modG = mx.mod.Module(symbol=symG, data_names=('rand',), label_names=None, context=ctxs)
    modG.bind(data_shapes=rand_iter.provide_data)
    modG.init_params(initializer=mx.init.Normal(0.002))
    modG.init_optimizer(
        optimizer='sgd',
        optimizer_params={
            'learning_rate': lr,
        })
    modD = mx.mod.Module(symbol=symD, data_names=('data',), label_names=None, context=ctxs)
    modD.bind(data_shapes=image_iter.provide_data,
              inputs_need_grad=True)
    modD.init_params(mx.init.Normal(0.002))
    modD.init_optimizer(
        optimizer='sgd',
        optimizer_params={
            'learning_rate': lr,
        })
    # train
    logging.info('Start training')
    metricD = WGANMetric()
    metricG = WGANMetric()
    fix_noise_batch = mx.io.DataBatch([mx.random.normal(0, 1, shape=(batch_size, z_dim, 1, 1))], [])
    # visualization with TensorBoard if possible
    if use_tb:
        writer = FileWriter('tmp/exp')
    for epoch in range(epoches):
        image_iter.reset()
        metricD.reset()
        metricG.reset()
        for i, batch in enumerate(image_iter):
            # clip weight
            for params in modD._exec_group.param_arrays:
                for param in params:
                    mx.nd.clip(param, -wclip, wclip, out=param)
            # forward G
            rbatch = rand_iter.next()
            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()
            # fake
            modD.forward(mx.io.DataBatch(outG, label=[]), is_train=True)
            fw_g = modD.get_outputs()[0].asnumpy()
            modD.backward([mx.nd.ones((batch_size, 1)) / batch_size])
            gradD = [[grad.copyto(grad.context) for grad in grads] for grads in modD._exec_group.grad_arrays]
            # real
            modD.forward(batch, is_train=True)
            fw_r = modD.get_outputs()[0].asnumpy()
            modD.backward([-mx.nd.ones((batch_size, 1)) / batch_size])
            for grads_real, grads_fake in zip(modD._exec_group.grad_arrays, gradD):
                for grad_real, grad_fake in zip(grads_real, grads_fake):
                    grad_real += grad_fake
            modD.update()
            errorD = -(fw_r - fw_g) / batch_size
            metricD.update(errorD.mean())
            # update G
            rbatch = rand_iter.next()
            modG.forward(rbatch, is_train=True)
            outG = modG.get_outputs()
            modD.forward(mx.io.DataBatch(outG, []), is_train=True)
            errorG = -modD.get_outputs()[0] / batch_size
            modD.backward([-mx.nd.ones((batch_size, 1)) / batch_size])
            modG.backward(modD.get_input_grads())
            modG.update()
            metricG.update(errorG.asnumpy().mean())
            # logging state
            if (i+1)%frequency == 0:
                print("epoch:", epoch+1, "iter:", i+1, "G: ", metricG.get(), "D: ", metricD.get())
        # save checkpoint
        modG.save_checkpoint('model/%s-G'%(model_prefix), epoch+1)
        modD.save_checkpoint('model/%s-D'%(model_prefix), epoch+1)
        rbatch = rand_iter.next()
        modG.forward(rbatch)
        outG = modG.get_outputs()[0]
        canvas = visual('tmp/gout-rand-%d.png'%(epoch+1), outG.asnumpy())
        if use_tb:
            canvas = canvas[:, :, ::-1]  # BGR -> RGB
            writer.add_summary(summary.image('gout-rand-%d'%(epoch+1), canvas))
        modG.forward(fix_noise_batch)
        outG = modG.get_outputs()[0]
        canvas = visual('tmp/gout-fix-%d.png'%(epoch+1), outG.asnumpy())
        if use_tb:
            canvas = canvas[:, :, ::-1]
            writer.add_summary(summary.image('gout-fix-%d'%(epoch+1), canvas))
        if use_tb:
            writer.flush()
    writer.close()
示例#33
0
class GANTrainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()

        start_epoch = 0
        if cfg.NET_G != '':
            istart = cfg.NET_G.rfind('_') + 1
            iend = cfg.NET_G.rfind('.')
            start_epoch = cfg.NET_G[istart:iend]
            start_epoch = int(start_epoch) + 1

        return netG, netD, start_epoch

    # ############# For training stageII GAN  #############
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = \
                torch.load(cfg.STAGE1_G,
                           map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict)
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()

        start_epoch = 0
        if cfg.NET_G != '':
            istart = cfg.NET_G.rfind('_') + 1
            iend = cfg.NET_G.rfind('.')
            start_epoch = cfg.NET_G[istart:iend]
            start_epoch = int(start_epoch) + 1

        return netG, netD, start_epoch

    def load_optimizers(self, netG, netD):
        optimizerD = optim.Adam(netD.parameters(),
                                lr=cfg.TRAIN.DISCRIMINATOR_LR,
                                betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        if cfg.NET_G != '':
            Gname = cfg.NET_G
            s_tmp = Gname[:Gname.rfind('/')]
            oGname = '%s/optimizerG.pth' % (s_tmp)
            state_dict = \
                  torch.load(oGname, map_location=lambda storage, loc: storage)
            optimizerG.load_state_dict(state_dict)
            print('Load optimizerG from: ', oGname)

            oDname = '%s/optimizerD.pth' % (s_tmp)
            state_dict = \
                torch.load(oDname, map_location=lambda storage, loc: storage)
            optimizerD.load_state_dict(state_dict)
            print('Load optimizerD from: ', oDname)
        return optimizerG, optimizerD

    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD, start_epoch = self.load_network_stageI()
        else:
            netG, netD, start_epoch = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     requires_grad=False)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerG, optimizerD = self.load_optimizers(netG, netD)

        count = 0
        for epoch in range(start_epoch, self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.float().cuda()

                ######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ######################################################
                # (3) Update D network
                ######################################################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ######################################################
                # (2) Update G network
                ######################################################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs, real_labels, mu,
                                              self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.item())
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.item())
                    summary_KL = summary.scalar('KL_loss', kl_loss.item())

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print(
                '''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  ''' %
                (epoch, self.max_epoch, i, len(data_loader), errD.item(),
                 errG.item(), kl_loss.item(), errD_real, errD_wrong, errD_fake,
                 (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
                save_optimizer(optimizerG, optimizerD, self.model_dir)

        save_model(netG, netD, self.max_epoch, self.model_dir)
        self.summary_writer.close()

    def sample(self, datapath, stage=1):
        if stage == 1:
            netG, _, _ = self.load_network_stageI()
        else:
            netG, _, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        embeddings = np.load(datapath)
        num_embeddings = embeddings.shape[0]
        print('Successfully load sentences from: ', datapath)
        print('num_embeddings:', num_embeddings, embeddings.shape)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')]
        mkdir_p(save_dir)

        batch_size = np.minimum(num_embeddings, self.batch_size)
        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        if cfg.CUDA:
            noise = noise.cuda()
        count = 0

        while count < num_embeddings:
            #             if count > 3000:
            #                 break
            iend = count + batch_size
            if iend > num_embeddings:
                iend = num_embeddings
            embeddings_batch = embeddings[count:iend]
            txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
            if cfg.CUDA:
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            #######################################################
            noise.data.normal_(0, 1)
            inputs = (txt_embedding, noise[0:embeddings_batch.shape[0], :])
            _, fake_imgs, mu, logvar = \
                nn.parallel.data_parallel(netG, inputs, self.gpus)
            for i in range(embeddings_batch.shape[0]):
                save_name = '%s/%d.png' % (save_dir, count + i + 1)
                im = fake_imgs[i].data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                # print('im', im.shape)
                im = np.transpose(im, (1, 2, 0))
                # print('im', im.shape)
                im = Image.fromarray(im)
                im.save(save_name)
            count += batch_size

    def sample2(self, datapath, stage=2):
        if stage == 1:
            netG, _, _ = self.load_network_stageI()
        else:
            netG, _, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        embeddings = np.load(datapath)
        num_embeddings = embeddings.shape[0]
        print('Successfully load sentences from: ', datapath)
        print('num_embeddings:', num_embeddings, embeddings.shape)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')]
        mkdir_p(save_dir)

        batch_size = np.minimum(num_embeddings, self.batch_size)
        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        if cfg.CUDA:
            noise = noise.cuda()
        count = 0

        while count < num_embeddings:
            #             if count > 3000:
            #                 break
            iend = count + batch_size
            if iend > num_embeddings:
                iend = num_embeddings
            embeddings_batch = embeddings[count:iend]
            txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
            if cfg.CUDA:
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            #######################################################
            noise.data.normal_(0, 1)
            inputs = (txt_embedding, noise[0:embeddings_batch.shape[0], :])
            lr_fake_imgs, fake_imgs, mu, logvar = \
                nn.parallel.data_parallel(netG, inputs, self.gpus)
            for i in range(embeddings_batch.shape[0]):
                im = fake_imgs[i].data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                # print('im', im.shape)
                im = np.transpose(im, (1, 2, 0))
                # print('im', im.shape)
                im = Image.fromarray(im)
                save_name = '%s/%d_%d.png' % (save_dir, im.size[0],
                                              count + i + 1)
                im.save(save_name)
            for i in range(embeddings_batch.shape[0]):
                im = lr_fake_imgs[i].data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                # print('im', im.shape)
                im = np.transpose(im, (1, 2, 0))
                # print('im', im.shape)
                im = Image.fromarray(im)
                save_name = '%s/%d_%d.png' % (save_dir, im.size[0],
                                              count + i + 1)
                im.save(save_name)
            count += batch_size
示例#34
0
class GANTrainer(object):
    def __init__(self, output_dir):
        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(output_dir, 'Model')
            self.image_dir = os.path.join(output_dir, 'Image')
            self.log_dir = os.path.join(output_dir, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.image_dir)
            mkdir_p(self.log_dir)
            self.summary_writer = FileWriter(self.log_dir)

        self.max_epoch = cfg.TRAIN.MAX_EPOCH
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)
        self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

    # ############# For training stageI GAN #############
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    # ############# For training stageII GAN  #############
    def load_network_stageII(self):
        from model import STAGE1_G, STAGE2_G, STAGE2_D

        Stage1_G = STAGE1_G()
        netG = STAGE2_G(Stage1_G)
        netG.apply(weights_init)
        print(netG)
        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        elif cfg.STAGE1_G != '':
            state_dict = \
                torch.load(cfg.STAGE1_G,
                           map_location=lambda storage, loc: storage)
            netG.STAGE1_G.load_state_dict(state_dict)
            print('Load from: ', cfg.STAGE1_G)
        else:
            print("Please give the Stage1_G path")
            return

        netD = STAGE2_D()
        netD.apply(weights_init)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        print(netD)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD

    def train(self, data_loader, stage=1):
        if stage == 1:
            netG, netD = self.load_network_stageI()
        else:
            netG, netD = self.load_network_stageII()

        nz = cfg.Z_DIM
        batch_size = self.batch_size
        noise = Variable(torch.FloatTensor(batch_size, nz))
        fixed_noise = \
            Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
                     volatile=True)
        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
        if cfg.CUDA:
            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
            real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()

        generator_lr = cfg.TRAIN.GENERATOR_LR
        discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
        lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
        optimizerD = \
            optim.Adam(netD.parameters(),
                       lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
        netG_para = []
        for p in netG.parameters():
            if p.requires_grad:
                netG_para.append(p)
        optimizerG = optim.Adam(netG_para,
                                lr=cfg.TRAIN.GENERATOR_LR,
                                betas=(0.5, 0.999))
        count = 0
        for epoch in range(self.max_epoch):
            start_t = time.time()
            if epoch % lr_decay_step == 0 and epoch > 0:
                generator_lr *= 0.5
                for param_group in optimizerG.param_groups:
                    param_group['lr'] = generator_lr
                discriminator_lr *= 0.5
                for param_group in optimizerD.param_groups:
                    param_group['lr'] = discriminator_lr

            for i, data in enumerate(data_loader, 0):
                ######################################################
                # (1) Prepare training data
                ######################################################
                real_img_cpu, txt_embedding = data
                real_imgs = Variable(real_img_cpu)
                txt_embedding = Variable(txt_embedding)
                if cfg.CUDA:
                    real_imgs = real_imgs.cuda()
                    txt_embedding = txt_embedding.cuda()

                #######################################################
                # (2) Generate fake images
                ######################################################
                noise.data.normal_(0, 1)
                inputs = (txt_embedding, noise)
                _, fake_imgs, mu, logvar = \
                    nn.parallel.data_parallel(netG, inputs, self.gpus)

                ############################
                # (3) Update D network
                ###########################
                netD.zero_grad()
                errD, errD_real, errD_wrong, errD_fake = \
                    compute_discriminator_loss(netD, real_imgs, fake_imgs,
                                               real_labels, fake_labels,
                                               mu, self.gpus)
                errD.backward()
                optimizerD.step()
                ############################
                # (2) Update G network
                ###########################
                netG.zero_grad()
                errG = compute_generator_loss(netD, fake_imgs,
                                              real_labels, mu, self.gpus)
                kl_loss = KL_loss(mu, logvar)
                errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
                errG_total.backward()
                optimizerG.step()

                count = count + 1
                if i % 100 == 0:
                    summary_D = summary.scalar('D_loss', errD.data[0])
                    summary_D_r = summary.scalar('D_loss_real', errD_real)
                    summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
                    summary_D_f = summary.scalar('D_loss_fake', errD_fake)
                    summary_G = summary.scalar('G_loss', errG.data[0])
                    summary_KL = summary.scalar('KL_loss', kl_loss.data[0])

                    self.summary_writer.add_summary(summary_D, count)
                    self.summary_writer.add_summary(summary_D_r, count)
                    self.summary_writer.add_summary(summary_D_w, count)
                    self.summary_writer.add_summary(summary_D_f, count)
                    self.summary_writer.add_summary(summary_G, count)
                    self.summary_writer.add_summary(summary_KL, count)

                    # save the image result for each epoch
                    inputs = (txt_embedding, fixed_noise)
                    lr_fake, fake, _, _ = \
                        nn.parallel.data_parallel(netG, inputs, self.gpus)
                    save_img_results(real_img_cpu, fake, epoch, self.image_dir)
                    if lr_fake is not None:
                        save_img_results(None, lr_fake, epoch, self.image_dir)
            end_t = time.time()
            print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
                     Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
                     Total Time: %.2fsec
                  '''
                  % (epoch, self.max_epoch, i, len(data_loader),
                     errD.data[0], errG.data[0], kl_loss.data[0],
                     errD_real, errD_wrong, errD_fake, (end_t - start_t)))
            if epoch % self.snapshot_interval == 0:
                save_model(netG, netD, epoch, self.model_dir)
        #
        save_model(netG, netD, self.max_epoch, self.model_dir)
        #
        self.summary_writer.close()

    def sample(self, datapath, stage=1):
        if stage == 1:
            netG, _ = self.load_network_stageI()
        else:
            netG, _ = self.load_network_stageII()
        netG.eval()

        # Load text embeddings generated from the encoder
        t_file = torchfile.load(datapath)
        captions_list = t_file.raw_txt
        embeddings = np.concatenate(t_file.fea_txt, axis=0)
        num_embeddings = len(captions_list)
        print('Successfully load sentences from: ', datapath)
        print('Total number of sentences:', num_embeddings)
        print('num_embeddings:', num_embeddings, embeddings.shape)
        # path to save generated samples
        save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')]
        mkdir_p(save_dir)

        batch_size = np.minimum(num_embeddings, self.batch_size)
        nz = cfg.Z_DIM
        noise = Variable(torch.FloatTensor(batch_size, nz))
        if cfg.CUDA:
            noise = noise.cuda()
        count = 0
        while count < num_embeddings:
            if count > 3000:
                break
            iend = count + batch_size
            if iend > num_embeddings:
                iend = num_embeddings
                count = num_embeddings - batch_size
            embeddings_batch = embeddings[count:iend]
            # captions_batch = captions_list[count:iend]
            txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
            if cfg.CUDA:
                txt_embedding = txt_embedding.cuda()

            #######################################################
            # (2) Generate fake images
            ######################################################
            noise.data.normal_(0, 1)
            inputs = (txt_embedding, noise)
            _, fake_imgs, mu, logvar = \
                nn.parallel.data_parallel(netG, inputs, self.gpus)
            for i in range(batch_size):
                save_name = '%s/%d.png' % (save_dir, count + i)
                im = fake_imgs[i].data.cpu().numpy()
                im = (im + 1.0) * 127.5
                im = im.astype(np.uint8)
                # print('im', im.shape)
                im = np.transpose(im, (1, 2, 0))
                # print('im', im.shape)
                im = Image.fromarray(im)
                im.save(save_name)
            count += batch_size