Ejemplo n.º 1
0
def search_learning_rate(lrs=[0.001, 0.0004, 0.0001, 0.00003,],
                         epochs=500):
  FLAGS.suffix = 'grid_lr'
  ut.print_info('START: search_learning_rate', color=31)

  best_result, best_args = None, None
  result_summary, result_list = [], []

  for lr in lrs:
    ut.print_info('STEP: search_learning_rate', color=31)
    FLAGS.learning_rate = lr
    model = model_class()
    meta, accuracy_by_epoch = model.train(epochs)
    result_list.append((ut.to_file_name(meta), accuracy_by_epoch))
    best_accuracy = np.min(accuracy_by_epoch)
    result_summary.append('\n\r lr:%2.5f \tq:%.2f' % (lr, best_accuracy))
    if best_result is None or best_result > best_accuracy:
      best_result = best_accuracy
      best_args = lr

  meta = {'suf': 'grid_lr_bs', 'e': epochs, 'lrs': lrs, 'acu': best_result,
          'bs': FLAGS.batch_size, 'h': model.get_layer_info()}
  pickle.dump(result_list, open('search_learning_rate%d.txt' % epochs, "wb"))
  ut.plot_epoch_progress(meta, result_list)
  print(''.join(result_summary))
  ut.print_info('BEST Q: %d IS ACHIEVED FOR LR: %f' % (best_result, best_args), 36)
Ejemplo n.º 2
0
    def _register_epoch(self, epoch, total_epochs, elapsed, sess):
        if is_stopping_point(epoch, total_epochs, FLAGS.save_every):
            self._saver.save(sess, self.get_checkpoint_path())

        accuracy = 100000 * np.sqrt(self._epoch_stats['total_loss'] / np.prod(
            self._batch_shape) / FLAGS.epoch_size)

        if is_stopping_point(epoch, total_epochs, FLAGS.save_encodings_every):
            digest = self.evaluate(sess, take=self.MAX_IMAGES)
            data = {
                'enc': np.asarray(digest[0]),
                'rec': np.asarray(digest[1]),
                'blu': np.asarray(digest[2][:self.MAX_IMAGES])
            }

            meta = {
                'suf': 'encodings',
                'e': '%06d' % int(self.get_past_epochs()),
                'er': int(accuracy)
            }
            projection_file = ut.to_file_name(meta, FLAGS.save_path)
            np.save(projection_file, data)
            vis.plot_encoding_crosssection(data['enc'], FLAGS.save_path, meta,
                                           data['blu'], data['rec'])

        self._stats['epoch_accuracy'].append(accuracy)
        self.print_epoch_info(accuracy, epoch, total_epochs, elapsed)
        if epoch + 1 != total_epochs:
            self._epoch_stats = self._get_stats_template()
Ejemplo n.º 3
0
  def _register_epoch(self, epoch, total_epochs, elapsed, sess):
    if is_stopping_point(epoch, total_epochs, FLAGS.save_every):
      self._saver.save(sess, self.get_checkpoint_path())

    accuracy = 100000 * np.sqrt(self._epoch_stats['total_loss'] / np.prod(self._batch_shape) / FLAGS.epoch_size)

    if is_stopping_point(epoch, total_epochs, FLAGS.save_encodings_every):
      digest = self.evaluate(sess, take=self.MAX_IMAGES)
      data = {
        'enc': np.asarray(digest[0]),
        'rec': np.asarray(digest[1]),
        'blu': np.asarray(digest[2][:self.MAX_IMAGES])
      }
      # save
      meta = {'suf': 'encodings', 'e': '%06d' % int(self.get_past_epochs()), 'er': int(accuracy)}
      projection_file = ut.to_file_name(meta, FLAGS.save_path)
      np.save(projection_file, data)
      # visualize
      # if DEV:
      vis.visualize_encoding_cross(digest[0], FLAGS.save_path, meta, data['blu'], data['rec'])

    self._stats['epoch_accuracy'].append(accuracy)
    self.print_epoch_info(accuracy, epoch, total_epochs, elapsed)
    if epoch + 1 != total_epochs:
      self._epoch_stats = self._get_stats_template()
