Exemplo n.º 1
0
    def test(self, sess, ckpt, sample_size):
        assert ckpt is not None, 'no checkpoint provided.'

        gen_res = self.generator(self.z, reuse=False)

        num_batches = int(math.ceil(sample_size / self.num_chain))

        saver = tf.train.Saver()

        sess.run(tf.global_variables_initializer())
        saver.restore(sess, ckpt)
        print('Loading checkpoint {}.'.format(ckpt))

        for i in xrange(num_batches):
            z_vec = np.random.randn(min(sample_size, self.num_chain),
                                    self.z_size)
            g_res = sess.run(gen_res, feed_dict={self.z: z_vec})
            saveSampleResults(g_res,
                              "%s/gen%03d.png" % (self.test_dir, i),
                              col_num=self.nTileCol)

            # output interpolation results
            interp_z = linear_interpolator(z_vec,
                                           npairs=self.nTileRow,
                                           ninterp=self.nTileCol)
            interp = sess.run(gen_res, feed_dict={self.z: interp_z})
            saveSampleResults(interp,
                              "%s/interp%03d.png" % (self.test_dir, i),
                              col_num=self.nTileCol)
            sample_size = sample_size - self.num_chain
Exemplo n.º 2
0
    def test(self, sess, ckpt, sample_size):
        assert ckpt is not None, 'no checkpoint provided.'

        # gen_res = self.generator(self.z, reuse=False)
        self.gen_res = self.generator(self.z, reuse=False)

        # num_batches = int(math.ceil(sample_size / self.num_chain))
        num_batches = 100

        saver = tf.train.Saver()

        sess.run(tf.global_variables_initializer())
        saver.restore(sess, ckpt)
        print('Loading checkpoint {}.'.format(ckpt))

        test_data = DataSet(self.data_path3, image_size=self.image_size)

        for i in xrange(num_batches):
            test_data_batch = test_data[
                i * self.batch_size:min(len(test_data), (i + 1) *
                                        self.batch_size)]

            z_vec = test_data_batch
            g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec})
            saveSampleResults(g_res,
                              "%s/gen%03d.png" % (self.test_dir, i),
                              col_num=self.nTileCol)
Exemplo n.º 3
0
    def sampling(self,
                 sess,
                 ckpt,
                 sample_size,
                 sample_step,
                 calculate_inception=False):
        assert ckpt is not None, 'no checkpoint provided.'

        self.t1 = sample_step

        gen_res = self.generator(self.z, reuse=False)
        obs_res = self.descriptor(self.obs, reuse=False)

        self.langevin_descriptor = self.langevin_dynamics_descriptor(gen_res)
        num_batches = int(math.ceil(sample_size / self.num_chain))

        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, ckpt)
        print('Loading checkpoint {}.'.format(ckpt))

        sample_results_des = np.random.randn(self.num_chain * num_batches,
                                             self.image_size, self.image_size,
                                             self.num_channel)
        for i in xrange(num_batches):
            z_vec = np.random.randn(self.num_chain, self.z_size)

            # synthesis by generator
            g_res = sess.run(gen_res, feed_dict={self.z: z_vec})
            saveSampleResults(g_res,
                              "%s/gen%03d_test.png" % (self.test_dir, i),
                              col_num=self.nTileCol)

            # synthesis by descriptor and generator
            syn = sess.run(self.langevin_descriptor, feed_dict={self.z: z_vec})
            saveSampleResults(syn,
                              "%s/des%03d_test.png" % (self.test_dir, i),
                              col_num=self.nTileCol)

            sample_results_des[i * self.num_chain:(i + 1) *
                               self.num_chain] = syn

            if i % 10 == 0:
                print("Sampling batches: {}, from {} to {}".format(
                    i, i * self.num_chain,
                    min((i + 1) * self.num_chain, sample_size)))
        sample_results_des = sample_results_des[:sample_size]
        sample_results_des = np.minimum(1, np.maximum(-1, sample_results_des))
        sample_results_des = (sample_results_des + 1) / 2 * 255

        if calculate_inception:
            m, s = get_inception_score(sample_results_des)
            print("Inception score: mean {}, sd {}".format(m, s))

        sampling_output_file = os.path.join(self.output_dir, 'samples_des.npy')
        np.save(sampling_output_file, sample_results_des)
        print("The results are saved in folder: {}".format(self.output_dir))
