コード例 #1
0
    def train(self):
        if not os.path.exists(os.path.join("data", self.dataset_name)):
            print "No GAN training files found. Training aborted. =("
            return

        dataset_files = glob.glob("data/" + self.dataset_name + "/*.png")
        dataset_files.sort(key=ops.alphanum_key)
        dataset_files = np.array(dataset_files)
        dataset_encodings = np.load("data/" + self.dataset_name +
                                    "/encondings.npy")

        n_files = dataset_files.shape[0]
        training_step = 0
        test_idxs = np.random.permutation(range(n_files))[0:self.batch_size]
        test_imgs = ops.load_imgbatch(dataset_files[test_idxs], color=False)
        test_encs = dataset_encodings[test_idxs, :]

        self.session.run(tf.initialize_all_variables())
        self.load()
        for epoch in xrange(self.n_iterations):

            rand_idxs = np.random.permutation(range(n_files))
            n_batches = n_files // self.batch_size

            for batch_i in xrange(n_batches):
                idxs_i = rand_idxs[batch_i * self.batch_size:(batch_i + 1) *
                                   self.batch_size]
                imgs_batch = ops.load_imgbatch(dataset_files[idxs_i],
                                               color=False)
                self.session.run(self.optimizer,
                                 feed_dict={
                                     self.images: imgs_batch,
                                     self.encodings:
                                     dataset_encodings[idxs_i, :],
                                     self.train_flag: True
                                 })
                training_step += 1
                current_loss = self.session.run(self.loss,
                                                feed_dict={
                                                    self.images: test_imgs,
                                                    self.encodings: test_encs,
                                                    self.train_flag: False
                                                })
                print "Epoch {}/{}, Batch {}/{}, Loss {}".format(
                    epoch + 1, self.n_iterations, batch_i + 1, n_batches,
                    current_loss)
                # Save checkpoint
                if training_step % 1000 == 0:
                    if not os.path.exists("checkpoint"):
                        print "Checkpoint folder not found. reating one..."
                        os.makedirs("checkpoint")
                        print "Done."
                    create_folder('checkpoint/InversePrGAN{}'.format(
                        self.dataset_name))
                    self.saver.save(
                        self.session,
                        'checkpoint/InversePrGAN{}/model.ckpt'.format(
                            self.dataset_name),
                        global_step=training_step)
コード例 #2
0
ファイル: RenderNet.py プロジェクト: a-programr/CS726project
    def train(self):
        if not os.path.exists(os.path.join("data", "train")):
            print "No training files found. Training aborted. =("
            return

        dataset_files = glob.glob("data/train/*.png")
        dataset_files.sort(key=ops.alphanum_key)
        dataset_files = np.array(dataset_files)
        dataset_params = np.load("train_params.npy")

        n_files = dataset_params.shape[0]

        testset_idxs = np.random.choice(range(n_files), self.batch_size)
        test_imgs = ops.load_imgbatch(dataset_files[testset_idxs])
        training_step = 0

        self.session.run(tf.initialize_all_variables())
        for epoch in xrange(self.n_iterations):

            rand_idxs = np.random.permutation(range(n_files))
            n_batches = n_files // self.batch_size

            for batch_i in xrange(n_batches):
                idxs_i = rand_idxs[batch_i * self.batch_size: (batch_i + 1) * self.batch_size]
                imgs_batch = ops.load_imgbatch(dataset_files[idxs_i])
                self.session.run(self.optimizer, feed_dict={self.img_params: dataset_params[idxs_i, :],
                                                            self.final_image: imgs_batch})
                training_step += 1

                current_loss = self.session.run(self.loss, feed_dict={self.img_params: dataset_params[testset_idxs, :],
                                                                      self.final_image: test_imgs})

                print "Epoch {}/{}, Batch {}/{}, Loss {}".format(epoch + 1, self.n_iterations,
                                                                 batch_i + 1, n_batches, current_loss)

                # Save checkpoint
                if training_step % 1000 == 0:
                    if not os.path.exists("checkpoint"):
                        print "Checkpoint folder not found. Creating one..."
                        os.makedirs("checkpoint")
                        print "Done."
                    self.saver.save(self.session, 'checkpoint/model.ckpt', global_step=training_step)