Ejemplo n.º 4
0
def visualize_encoding(encodings, folder=None, meta={}, original=None, reconstruction=None):
  if np.max(original) < 10:
    original = (original * 255).astype(np.uint8)
  # print('np', np.max(original), np.max(reconstruction), np.min(original), np.min(reconstruction),
  #       original.dtype, reconstruction.dtype)
  file_path = None
  if folder:
    meta['postfix'] = 'pca'
    file_path = ut.to_file_name(meta, folder, 'jpg')
  encodings = manual_pca(encodings)

  if original is not None:
    assert len(original) == len(reconstruction)
    fig = plt.figure()

    # print('reco max:', np.max(reconstruction))
    column_picture, height = stitch_images(original, reconstruction)
    subplot, proportion = (122, 1) if encodings.shape[1] <= 3 else (155, 3)
    picture = reshape_images(column_picture, height, proportion=proportion)
    if picture.shape[-1] == 1:
      picture = picture.squeeze()
    plt.subplot(subplot).imshow(picture)

    visualize_encodings(encodings, file_name=file_path, fig=fig, grid=(3, 5), skip_every=5)
  else:
    visualize_encodings(encodings, file_name=file_path)
Ejemplo n.º 5
0
def search_batch_size(bss=[50], strides=[1, 2, 5, 20], epochs=500):
  FLAGS.suffix = 'grid_bs'
  ut.print_info('START: search_batch_size', color=31)
  best_result, best_args = None, None
  result_summary, result_list = [], []

  print(bss)
  for bs in bss:
    for stride in strides:
      ut.print_info('STEP: search_batch_size %d %d' % (bs, stride), color=31)
      FLAGS.batch_size = bs
      FLAGS.stride = stride
      model = model_class()
      start = dt.now()
      # meta, accuracy_by_epoch = model.train(epochs * int(bs / bss[0]))
      meta, accuracy_by_epoch = model.train(epochs)
      meta['str'] = stride
      meta['t'] = int((dt.now() - start).seconds)
      result_list.append((ut.to_file_name(meta)[22:], accuracy_by_epoch))
      best_accuracy = np.min(accuracy_by_epoch)
      result_summary.append('\n\r bs:%d \tst:%d \tq:%.2f' % (bs, stride, best_accuracy))
      if best_result is None or best_result > best_accuracy:
        best_result = best_accuracy
        best_args = (bs, stride)

  meta = {'suf': 'grid_batch_bs', 'e': epochs, 'acu': best_result,
          'h': model.get_layer_info()}
  pickle.dump(result_list, open('search_batch_size%d.txt' % epochs, "wb"))
  ut.plot_epoch_progress(meta, result_list)
  print(''.join(result_summary))

  ut.print_info('BEST Q: %d IS ACHIEVED FOR bs, st: %d %d' % (best_result, best_args[0], best_args[1]), 36)
Ejemplo n.º 6
0
def visualize_encoding_cross(encodings, folder=None, meta={}, original=None, reconstruction=None, interactive=False):
  if np.max(original) < 10:
    print('should not happen')
    original = (original * 255).astype(np.uint8)
  file_path = None

  if folder:
    meta['postfix'] = 'cross'
    file_path = ut.to_file_name(meta, folder, 'jpg')
  encodings = manual_pca(encodings)

  # print('shapes', reconstruction.shape, original.shape)
  fig = None
  if original is not None:
    assert len(original) == len(reconstruction)
    subplot, proportion = visualize_cross_section_with_reco(encodings, fig=fig)
    column_picture, height = stitch_images(original, reconstruction)
    picture = reshape_images(column_picture, height, proportion=proportion)
    if picture.shape[-1] == 1:
      picture = picture.squeeze()
    # print(picture.shape)
    subplot.imshow(picture)
  else:
    visualize_cross_section(encodings, fig=fig)
  if not interactive:
    save_fig(file_path, fig)
  else:
    plt.show()