Exemplo n.º 4
0
    def test(self):
        assert self.opts.ckpt_gen is not None, 'Please specify the path to the checkpoint of generator.'
        assert self.opts.ckpt_des is not None, 'Please specify the path to the checkpoint of generator.'
        print('===Test on ' + self.opts.ckpt_gen + ' and ' +
              self.opts.ckpt_des + ' ===')
        generator = torch.load(self.opts.ckpt_gen).eval()
        descriptor = torch.load(self.opts.ckpt_des).eval()

        if not os.path.exists(self.opts.output_dir):
            os.makedirs(self.opts.output_dir)

        test_batch = int(
            np.ceil(self.opts.test_size / self.opts.nRow / self.opts.nCol))
        print('===Generated images saved to %s ===' % (self.opts.output_dir))

        for i in range(test_batch):
            z = torch.randn(self.num_chain, self.opts.z_size, 1, 1)
            z = Variable(z.cuda())
            gen_res = generator(z)

            for s in range(self.opts.langevin_step_num_des):
                # clone it and turn x into a leaf variable so the grad won't be thrown away
                gen_res = Variable(gen_res.data, requires_grad=True)
                gen_res_feature = descriptor(gen_res)
                gen_res_feature.backward(
                    torch.ones(self.num_chain, self.opts.z_size).cuda())
                grad = gen_res.grad
                gen_res = gen_res - 0.5 * self.opts.langevin_step_size_des * self.opts.langevin_step_size_des * \
                                    (gen_res / self.opts.sigma_des / self.opts.sigma_des - grad)

            if self.opts.score:
                gen_res = gen_res.detach().cpu()
                for img_no, img in enumerate(gen_res):
                    if i * self.num_chain + img_no + 1 > self.opts.test_size:
                        break
                    print('Generating {:05d}/{:05d}'.format(
                        i * self.num_chain + img_no + 1, self.opts.test_size))
                    saveSampleResults(img[None, :, :, :],
                                      "%s/testres_%03d.png" %
                                      (self.opts.output_dir,
                                       i * self.num_chain + img_no + 1),
                                      col_num=1,
                                      margin_syn=0)
            else:
                gen_res = gen_res.detach().cpu()
                print('Generating {:05d}/{:05d}'.format(i + 1, test_batch))
                saveSampleResults(gen_res,
                                  "%s/testres_%03d.png" %
                                  (self.opts.output_dir, i + 1),
                                  col_num=self.opts.nCol,
                                  margin_syn=0)

        print('===Image generation done.===')
Exemplo n.º 5
0
    def test(self, sess, ckpt, sample_size):
        assert ckpt is not None, 'no checkpoint provided.'

        # gen_res = self.generator(self.z, reuse=False)
        self.gen_res = self.generator(self.z, reuse=False)

        # num_batches = int(math.ceil(sample_size / self.num_chain))
        num_batches = 100

        saver = tf.train.Saver()

        sess.run(tf.global_variables_initializer())
        saver.restore(sess, ckpt)
        print('Loading checkpoint {}.'.format(ckpt))

        test_data = DataSet(self.data_path3, image_size=self.image_size)

        for i in xrange(num_batches):
            # z_vec = np.random.randn(min(sample_size, self.num_chain), self.z_size)
            test_data_batch = test_data[
                i * self.batch_size:min(len(test_data), (i + 1) *
                                        self.batch_size)]
            # print('(i + 1) = {}'.format((i + 1)))
            # print('self.batch_size = {}'.format(self.batch_size)) # self.batch_size = 100
            # print('len(origin_data) = {}'.format(len(origin_data))) # len(origin_data) = 1504
            # print('ori_data.shape = {}'.format(ori_data.shape)) # origin_data.shape = (100, 64, 64, 3)

            # ori_data = np.reshape(ori_data, (100, 64*64*3))
            # print('ori_data.shape = {}'.format(ori_data.shape)) # origin_data.shape = (100, 12288)
            # Step G0: generate X ~ N(0, 1)
            '''
            z_vec = np.random.randn(self.num_chain, self.z_size) # z_vec.shape = (144, 100) 
            print('G0 : self.num_chain = {}'.format(self.num_chain)) # self.num_chain = 144
            print('G0 : self.z_size = {}'.format(self.z_size)) # self.z_size = 100
            print('G0 : z_vec.shape = {}'.format(z_vec.shape)) # z_vec.shape = (144, 100)
            '''

            z_vec = test_data_batch

            g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec})
            # g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec})
            saveSampleResults(g_res,
                              "%s/gen%03d.png" % (self.test_dir, i),
                              col_num=self.nTileCol)
