Beispiel #1
0
 def _get_blurred_dataset(self):
     if FLAGS.blur != 0:
         current_sigma = self._get_blur_sigma()
         if current_sigma != self._last_blur:
             self._last_blur = current_sigma
             self._blurred_dataset = inp.apply_gaussian(self.dataset,
                                                        sigma=current_sigma)
     return self._blurred_dataset if self._blurred_dataset is not None else self.dataset
Beispiel #2
0
 def _get_blurred_dataset(self):
   if FLAGS.blur_sigma != 0:
     current_sigma = self._get_blur_sigma()
     if current_sigma != self._last_blur_sigma:
       print('sigmas current:%f last:%f' % (current_sigma,  self._last_blur_sigma))
       self._last_blur_sigma = current_sigma
       self._blurred_dataset = inp.apply_gaussian(self.dataset, sigma=current_sigma)
   return self._blurred_dataset if self._blurred_dataset is not None else self.dataset
Beispiel #3
0
 def _get_blurred_dataset(self):
     if FLAGS.blur != 0:
         current_sigma = self._get_blur_sigma()
         if current_sigma != self._last_blur:
             # print(self._last_blur, current_sigma)
             self._last_blur = current_sigma
             self._blurred_dataset = inp.apply_gaussian(self.train_set,
                                                        sigma=current_sigma)
             ut.print_info('blur s:%.1f[%.1f>%.1f]' %
                           (current_sigma, self.train_set[2, 10, 10, 0],
                            self._blurred_dataset[2, 10, 10, 0]))
         return self._blurred_dataset if self._blurred_dataset is not None else self.train_set
     return self.train_set
Beispiel #4
0
 def evaluate(self, sess, take):
   encoded, reconstructed = None, None
   blurred = inp.apply_gaussian(self.test_set, self._get_blur_sigma())
   for i in range(int(len(self.test_set)/FLAGS.batch_size)):
     batch = blurred[i*FLAGS.batch_size: (i+1)*FLAGS.batch_size]
     encoding, reconstruction = sess.run(
       [self._encode, self._decode],
       feed_dict={self._input: batch})
     encoded = self._concatenate(encoded, encoding)
     reconstructed = self._concatenate(reconstructed, reconstruction, take=take)
   # reconstructed = self._restore_distribution(reconstructed)
   # blurred = self._restore_distribution(blurred[:take])
   return encoded, reconstructed, blurred
Beispiel #5
0
    def evaluate(self, sess, take):
        digest = Bunch(encoded=None,
                       reconstructed=None,
                       source=None,
                       loss=.0,
                       eval_loss=.0,
                       dumb_loss=.0)
        blurred = inp.apply_gaussian(self.test_set, self._get_blur_sigma())
        # Encode
        for i, batch in enumerate(self._batch_generator(blurred,
                                                        shuffle=False)):
            encoding = self.encode.eval(feed_dict={self.input: batch[0]})
            digest.encoded = ut.concatenate(digest.encoded, encoding)
        # Save encoding for visualization
        encoded_no_nan = np.nan_to_num(digest.encoded)
        self.embedding_assign.eval(
            feed_dict={self.embedding_test_ph: encoded_no_nan})
        try:
            self.embedding_saver.save(sess,
                                      self.get_checkpoint_path() + EMB_SUFFIX)
        except:
            ut.print_info("Unexpected error: %s" % str(sys.exc_info()[0]),
                          color=33)

        # Calculate expected evaluation
        expected = digest.encoded[1:-1] * 2 - digest.encoded[:-2]
        average = 0.5 * (digest.encoded[1:-1] + digest.encoded[:-2])
        digest.size = len(expected)
        # evaluation summaries
        self.summary_writer.add_summary(self.eval_summs.eval(
            feed_dict={self.blur_ph: self._get_blur_sigma()}),
                                        global_step=self.get_past_epochs())
        # evaluation losses
        for p in self._batch_permutation_generator(digest.size, shuffle=False):
            digest.loss += self.eval_loss.eval(
                feed_dict={
                    self.encoding: digest.encoded[p + 2],
                    self.target: blurred[p + 2]
                })
            digest.eval_loss += self.eval_loss.eval(feed_dict={
                self.encoding: expected[p],
                self.target: blurred[p + 2]
            })
            digest.dumb_loss += self.loss_ae.eval(feed_dict={
                self.input: blurred[p],
                self.target: blurred[p + 2]
            })

        # for batch in self._batch_generator(blurred, batches=1):
        #   digest.source = batch[1][:take]
        #   digest.reconstructed = self.decode.eval(feed_dict={self.input: batch[0]})[:take]

        # Reconstruction visualizations
        for p in self._batch_permutation_generator(digest.size,
                                                   shuffle=True,
                                                   batches=1):
            self.visualization_batch_perm = self.visualization_batch_perm if self.visualization_batch_perm is not None else p
            p = self.visualization_batch_perm
            digest.source = self.eval_decode.eval(
                feed_dict={self.encoding: expected[p]})[:take]
            digest.source = blurred[(p + 2)[:take]]
            digest.reconstructed = self.eval_decode.eval(
                feed_dict={self.encoding: average[p]})[:take]
            self._eval_image_summaries(blurred[p], digest.encoded[p],
                                       average[p], expected[p])

        digest.dumb_loss = guard_nan(digest.dumb_loss)
        digest.eval_loss = guard_nan(digest.eval_loss)
        digest.loss = guard_nan(digest.loss)
        return digest
Beispiel #6
0
    def train(self):
        self.fetch_datasets()
        if FLAGS.model == AUTOENCODER:
            self.build_ae_model()
        elif FLAGS.model == PREDICTIVE:
            self.build_predictive_model()
        else:
            self.build_denoising_model()
        self._init_optimizer()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            self._on_training_start(sess)

            try:
                for current_epoch in range(FLAGS.max_epochs):
                    start = time.time()
                    full_set_blur = len(self.train_set) < 50000
                    ds = self._get_blurred_dataset(
                    ) if full_set_blur else self.train_set
                    if FLAGS.model == AUTOENCODER:

                        # Autoencoder Training
                        for batch in self._batch_generator():
                            summs, encoding, reconstruction, loss, _, step = sess.run(
                                [
                                    self.summs_train, self.encode, self.decode,
                                    self.loss_ae, self.train_ae, self.step
                                ],
                                feed_dict={
                                    self.input: batch[0],
                                    self.target: batch[1]
                                })
                            self._on_batch_finish(summs, loss, batch, encoding,
                                                  reconstruction)

                    else:

                        # Predictive and Denoising training
                        for batch_indexes in self._batch_permutation_generator(
                                len(ds) - 2):
                            batch = np.stack(
                                (ds[batch_indexes], ds[batch_indexes + 1],
                                 ds[batch_indexes + 2]))
                            if not full_set_blur:
                                batch = np.stack(
                                    (inp.apply_gaussian(
                                        ds[batch_indexes],
                                        sigma=self._get_blur_sigma()),
                                     inp.apply_gaussian(
                                         ds[batch_indexes + 1],
                                         sigma=self._get_blur_sigma()),
                                     inp.apply_gaussian(
                                         ds[batch_indexes + 2],
                                         sigma=self._get_blur_sigma())))

                            summs, loss, _ = sess.run([
                                self.summs_train, self.loss_total, self._train
                            ],
                                                      feed_dict={
                                                          self.inputs: batch,
                                                          self.targets: batch
                                                      })
                            self._on_batch_finish(summs, loss)

                    self._on_epoch_finish(current_epoch, start, sess)
                self._on_training_finish(sess)
            except KeyboardInterrupt:
                self._on_training_abort(sess)