Ejemplo n.º 7
0
  def train(self, epochs_to_train=5):
    meta = self.get_meta()
    ut.print_time('train started: \n%s' % ut.to_file_name(meta))
    ut.configure_folders(FLAGS, meta)

    self.fetch_datasets(self._activation)
    self.build_model()
    self._register_training_start()

    with tf.Session() as sess:
      sess.run(tf.initialize_all_variables())
      self._saver = tf.train.Saver()

      if FLAGS.load_state and os.path.exists(self.get_checkpoint_path()):
        self._saver.restore(sess, self.get_checkpoint_path())
        ut.print_info('Restored requested. Previous epoch: %d' % self.get_past_epochs(), color=31)

      # MAIN LOOP
      for current_epoch in xrange(epochs_to_train):
        start = time.time()
        feed = self._get_epoch_dataset()
        for _, batch in enumerate(feed):

          encoding, reconstruction, loss, _, _ = sess.run(
            [self._encode, self._decode, self._reco_loss, self._train, self._step],
            feed_dict={self._input: batch[0], self._reconstruction: batch[0]})
          self._register_batch(loss)
        self._register_epoch(current_epoch, epochs_to_train, time.time()-start, sess)
      self._writer = tf.train.SummaryWriter(FLAGS.logdir, sess.graph)
      meta = self._register_training()
    return meta, self._stats['epoch_accuracy']
Ejemplo n.º 8
0
 def _register_training(self):
   best_acc = np.min(self._stats['epoch_accuracy'])
   meta = self.get_meta()
   meta['acu'] = int(best_acc)
   meta['e'] = self.get_past_epochs()
   ut.print_time('Best Quality: %f for %s' % (best_acc, ut.to_file_name(meta)))
   return meta
Ejemplo n.º 9
0
def visualize_encoding(encodings, folder=None, meta={}, original=None, reconstruction=None):
  if np.max(original) < 10:
    original = (original * 255).astype(np.uint8)
  # print('np', np.max(original), np.max(reconstruction), np.min(original), np.min(reconstruction),
  #       original.dtype, reconstruction.dtype)
  file_path = None
  if folder:
    meta['postfix'] = 'pca'
    file_path = ut.to_file_name(meta, folder, 'jpg')
  encodings = manual_pca(encodings)

  if original is not None:
    assert len(original) == len(reconstruction)
    fig = get_figure()

    # print('reco max:', np.max(reconstruction))
    column_picture, height = _stitch_images(original, reconstruction)
    subplot, proportion = (122, 1) if encodings.shape[1] <= 3 else (155, 3)
    picture = _reshape_column_image(column_picture, height, proportion=proportion)
    if picture.shape[-1] == 1:
      picture = picture.squeeze()
    plt.subplot(subplot).set_title("Original/reconstruction")
    plt.subplot(subplot).imshow(picture)
    plt.subplot(subplot).axis('off')

    visualize_encodings(encodings, file_name=file_path, fig=fig, grid=(3, 5), skip_every=5)
  else:
    visualize_encodings(encodings, file_name=file_path)