Exemplo n.º 6
0
    def train(self):
        if self.opts.ckpt_des != None and self.opts.ckpt_des != 'None':
            self.descriptor = torch.load(self.opts.ckpt_des)
            print('Loading Descriptor from ' + self.opts.ckpt_des + '...')
        else:
            if self.opts.set == 'scene' or self.opts.set == 'lsun':
                self.descriptor = Descriptor(self.opts).cuda()
                print('Loading Descriptor without initialization...')
            elif self.opts.set == 'cifar':
                self.descriptor = Descriptor_cifar(self.opts).cuda()
                print('Loading Descriptor_cifar without initialization...')
            else:
                raise NotImplementedError(
                    'The set should be either scene, lsun, or cifar')

        if self.opts.ckpt_gen != None and self.opts.ckpt_gen != 'None':
            self.generator = torch.load(self.opts.ckpt_gen)
            print('Loading Generator from ' + self.opts.ckpt_gen + '...')
        else:
            if self.opts.set == 'scene' or self.opts.set == 'lsun':
                self.generator = Generator(self.opts).cuda()
                print('Loading Generator without initialization...')
            elif self.opts.set == 'cifar':
                self.generator = Generator_cifar(self.opts).cuda()
                print('Loading Generator_cifar without initialization...')
            else:
                raise NotImplementedError(
                    'The set should be either scene, lsun or cifar')

        # TODO -add tensorboard & plot

        batch_size = self.opts.batch_size
        if self.opts.set == 'scene' or self.opts.set == 'cifar':
            train_data = DataSet(os.path.join(self.opts.data_path,
                                              self.opts.category),
                                 image_size=self.opts.img_size)
        else:
            train_data = torchvision.datasets.LSUN(
                root=self.opts.data_path,
                classes=['bedroom_train'],
                transform=transforms.Compose([
                    transforms.Resize(self.img_size),
                    transforms.ToTensor(),
                ]))
        num_batches = int(math.ceil(len(train_data) / batch_size))

        # sample_results = np.random.randn(self.num_chain * num_batches, self.opts.img_size, self.opts.img_size, 3)
        des_optimizer = torch.optim.Adam(self.descriptor.parameters(),
                                         lr=self.opts.lr_des,
                                         betas=[self.opts.beta1_des, 0.999])
        gen_optimizer = torch.optim.Adam(self.generator.parameters(),
                                         lr=self.opts.lr_gen,
                                         betas=[self.opts.beta1_gen, 0.999])

        if not os.path.exists(self.opts.ckpt_dir):
            os.makedirs(self.opts.ckpt_dir)
        if not os.path.exists(self.opts.output_dir):
            os.makedirs(self.opts.output_dir)
        logfile = open(self.opts.ckpt_dir + '/log', 'w+')

        mse_loss = torch.nn.MSELoss(size_average=False, reduce=True)

        for epoch in range(self.opts.num_epoch):
            start_time = time.time()
            gen_loss_epoch, des_loss_epoch, recon_loss_epoch = [], [], []
            for i in range(num_batches):
                if (i + 1) * batch_size > len(train_data):
                    continue
                obs_data = train_data[i * batch_size:min(
                    (i + 1) * batch_size, len(train_data))]
                obs_data = Variable(
                    torch.Tensor(obs_data).cuda())  # ,requires_grad=True

                # G0
                z = torch.randn(self.num_chain, self.opts.z_size, 1, 1)
                z = Variable(z.cuda(), requires_grad=True)
                # NCHW
                gen_res = self.generator(z)

                # D1
                if self.opts.langevin_step_num_des > 0:
                    revised = self.langevin_dynamics_descriptor(gen_res)
                # G1
                if self.opts.langevin_step_num_gen > 0:
                    z = self.langevin_dynamics_generator(z, revised)

                # D2
                obs_feature = self.descriptor(obs_data)
                revised_feature = self.descriptor(revised)

                des_loss = (revised_feature.mean(0) -
                            obs_feature.mean(0)).sum()

                des_optimizer.zero_grad()
                des_loss.backward()
                des_optimizer.step()

                # G2
                ini_gen_res = gen_res.detach()
                if self.opts.langevin_step_num_gen > 0:
                    gen_res = self.generator(z)
                # gen_res=gen_res.detach()
                gen_loss = 0.5 * self.opts.sigma_gen * self.opts.sigma_gen * mse_loss(
                    gen_res, revised.detach())

                gen_optimizer.zero_grad()
                gen_loss.backward()
                gen_optimizer.step()

                # Compute reconstruction loss
                recon_loss = mse_loss(revised, ini_gen_res)

                gen_loss_epoch.append(gen_loss.cpu().data)
                des_loss_epoch.append(des_loss.cpu().data)
                recon_loss_epoch.append(recon_loss.cpu().data)

            # TO-FIX (confliction between pytorch and tf)
            # if opts.incep_interval>0, compute inception score each [incep_interval] epochs.
            # if self.opts.incep_interval > 0:
            #     import inception_model
            #     if epoch % self.opts.incep_interval == 0:
            #         inception_log_file = os.path.join(self.opts.output_dir, 'inception.txt')
            #         inception_output_file = os.path.join(self.opts.output_dir, 'inception.mat')
            #         sample_results_partial = revised[:len(train_data)]
            #         sample_results_partial = np.minimum(1, np.maximum(-1, sample_results_partial))
            #         sample_results_partial = (sample_results_partial + 1) / 2 * 255
            #         # sample_results_list = sample_results.copy().swapaxes(1, 3)
            #         # sample_results_list = np.split(sample_results, len(sample_results), axis=0)
            #         m, s = get_inception_score(sample_results_partial)
            #         fo = open(inception_log_file, 'a')
            #         fo.write("Epoch {}: mean {}, sd {}".format(epoch, m, s))
            #         fo.close()
            #         inception_mean.append(m)
            #         inception_sd.append(s)
            #         sio.savemat(inception_output_file,
            #                     {'mean': np.asarray(inception_mean), 'sd': np.asarray(inception_sd)})

            try:
                col_num = self.opts.nCol
                saveSampleResults(obs_data.cpu().data[:col_num * col_num],
                                  "%s/observed.png" % (self.opts.output_dir),
                                  col_num=col_num)
            except:
                print('Error when saving obs_data. Skip.')
                continue
            saveSampleResults(revised.cpu().data,
                              "%s/des_%03d.png" %
                              (self.opts.output_dir, epoch + 1),
                              col_num=self.opts.nCol)
            saveSampleResults(gen_res.cpu().data,
                              "%s/gen_%03d.png" %
                              (self.opts.output_dir, epoch + 1),
                              col_num=self.opts.nCol)

            end_time = time.time()
            print(
                'Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, '
                'time: {:.2f}s'.format(epoch + 1, self.opts.num_epoch,
                                       np.mean(des_loss_epoch),
                                       np.mean(gen_loss_epoch),
                                       np.mean(recon_loss_epoch),
                                       end_time - start_time))

            # python 3
            print(
                'Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, '
                'time: {:.2f}s'.format(epoch, self.opts.num_epoch,
                                       np.mean(des_loss_epoch),
                                       np.mean(gen_loss_epoch),
                                       np.mean(recon_loss_epoch),
                                       end_time - start_time),
                file=logfile)
            # python 2.7
            # print >> logfile, ('Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, '
            #     'time: {:.2f}s'.format(epoch,self.opts.num_epoch, np.mean(des_loss_epoch), np.mean(gen_loss_epoch), np.mean(recon_loss_epoch),
            #                            end_time - start_time))

            if epoch % self.opts.log_epoch == 0:
                torch.save(
                    self.descriptor,
                    self.opts.ckpt_dir + '/des_ckpt_{}.pth'.format(epoch))
                torch.save(
                    self.generator,
                    self.opts.ckpt_dir + '/gen_ckpt_{}.pth'.format(epoch))
        logfile.close()
    def train(self, sess):
        self.build_model()

        # Prepare training data
        train_data = DataSet(self.data_path, image_size=self.image_size)
        num_batches = int(math.ceil(len(train_data) / self.batch_size))

        # initialize training
        sess.run(tf.global_variables_initializer())

        sample_results = np.random.randn(self.num_chain * num_batches, self.image_size, self.image_size, 3)

        saver = tf.train.Saver(max_to_keep=50)

        # make graph immutable
        tf.get_default_graph().finalize()

        # store graph in protobuf
        with open(self.model_dir + '/graph.proto', 'w') as f:
            f.write(str(tf.get_default_graph().as_graph_def()))

        des_loss_vis = Visualizer(title='descriptor', ylabel='normalized negative log-likelihood', ylim=(-200, 200),
                                  save_figpath=self.log_dir + '/des_loss.png', avg_period = self.batch_size)

        gen_loss_vis = Visualizer(title='generator', ylabel='reconstruction error',
                                  save_figpath=self.log_dir + '/gen_loss.png', avg_period = self.batch_size)


        # train
        for epoch in xrange(self.num_epochs):
            start_time = time.time()
            des_loss_avg, gen_loss_avg, mse_avg = [], [], []
            for i in xrange(num_batches):

                obs_data = train_data[i * self.batch_size:min(len(train_data), (i + 1) * self.batch_size)]

				# Step G0: generate X ~ N(0, 1)
				z_vec = np.random.randn(self.num_chain, self.z_size)
				print('z_vec = {}'.format(z_vec))
				g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec})
				print('g_res = {}'.format(g_res))
                # Step D1: obtain synthesized images Y
                if self.t1 > 0:
                    syn = sess.run(self.langevin_descriptor, feed_dict={self.syn: g_res})
                # Step G1: update X using Y as training image
                if self.t2 > 0:
                    z_vec = sess.run(self.langevin_generator, feed_dict={self.z: z_vec, self.obs: syn})
                # Step D2: update D net
                d_loss = sess.run([self.des_loss, self.apply_d_grads],
                                  feed_dict={self.obs: obs_data, self.syn: syn})[0]
                # Step G2: update G net
                g_loss = sess.run([self.gen_loss, self.apply_g_grads],
                                  feed_dict={self.obs: syn, self.z: z_vec})[0]

                # Compute MSE for generator
                mse = sess.run(self.recon_err, feed_dict={self.obs: syn, self.syn: g_res})
                sample_results[i * self.num_chain:(i + 1) * self.num_chain] = syn

                des_loss_avg.append(d_loss)
                gen_loss_avg.append(g_loss)
                mse_avg.append(mse)

                des_loss_vis.add_loss_val(epoch*num_batches + i, d_loss / float(self.image_size * self.image_size * 3))
                gen_loss_vis.add_loss_val(epoch*num_batches + i, mse)

                if self.debug:
                    print('Epoch #{:d}, [{:2d}]/[{:2d}], descriptor loss: {:.4f}, generator loss: {:.4f}, '
                          'L2 distance: {:4.4f}'.format(epoch, i + 1, num_batches, d_loss.mean(), g_loss.mean(), mse))
                if i == 0 and epoch % self.log_step == 0:
                    if not os.path.exists(self.sample_dir):
                        os.makedirs(self.sample_dir)
                    saveSampleResults(syn, "%s/des%03d.png" % (self.sample_dir, epoch), col_num=self.nTileCol)
                    saveSampleResults(g_res, "%s/gen%03d.png" % (self.sample_dir, epoch), col_num=self.nTileCol)

            end_time = time.time()
            print('Epoch #{:d}, avg.descriptor loss: {:.4f}, avg.generator loss: {:.4f}, avg.L2 distance: {:4.4f}, '
                  'time: {:.2f}s'.format(epoch, np.mean(des_loss_avg), np.mean(gen_loss_avg), np.mean(mse_avg), end_time - start_time))

            if epoch % self.log_step == 0:
                if not os.path.exists(self.model_dir):
                    os.makedirs(self.model_dir)
                saver.save(sess, "%s/%s" % (self.model_dir, 'model.ckpt'), global_step=epoch)

                if not os.path.exists(self.log_dir):
                    os.makedirs(self.log_dir)

                des_loss_vis.draw_figure()
                gen_loss_vis.draw_figure()