コード例 #3
0
 def test(self, path):
     self.load()
     imgs_path = glob.glob(path)
     imgs_path.sort(key=ops.alphanum_key)
     imgs_path = np.array(imgs_path)
     imgs_batch = ops.load_imgbatch(imgs_path[range(8 * 337, 8 * 337 + 64)],
                                    color=False)
     encs = self.z.eval(session=self.session,
                        feed_dict={
                            self.images: imgs_batch,
                            self.train_flag: False
                        })
     np.save("inverse_encs.npy", encs)
コード例 #4
0
ファイル: RenderNet.py プロジェクト: a-programr/CS726project
    def test(self):
        test_files = glob.glob("data/test/*.png")
        test_files.sort(key=ops.alphanum_key)
        test_files = np.array(test_files)
        test_params = np.load("test_params.npy")

        n_files = test_params.shape[0]

        test_idxs = np.random.choice(range(n_files), self.batch_size)
        test_imgs = ops.load_imgbatch(test_files[test_idxs])

        ops.save_images(test_imgs, [8, 8], 'ground_truth.png')

        result_imgs = self.forward_batch(test_params[test_idxs, :])
        ops.save_images(result_imgs, [8, 8], 'test_results.png')
コード例 #5
0
    def train(self):
        if not os.path.exists(os.path.join("data", self.dataset_name)):
            print "No GAN training files found. Training aborted. =("
            return

        dataset_files = glob.glob("data/" + self.dataset_name + "/*.png")
        dataset_files = np.array(dataset_files)
        n_files = dataset_files.shape[0]
        sample_z = np.random.uniform(-1, 1, [self.batch_size, self.z_size])
        training_step = 0

        self.session.run(tf.initialize_all_variables())
        self.load()
        for epoch in xrange(self.n_iterations):

            rand_idxs = np.random.permutation(range(n_files))
            n_batches = n_files // self.batch_size

            for batch_i in xrange(n_batches):
                idxs_i = rand_idxs[batch_i * self.batch_size:(batch_i + 1) *
                                   self.batch_size]
                imgs_batch = ops.load_imgbatch(dataset_files[idxs_i],
                                               color=False)
                #imgs_batch = ops.load_voxelbatch(dataset_files[idxs_i])
                batch_z = np.random.uniform(-1, 1,
                                            [self.batch_size, self.z_size])

                dloss_fake = self.D_fake.eval(session=self.session,
                                              feed_dict={
                                                  self.z: batch_z,
                                                  self.train_flag: False
                                              })
                dloss_real = self.D_real.eval(session=self.session,
                                              feed_dict={
                                                  self.images: imgs_batch,
                                                  self.train_flag: False
                                              })
                gloss = self.G_loss.eval(session=self.session,
                                         feed_dict={
                                             self.z: batch_z,
                                             self.images: imgs_batch,
                                             self.train_flag: False
                                         })

                train_discriminator = True

                margin = 0.8
                dacc_real = np.mean(dloss_real)
                dacc_fake = np.mean(np.ones_like(dloss_fake) - dloss_fake)
                dacc = (dacc_real + dacc_fake) * 0.5
                #print np.mean(dloss_real)
                #print np.mean(dloss_fake)
                if dacc > margin:
                    train_discriminator = False
                #if dloss_fake > 1.0-margin or dloss_real > 1.0-margin:
                #    train_generator = False
                #if train_discriminator is False and train_generator is False:
                #    train_generator = train_discriminator = True

                print "EPOCH[{}], BATCH[{}/{}]".format(epoch, batch_i,
                                                       n_batches)
                print "Discriminator avg acc: {}".format(dacc)
                print "Discriminator real mean: {}".format(np.mean(dloss_real))
                print "Discriminator fake mean: {}".format(np.mean(dloss_fake))
                print "Generator Loss:{}".format(gloss)

                # Update discriminator
                if train_discriminator:
                    print "***Discriminator trained.***"
                    self.session.run(self.D_optim,
                                     feed_dict={
                                         self.images: imgs_batch,
                                         self.z: batch_z,
                                         self.train_flag: True
                                     })
                # Update generator
                #if dacc > 0.9:
                #    self.session.run(self.G_optim_classic, feed_dict={self.z: batch_z})
                #if dacc > margin + 1.0:
                self.session.run(self.G_optim_classic,
                                 feed_dict={
                                     self.z: batch_z,
                                     self.images: imgs_batch,
                                     self.train_flag: True
                                 })
                #self.session.run(self.G_optim, feed_dict={self.z: batch_z, self.images: imgs_batch, self.train_flag: True})

                if batch_i % 50 == 0:
                    rendered_images = self.G.eval(session=self.session,
                                                  feed_dict={
                                                      self.z: sample_z,
                                                      self.images: imgs_batch,
                                                      self.train_flag: False
                                                  })
                    rendered_images = np.array(rendered_images)

                    voxels = self.voxels.eval(session=self.session,
                                              feed_dict={
                                                  self.z: sample_z,
                                                  self.images: imgs_batch,
                                                  self.train_flag: False
                                              })
                    voxels = np.array(voxels)

                    create_folder("results/{}".format(self.dataset_name))
                    ops.save_images(
                        rendered_images, [8, 8], "results/{}/{}.png".format(
                            self.dataset_name, epoch * n_batches + batch_i))
                    ops.save_images(imgs_batch, [8, 8], "sanity_chairs.png")
                    ops.save_voxels(voxels,
                                    "results/{}".format(self.dataset_name))

                    print "Saving checkpoint..."
                    create_folder('checkpoint/{}'.format(self.dataset_name))
                    self.saver.save(self.session,
                                    'checkpoint/{}/model.ckpt'.format(
                                        self.dataset_name),
                                    global_step=training_step)
                    print "***CHECKPOINT SAVED***"
                training_step += 1

                self.history["generator"].append(gloss)
                self.history["discriminator_real"].append(dloss_real)
                self.history["discriminator_fake"].append(dloss_fake)

        np.save(os.path.join(self.logpath, "generator.npy"),
                np.array(self.history["generator"]))
        np.save(os.path.join(self.logpath, "discriminator_real.npy"),
                np.array(self.history["discriminator_real"]))
        np.save(os.path.join(self.logpath, "discriminator_fake.npy"),
                np.array(self.history["discriminator_fake"]))