Ejemplo n.º 10
0
    def train(self, epochs_to_train=5):
        meta = self.get_meta()
        ut.print_time('train started: \n%s' % ut.to_file_name(meta))
        # return meta, np.random.randn(epochs_to_train)
        ut.configure_folders(FLAGS, meta)

        self._dataset, self._filters = self.fetch_datasets(self._activation)
        self.build_model()
        self._register_training_start()

        with tf.Session() as sess:
            sess.run(tf.initialize_all_variables())
            self._saver = tf.train.Saver()

            if FLAGS.load_state and os.path.exists(self.get_checkpoint_path()):
                self._saver.restore(sess, self.get_checkpoint_path())
                ut.print_info('Restored requested. Previous epoch: %d' %
                              self.get_past_epochs(),
                              color=31)

            # MAIN LOOP
            for current_epoch in xrange(epochs_to_train):

                feed, permutation = self._get_epoch_dataset()
                for _, batch in enumerate(feed):
                    filter = batch[1][0]
                    assert batch[1][0, 0] == batch[1][-1, 0]
                    encoding, = sess.run([self._encode],
                                         feed_dict={self._input: batch[0]
                                                    })  # 1.1 encode forward
                    clamped_enc, vae_grad = _clamp(encoding,
                                                   filter)  # 1.2 # clamp

                    sess.run(self._assign_clamped,
                             feed_dict={self._clamped: clamped_enc})
                    reconstruction, loss, clamped_gradient, _ = sess.run(  # 2.1 decode forward+backward
                        [
                            self._decode, self._decoder_loss,
                            self._clamped_grad, self._train_decoder
                        ],
                        feed_dict={
                            self._clamped: clamped_enc,
                            self._reconstruction: batch[0]
                        })

                    declamped_grad = _declamp_grad(
                        vae_grad, clamped_gradient,
                        filter)  # 2.2 prepare gradient
                    _, step = sess.run(                                            # 3.0 encode backward path
                      [self._train_encoder, self._step],
                      feed_dict={self._input: batch[0], self._encoding: encoding-declamped_grad})          # Profit

                    self._register_batch(batch, encoding, reconstruction, loss)
                self._register_epoch(current_epoch, epochs_to_train,
                                     permutation, sess)
            self._writer = tf.train.SummaryWriter(FLAGS.logdir, sess.graph)
            meta = self._register_training()
        return meta, self._stats['epoch_accuracy']
Ejemplo n.º 11
0
 def _register_training(self):
     best_acc = np.min(self._stats['epoch_accuracy'])
     meta = self.get_meta()
     meta['acu'] = int(best_acc)
     meta['e'] = self.get_past_epochs()
     ut.print_time('Best Quality: %f for %s' %
                   (best_acc, ut.to_file_name(meta)))
     self.summary_writer.close()
     return meta
Ejemplo n.º 12
0
def plot_reconstruction(original, reconstruction, meta={'debug': 'true'}, interactive=False):
  # if not interactive:
  #   _get_figure()
  picture = stitch_side_by_side(original, reconstruction)
  plt.imshow(picture)
  if not interactive:
    file_path = ut.to_file_name(meta, FLAGS.save_path, 'jpg')
    save_fig(file_path)
  else:
    plt.draw()
    plt.pause(0.001)
Ejemplo n.º 13
0
def get_template(name: str) -> Toml:
    "Gets the template file from the template direcory"
    # TODO: urls?
    file_name = utils.to_file_name(name)
    template_path = USER_CONFIG_DIR.joinpath("templates", file_name).resolve()
    print(template_path)
    if not template_path.is_file():
        log.error("The template does not exist")
        raise typer.Exit(1)

    return toml.load(template_path)
Ejemplo n.º 14
0
  def train(self, epochs_to_train=5):
    meta = self.get_meta()
    ut.print_time('train started: \n%s' % ut.to_file_name(meta))
    # return meta, np.random.randn(epochs_to_train)
    ut.configure_folders(FLAGS, meta)

    self._dataset, self._filters = self.fetch_datasets(self._activation)
    self.build_model()
    self._register_training_start()

    with tf.Session() as sess:
      sess.run(tf.initialize_all_variables())
      self._saver = tf.train.Saver()

      if FLAGS.load_state and os.path.exists(self.get_checkpoint_path()):
        self._saver.restore(sess, self.get_checkpoint_path())
        ut.print_info('Restored requested. Previous epoch: %d' % self.get_past_epochs(), color=31)

      # MAIN LOOP
      for current_epoch in xrange(epochs_to_train):

        feed, permutation = self._get_epoch_dataset()
        for _, batch in enumerate(feed):
          filter = batch[1][0]
          assert batch[1][0,0] == batch[1][-1,0]
          encoding, = sess.run([self._encode], feed_dict={self._input: batch[0]})   # 1.1 encode forward
          clamped_enc, vae_grad = _clamp(encoding, filter)                          # 1.2 # clamp

          sess.run(self._assign_clamped, feed_dict={self._clamped:clamped_enc})
          reconstruction, loss, clamped_gradient, _ = sess.run(          # 2.1 decode forward+backward
            [self._decode, self._decoder_loss, self._clamped_grad, self._train_decoder],
            feed_dict={self._clamped: clamped_enc, self._reconstruction: batch[0]})

          declamped_grad = _declamp_grad(vae_grad, clamped_gradient, filter) # 2.2 prepare gradient
          _, step = sess.run(                                            # 3.0 encode backward path
            [self._train_encoder, self._step],
            feed_dict={self._input: batch[0], self._encoding: encoding-declamped_grad})          # Profit

          self._register_batch(batch, encoding, reconstruction, loss)
        self._register_epoch(current_epoch, epochs_to_train, permutation, sess)
      self._writer = tf.train.SummaryWriter(FLAGS.logdir, sess.graph)
      meta = self._register_training()
    return meta, self._stats['epoch_accuracy']