Exemplo n.º 8
0
    def train(self):
        if self.opts.ckpt_des != None and self.opts.ckpt_des != 'None':
            self.descriptor = torch.load(self.opts.ckpt_des)
            print('Loading Descriptor from ' + self.opts.ckpt_des + '...')
        else:
            if self.opts.set == 'scene' or self.opts.set == 'lsun':
                self.descriptor = Descriptor(self.opts).cuda()
                print('Loading Descriptor without initialization...')
            elif self.opts.set == 'cifar':
                self.descriptor = Descriptor_cifar(self.opts).cuda()
                print('Loading Descriptor_cifar without initialization...')
            else:
                raise NotImplementedError('The set should be either scene, lsun, or cifar')

        if self.opts.ckpt_gen != None and self.opts.ckpt_gen != 'None':
            self.generator = torch.load(self.opts.ckpt_gen)
            print('Loading Generator from ' + self.opts.ckpt_gen + '...')
        else:
            if self.opts.set == 'scene' or self.opts.set == 'lsun':
                self.generator = Generator(self.opts).cuda()
                print('Loading Generator without initialization...')
            elif self.opts.set == 'cifar':
                self.generator = Generator_cifar(self.opts).cuda()
                print('Loading Generator_cifar without initialization...')
            else:
                raise NotImplementedError('The set should be either scene, lsun or cifar')

        batch_size = self.opts.batch_size
        if self.opts.set == 'scene' or self.opts.set == 'cifar':
            train_data = DataSet(os.path.join(self.opts.data_path, self.opts.category), image_size=self.opts.img_size)
        else:
            train_data = torchvision.datasets.LSUN(root=self.opts.data_path,
                                                   classes=['bedroom_train'],
                                                   transform=transforms.Compose([transforms.Resize(self.img_size),
                                                                                 transforms.ToTensor(), ]))
        num_batches = int(math.ceil(len(train_data) / batch_size))

        # sample_results = np.random.randn(self.num_chain * num_batches, self.opts.img_size, self.opts.img_size, 3)
        des_optimizer = torch.optim.Adam(self.descriptor.parameters(), lr=self.opts.lr_des,
                                         betas=[self.opts.beta1_des, 0.999])
        gen_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.opts.lr_gen,
                                         betas=[self.opts.beta1_gen, 0.999])

        if not os.path.exists(self.opts.ckpt_dir):
            os.makedirs(self.opts.ckpt_dir)
        if not os.path.exists(self.opts.output_dir):
            os.makedirs(self.opts.output_dir)
        logfile = open(self.opts.ckpt_dir + '/log', 'w+')

        mse_loss = torch.nn.MSELoss(size_average=False, reduce=True)

        for epoch in range(self.opts.num_epoch):
            start_time = time.time()
            gen_loss_epoch, des_loss_epoch, recon_loss_epoch = [], [], []
            for i in range(num_batches):
                if (i + 1) * batch_size > len(train_data):
                    continue
                obs_data = train_data[i * batch_size:min((i + 1) * batch_size, len(train_data))]
                obs_data = Variable(torch.Tensor(obs_data).cuda())  # ,requires_grad=True

                # G0
                z = torch.randn(self.num_chain, self.opts.z_size, 1, 1)
                z = Variable(z.cuda(), requires_grad=True)
                # NCHW
                gen_res = self.generator(z)

                # D1
                if self.opts.langevin_step_num_des > 0:
                    revised = self.langevin_dynamics_descriptor(gen_res)
                # G1
                if self.opts.langevin_step_num_gen > 0:
                    z = self.langevin_dynamics_generator(z, revised)

                # D2
                obs_feature = self.descriptor(obs_data)
                revised_feature = self.descriptor(revised)

                des_loss = (revised_feature.mean(0) - obs_feature.mean(0)).sum()

                des_optimizer.zero_grad()
                des_loss.backward()
                des_optimizer.step()

                # G2
                ini_gen_res = gen_res.detach()
                if self.opts.langevin_step_num_gen > 0:
                    gen_res = self.generator(z)
                gen_loss = 1.0 / (2.0 * self.opts.sigma_gen * self.opts.sigma_gen) * mse_loss(gen_res,
                                                                                      revised.detach())

                gen_optimizer.zero_grad()
                gen_loss.backward()
                gen_optimizer.step()

                # Compute reconstruction loss
                recon_loss = mse_loss(revised, ini_gen_res)

                gen_loss_epoch.append(gen_loss.cpu().data)
                des_loss_epoch.append(des_loss.cpu().data)
                recon_loss_epoch.append(recon_loss.cpu().data)

            try:
                col_num = self.opts.nCol
                saveSampleResults(obs_data.cpu().data[:col_num * col_num], "%s/observed.png" % (self.opts.output_dir),
                                  col_num=col_num)
            except:
                print('Error when saving obs_data. Skip.')
                continue
            saveSampleResults(revised.cpu().data, "%s/des_%03d.png" % (self.opts.output_dir, epoch + 1),
                              col_num=self.opts.nCol)
            saveSampleResults(gen_res.cpu().data, "%s/gen_%03d.png" % (self.opts.output_dir, epoch + 1),
                              col_num=self.opts.nCol)

            end_time = time.time()
            print('Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, '
                  'time: {:.2f}s'.format(epoch + 1, self.opts.num_epoch, np.mean(des_loss_epoch),
                                         np.mean(gen_loss_epoch), np.mean(recon_loss_epoch),
                                         end_time - start_time))

            # python 3
            print('Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, '
                  'time: {:.2f}s'.format(epoch, self.opts.num_epoch, np.mean(des_loss_epoch), np.mean(gen_loss_epoch),
                                         np.mean(recon_loss_epoch),
                                         end_time - start_time), file=logfile)
            # python 2.7
            # print >> logfile, ('Epoch #{:d}/{:d}, des_loss: {:.4f}, gen_loss: {:.4f}, recon_loss: {:.4f}, '
            #     'time: {:.2f}s'.format(epoch,self.opts.num_epoch, np.mean(des_loss_epoch), np.mean(gen_loss_epoch), np.mean(recon_loss_epoch),
            #                            end_time - start_time))


            if epoch % self.opts.log_epoch == 0:
                torch.save(self.descriptor, self.opts.ckpt_dir + '/des_ckpt_{}.pth'.format(epoch))
                torch.save(self.generator, self.opts.ckpt_dir + '/gen_ckpt_{}.pth'.format(epoch))
        logfile.close()