コード例 #6
0
    def train(self):
        if not os.path.exists(os.path.join("data", "gan")):
            print "No GAN training files found. Training aborted. =("
            return

        dataset_files = glob.glob("data/gan/*.png")
        dataset_files = np.array(dataset_files)
        n_files = dataset_files.shape[0]
        sample_z = np.random.uniform(-1, 1, [self.batch_size, self.z_size])

        self.session.run(tf.initialize_all_variables())
        self.rendernet.load('checkpoint')
        for epoch in xrange(self.n_iterations):

            rand_idxs = np.random.permutation(range(n_files))
            n_batches = n_files // self.batch_size

            for batch_i in xrange(n_batches):
                idxs_i = rand_idxs[batch_i * self.batch_size:(batch_i + 1) *
                                   self.batch_size]
                imgs_batch = ops.load_imgbatch(dataset_files[idxs_i])
                batch_z = np.random.uniform(-1, 1,
                                            [self.batch_size, self.z_size])

                dloss_fake = self.D_loss_fake.eval(session=self.session,
                                                   feed_dict={self.z: batch_z})
                dloss_real = self.D_loss_real.eval(
                    session=self.session, feed_dict={self.images: imgs_batch})
                gloss = self.G_loss.eval(session=self.session,
                                         feed_dict={self.z: batch_z})

                train_discriminator = True
                train_generator = True

                margin = 0.3
                if dloss_fake < margin or dloss_real < margin:
                    train_discriminator = False
                if dloss_fake > 1.0 - margin or dloss_real > 1.0 - margin:
                    train_generator = False
                if train_discriminator is False and train_generator is False:
                    train_generator = train_discriminator = True

                # Update discriminator
                if train_discriminator:
                    self.session.run(self.D_optim,
                                     feed_dict={
                                         self.images: imgs_batch,
                                         self.z: batch_z
                                     })
                # Update generator
                if train_generator:
                    for i in xrange(5):
                        self.session.run(self.G_optim,
                                         feed_dict={self.z: batch_z})

                if batch_i % 10 == 0:
                    rendered_images = self.G.eval(session=self.session,
                                                  feed_dict={self.z: sample_z})
                    rendered_images = np.array(rendered_images)
                    ops.save_images(
                        rendered_images, [8, 8],
                        "results/gancubes{}.png".format(epoch * n_batches +
                                                        batch_i))

                print "EPOCH[{}], BATCH[{}/{}]".format(epoch, batch_i,
                                                       n_batches)
                print "Discriminator Loss - Real:{} / Fake:{} - Total:{}".format(
                    dloss_real, dloss_fake, dloss_real + dloss_fake)
                print "Generator Loss:{}".format(gloss)