Ejemplo n.º 15
0
def search_batch_size(bss=[50], strides=[1, 2, 5, 20], epochs=500):
    FLAGS.suffix = 'grid_bs'
    ut.print_info('START: search_batch_size', color=31)
    best_result, best_args = None, None
    result_summary, result_list = [], []

    print(bss)
    for bs in bss:
        for stride in strides:
            ut.print_info('STEP: search_batch_size %d %d' % (bs, stride),
                          color=31)
            FLAGS.batch_size = bs
            FLAGS.stride = stride
            model = model_class()
            start = dt.now()
            # meta, accuracy_by_epoch = model.train(epochs * int(bs / bss[0]))
            meta, accuracy_by_epoch = model.train(epochs)
            meta['str'] = stride
            meta['t'] = int((dt.now() - start).seconds)
            result_list.append((ut.to_file_name(meta)[22:], accuracy_by_epoch))
            best_accuracy = np.min(accuracy_by_epoch)
            result_summary.append('\n\r bs:%d \tst:%d \tq:%.2f' %
                                  (bs, stride, best_accuracy))
            if best_result is None or best_result > best_accuracy:
                best_result = best_accuracy
                best_args = (bs, stride)

    meta = {
        'suf': 'grid_batch_bs',
        'e': epochs,
        'acu': best_result,
        'h': model.get_layer_info()
    }
    pickle.dump(result_list, open('search_batch_size%d.txt' % epochs, "wb"))
    ut.plot_epoch_progress(meta, result_list)
    print(''.join(result_summary))

    ut.print_info(
        'BEST Q: %d IS ACHIEVED FOR bs, st: %d %d' %
        (best_result, best_args[0], best_args[1]), 36)
Ejemplo n.º 16
0
def search_learning_rate(lrs=[
    0.001,
    0.0004,
    0.0001,
    0.00003,
], epochs=500):
    FLAGS.suffix = 'grid_lr'
    ut.print_info('START: search_learning_rate', color=31)

    best_result, best_args = None, None
    result_summary, result_list = [], []

    for lr in lrs:
        ut.print_info('STEP: search_learning_rate', color=31)
        FLAGS.learning_rate = lr
        model = model_class()
        meta, accuracy_by_epoch = model.train(epochs)
        result_list.append((ut.to_file_name(meta), accuracy_by_epoch))
        best_accuracy = np.min(accuracy_by_epoch)
        result_summary.append('\n\r lr:%2.5f \tq:%.2f' % (lr, best_accuracy))
        if best_result is None or best_result > best_accuracy:
            best_result = best_accuracy
            best_args = lr

    meta = {
        'suf': 'grid_lr_bs',
        'e': epochs,
        'lrs': lrs,
        'acu': best_result,
        'bs': FLAGS.batch_size,
        'h': model.get_layer_info()
    }
    pickle.dump(result_list, open('search_learning_rate%d.txt' % epochs, "wb"))
    ut.plot_epoch_progress(meta, result_list)
    print(''.join(result_summary))
    ut.print_info(
        'BEST Q: %d IS ACHIEVED FOR LR: %f' % (best_result, best_args), 36)