Exemplo n.º 9
0
    def train(self, sess):
        self.build_model()

        # Prepare training data
        train_data = DataSet(self.data_path, image_size=self.image_size)
        num_batches = int(math.ceil(len(train_data) / self.batch_size))

        # initialize training
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        sample_results = np.random.randn(self.num_chain * num_batches,
                                         self.image_size, self.image_size, 3)

        saver = tf.train.Saver(max_to_keep=50)

        writer = tf.summary.FileWriter(self.log_dir, sess.graph)

        # symbolic langevins
        langevin_descriptor = self.langevin_dynamics_descriptor(self.syn)
        langevin_generator = self.langevin_dynamics_generator(self.z)

        # make graph immutable
        tf.get_default_graph().finalize()

        # store graph in protobuf
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        with open(self.model_dir + '/graph.proto', 'w') as f:
            f.write(str(tf.get_default_graph().as_graph_def()))

        # train
        for epoch in xrange(self.num_epochs):
            start_time = time.time()
            for i in xrange(num_batches):

                obs_data = train_data[i * self.
                                      batch_size:min(len(train_data), (i + 1) *
                                                     self.batch_size)]

                # Step G0: generate X ~ N(0, 1)
                z_vec = np.random.randn(self.num_chain, self.z_size)
                g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec})
                # Step D1: obtain synthesized images Y
                if self.t1 > 0:
                    syn = sess.run(langevin_descriptor,
                                   feed_dict={self.syn: g_res})
                # Step G1: update X using Y as training image
                if self.t2 > 0:
                    z_vec = sess.run(langevin_generator,
                                     feed_dict={
                                         self.z: z_vec,
                                         self.obs: syn
                                     })
                # Step D2: update D net
                d_loss = sess.run(
                    [self.des_loss, self.des_loss_update, self.apply_d_grads],
                    feed_dict={
                        self.obs: obs_data,
                        self.syn: syn
                    })[0]
                # Step G2: update G net
                g_loss = sess.run(
                    [self.gen_loss, self.gen_loss_update, self.apply_g_grads],
                    feed_dict={
                        self.obs: syn,
                        self.z: z_vec
                    })[0]

                # Compute MSE
                mse = sess.run([self.recon_err, self.recon_err_update],
                               feed_dict={
                                   self.obs: obs_data,
                                   self.syn: syn
                               })[0]

                sample_results[i * self.num_chain:(i + 1) *
                               self.num_chain] = syn
                print(
                    'Epoch #{:d}, [{:2d}]/[{:2d}], descriptor loss: {:.4f}, generator loss: {:.4f}, '
                    'L2 distance: {:4.4f}'.format(epoch, i + 1, num_batches,
                                                  d_loss, g_loss, mse))
                if i == 0 and epoch % self.log_step == 0:
                    if not os.path.exists(self.sample_dir):
                        os.makedirs(self.sample_dir)
                    saveSampleResults(syn,
                                      "%s/des%03d.png" %
                                      (self.sample_dir, epoch),
                                      col_num=self.nTileCol)
                    saveSampleResults(g_res,
                                      "%s/gen%03d.png" %
                                      (self.sample_dir, epoch),
                                      col_num=self.nTileCol)

            [des_loss_avg, gen_loss_avg, mse_avg, summary] = sess.run([
                self.des_loss_mean, self.gen_loss_mean, self.recon_err_mean,
                self.summary_op
            ])
            end_time = time.time()
            print(
                'Epoch #{:d}, avg. descriptor loss: {:.4f}, avg. generator loss: {:.4f}, avg. L2 distance: {:4.4f}, '
                'time: {:.2f}s'.format(epoch, des_loss_avg, gen_loss_avg,
                                       mse_avg, end_time - start_time))
            writer.add_summary(summary, epoch)
            writer.flush()

            if epoch % self.log_step == 0:
                if not os.path.exists(self.model_dir):
                    os.makedirs(self.model_dir)
                saver.save(sess,
                           "%s/%s" % (self.model_dir, 'model.ckpt'),
                           global_step=epoch)