Ejemplo n.º 17
0
def main(_=None, weight_init=tf.random_normal, activation_f=tf.nn.sigmoid, data_min=0, data_scale=1.0, epochs=50,
         learning_rate=0.01, prefix=None):
    tf.reset_default_graph()
    input_placeholder  = tf.placeholder(tf.float32, [BATCH_SIZE, 28, 28, 1])
    output_placeholder = tf.placeholder(tf.float32, [BATCH_SIZE, 28, 28, 1])

    # Grab the data as numpy arrays.
    train_input, train_output = data_utils.mnist(training=True)
    test_input,  test_output  = data_utils.mnist(training=False)
    train_set = ut.mnist_select_n_classes(train_input, train_output, NUM_CLASSES, min=data_min, scale=data_scale)
    test_set  = ut.mnist_select_n_classes(test_input,  test_output,  NUM_CLASSES, min=data_min, scale=data_scale)
    train_input, train_output = train_set[0], train_set[0]
    test_input,  test_output  = test_set[0],  test_set[0]
    ut.print_info('train (min, max): (%f, %f)' % (np.min(train_set[0]), np.max(train_set[0])))
    visual_inputs, visual_output = train_set[0][0:BATCH_SIZE], train_set[0][0:BATCH_SIZE]

    epoch_reconstruction = []

    EPOCH_SIZE = len(train_input) // BATCH_SIZE
    TEST_SIZE = len(test_input) // BATCH_SIZE

    assert_model(input_placeholder, output_placeholder, test_input, test_output, train_input, train_output, visual_inputs, visual_output)

    with pt.defaults_scope(activation_fn=activation_f,
                           # batch_normalize=True,
                           # learned_moments_update_rate=0.0003,
                           # variance_epsilon=0.001,
                           # scale_after_normalization=True
                           ):
        with pt.defaults_scope(phase=pt.Phase.train):
            with tf.variable_scope("model") as scope:
                output_tensor = decoder(encoder(input_placeholder), weight_init=weight_init)

    pretty_loss = loss(output_tensor, output_placeholder)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train = pt.apply_optimizer(optimizer, losses=[pretty_loss])

    init = tf.initialize_all_variables()
    runner = pt.train.Runner(save_path=FLAGS.save_path)

    best_q = 100000
    with tf.Session() as sess:
        sess.run(init)
        for epoch in xrange(epochs):
            # Shuffle the training data.
            additional_info = ''

            if epoch % np.ceil(epochs / 40.0) == 0 or epoch + 1 == epochs:
                reconstruct, loss_value = sess.run([output_tensor, pretty_loss], {input_placeholder: visual_inputs, output_placeholder: visual_output})
                epoch_reconstruction.append(reconstruct)
                additional_info += 'epoch:%d (min, max): (%f %f)' %(epoch, np.min(reconstruct), np.max(reconstruct))

            train_input, train_output = data_utils.permute_data(
                (train_input, train_output))

            runner.train_model(
                train,
                pretty_loss,
                EPOCH_SIZE,
                feed_vars=(input_placeholder, output_placeholder),
                feed_data=pt.train.feed_numpy(BATCH_SIZE, train_input, train_output),
                print_every=None
            )
            accuracy = runner.evaluate_model(
                pretty_loss,
                TEST_SIZE,
                feed_vars=(input_placeholder, output_placeholder),
                feed_data=pt.train.feed_numpy(BATCH_SIZE, test_input, test_output))
            ut.print_time('Accuracy after %2d/%d epoch %.2f; %s' % (epoch + 1, epochs, accuracy, additional_info))
            if best_q > accuracy:
                best_q = accuracy

        save_params = {'suf': 'mn_basic', 'act': activation_f, 'e': epochs, 'opt': optimizer, 'lr': learning_rate,
                       'init': weight_init, 'acu': int(best_q), 'bs': BATCH_SIZE, 'h': HIDDEN_0_SIZE, 'i':prefix}
        ut.reconstruct_images_epochs(np.asarray(epoch_reconstruction), visual_output, save_params=save_params)

    ut.print_time('Best Quality: %f for %s' % (best_q, ut.to_file_name(save_params)))
    ut.reset_start_time()
    return best_q