Exemplo n.º 10
0
    def train(self, sess):

        self.build_model()

        # Prepare training data
        is_mnist = True if self.type == "mnist" else False
        dataset = DataSet(self.data_path,
                          image_size=self.image_size,
                          batch_sz=self.batch_size,
                          prefetch=self.prefetch,
                          read_len=self.read_len,
                          is_mnist=is_mnist)

        # initialize training
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        sample_results_des = np.random.randn(
            self.num_chain * dataset.num_batch, self.image_size,
            self.image_size, self.num_channel)
        sample_results_gen = np.random.randn(
            self.num_chain * dataset.num_batch, self.image_size,
            self.image_size, self.num_channel)

        saver = tf.train.Saver(max_to_keep=50)

        writer = tf.summary.FileWriter(self.log_dir, sess.graph)

        # measure 1: Parzon-window based likelihood
        if self.calculate_parzen:

            kde = ParsenDensityEsimator(sess)
            parzon_des_mean_list, parzon_des_se_list = [], []
            parzon_gen_mean_list, parzon_gen_se_list = [], []
            parzon_log_file = os.path.join(self.output_dir, 'parzon.txt')
            parzon_write_file = os.path.join(self.output_dir, 'parzon.mat')
            parzon_syn_data_file = os.path.join(self.output_dir,
                                                'parzon_syn_dat.mat')
            parzon_max = -10000

        # measure 2: inception score
        if self.calculate_inception:
            inception_log_file = os.path.join(self.output_dir, 'inception.txt')
            inception_write_file = os.path.join(self.output_dir,
                                                'inception.mat')

        # make graph immutable
        tf.get_default_graph().finalize()

        # store graph in protobuf
        with open(self.model_dir + '/graph.proto', 'w') as f:
            f.write(str(tf.get_default_graph().as_graph_def()))

        inception_mean, inception_sd = [], []

        # train
        minibatch = -1

        for epoch in xrange(self.num_epochs):
            start_time = time.time()
            for i in xrange(dataset.num_batch):
                minibatch = minibatch + 1
                obs_data = dataset.get_batch()

                # Step G0: generate X ~ N(0, 1)
                z_vec = np.random.randn(self.num_chain, self.z_size)
                g_res = sess.run(self.gen_res, feed_dict={self.z: z_vec})
                # Step D1: obtain synthesized images Y
                if self.t1 > 0:
                    syn = sess.run(self.langevin_descriptor,
                                   feed_dict={self.syn: g_res})
                # Step G1: update X using Y as training image
                if self.t2 > 0:
                    z_vec = sess.run(self.langevin_generator,
                                     feed_dict={
                                         self.z: z_vec,
                                         self.obs: syn
                                     })
                # Step D2: update D net
                d_loss = sess.run(
                    [self.des_loss, self.des_loss_update, self.apply_d_grads],
                    feed_dict={
                        self.obs: obs_data,
                        self.syn: syn
                    })[0]
                # Step G2: update G net
                g_loss = sess.run(
                    [self.gen_loss, self.gen_loss_update, self.apply_g_grads],
                    feed_dict={
                        self.obs: syn,
                        self.z: z_vec
                    })[0]

                # Compute MSE
                mse = sess.run([self.recon_err, self.recon_err_update],
                               feed_dict={
                                   self.obs: obs_data,
                                   self.syn: syn
                               })[0]

                sample_results_gen[i * self.num_chain:(i + 1) *
                                   self.num_chain] = g_res
                sample_results_des[i * self.num_chain:(i + 1) *
                                   self.num_chain] = syn

                if minibatch % self.log_step == 0:
                    end_time = time.time()
                    [des_loss_avg, gen_loss_avg, mse_avg, summary] = sess.run([
                        self.des_loss_mean, self.gen_loss_mean,
                        self.recon_err_mean, self.summary_op
                    ])
                    writer.add_summary(summary, minibatch)
                    writer.flush()
                    print(
                        'Epoch #{:d}, minibatch #{:d}, avg.des loss: {:.4f}, avg.gen loss: {:.4f}, '
                        'avg.L2 dist: {:4.4f}, time: {:.2f}s'.format(
                            epoch, minibatch, des_loss_avg, gen_loss_avg,
                            mse_avg, end_time - start_time))
                    start_time = time.time()

                    # save synthesis images
                    if not os.path.exists(self.sample_dir):
                        os.makedirs(self.sample_dir)
                    saveSampleResults(syn,
                                      "%s/des_%06d_%06d.png" %
                                      (self.sample_dir, epoch, minibatch),
                                      col_num=self.nTileCol)
                    saveSampleResults(g_res,
                                      "%s/gen_%06d_%06d.png" %
                                      (self.sample_dir, epoch, minibatch),
                                      col_num=self.nTileCol)

                if minibatch % (self.log_step * 20) == 0:
                    # save check points
                    if not os.path.exists(self.model_dir):
                        os.makedirs(self.model_dir)
                    saver.save(sess,
                               "%s/%s" % (self.model_dir, 'model.ckpt'),
                               global_step=minibatch)

            if self.calculate_inception and epoch % 20 == 0:

                sample_results_partial = sample_results_des[:len(dataset)]
                sample_results_partial = np.minimum(
                    1, np.maximum(-1, sample_results_partial))
                sample_results_partial = (sample_results_partial + 1) / 2 * 255

                m, s = get_inception_score(sample_results_partial)
                print("Inception score: mean {}, sd {}".format(m, s))
                fo = open(inception_log_file, 'a')
                fo.write("Epoch {}: mean {}, sd {} \n".format(epoch, m, s))
                fo.close()
                inception_mean.append(m)
                inception_sd.append(s)
                sio.savemat(
                    inception_write_file, {
                        'mean': np.asarray(inception_mean),
                        'sd': np.asarray(inception_sd)
                    })

            if self.calculate_parzen:

                samples_des = sample_results_des[:10000]
                samples_gen = sample_results_gen[:10000]

                parzon_des_mean, parzon_des_se, parzon_gen_mean, parzon_gen_se = kde.eval_parzen(
                    samples_des, samples_gen)

                parzon_des_mean_list.append(parzon_des_mean)
                parzon_des_se_list.append(parzon_des_se)
                parzon_gen_mean_list.append(parzon_gen_mean)
                parzon_gen_se_list.append(parzon_gen_se)

                if parzon_des_mean > parzon_max:
                    parzon_max = parzon_des_mean
                    sio.savemat(parzon_syn_data_file, {
                        'samples_des': samples_des,
                        'samples_gen': samples_gen
                    })

                fo = open(parzon_log_file, 'a')
                fo.write(
                    "Epoch {}: des mean {}, sd {}; gen mean {}, sd {}, max score {}. \n"
                    .format(epoch, parzon_des_mean, parzon_des_se,
                            parzon_gen_mean, parzon_gen_se, parzon_max))
                fo.close()

                sio.savemat(
                    parzon_write_file, {
                        'parzon_des_mean': np.asarray(parzon_des_mean_list),
                        'parzon_des_se': np.asarray(parzon_des_se_list),
                        'parzon_gen_mean': np.asarray(parzon_gen_mean_list),
                        'parzon_gen_se': np.asarray(parzon_gen_se_list)
                    })