Esempio n. 1
0
def main(_):
    # Parse config dict from yaml config files / command line flags.
    config = util.ParseConfigsToLuaTable(FLAGS.config_paths,
                                         FLAGS.model_params)

    # Get tables to embed.
    query_records_dir = FLAGS.query_records_dir
    query_records = util.GetFilesRecursively(query_records_dir)

    target_records_dir = FLAGS.target_records_dir
    target_records = util.GetFilesRecursively(target_records_dir)

    height = config.data.raw_height
    width = config.data.raw_width
    mode = FLAGS.mode
    if mode == 'multi':
        # Generate videos where target set is composed of multiple videos.
        MultiImitationVideos(query_records, target_records, config, height,
                             width)
    elif mode == 'single':
        # Generate videos where target set is a single video.
        SingleImitationVideos(query_records, target_records, config, height,
                              width)
    elif mode == 'same':
        # Generate videos where target set is the same as query, but diff view.
        SameSequenceVideos(query_records, config, height, width)
    else:
        raise ValueError('Unknown mode %s' % mode)
Esempio n. 2
0
def get_labeled_tables(config):
  """Gets either labeled test or validation tables, based on flags."""
  # Get a list of filenames corresponding to labeled data.
  mode = FLAGS.mode
  if mode == 'validation':
    labeled_tables = util.GetFilesRecursively(config.data.labeled.validation)
  elif mode == 'test':
    labeled_tables = util.GetFilesRecursively(config.data.labeled.test)
  else:
    raise ValueError('Unknown dataset: %s' % mode)
  return labeled_tables
Esempio n. 3
0
  def train(self):
    """Runs training."""
    # Get a list of training tfrecords.
    config = self._config
    training_dir = config.data.training
    training_records = util.GetFilesRecursively(training_dir)

    # Define batch size.
    self._batch_size = config.data.batch_size

    # Create a subclass-defined training input function.
    train_input_fn = self.construct_input_fn(
        training_records, is_training=True)

    # Create the estimator.
    estimator = self._build_estimator(is_training=True)

    train_hooks = None
    if config.use_tpu:
      # TPU training initializes pretrained weights using a custom train hook.
      train_hooks = []
      if tf.train.latest_checkpoint(self._logdir) is None:
        train_hooks.append(
            InitFromPretrainedCheckpointHook(
                config[config.embedder_strategy].pretrained_checkpoint))

    # Run training.
    estimator.train(input_fn=train_input_fn, hooks=train_hooks,
                    steps=config.learning.max_step)
    def evaluate(self):
        """Runs `Estimator` validation.
    """
        config = self._config

        # Get a list of validation tfrecords.
        validation_dir = config.data.validation
        validation_records = util.GetFilesRecursively(validation_dir)

        # Define batch size.
        self._batch_size = config.data.batch_size

        # Create a subclass-defined training input function.
        validation_input_fn = self.construct_input_fn(validation_records,
                                                      False)

        # Create the estimator.
        estimator = self._build_estimator(is_training=False)

        # Run validation.
        eval_batch_size = config.data.batch_size
        num_eval_samples = config.val.num_eval_samples
        num_eval_batches = int(num_eval_samples / eval_batch_size)
        estimator.evaluate(input_fn=validation_input_fn,
                           steps=num_eval_batches)
Esempio n. 5
0
def main(_):
    # Parse config dict from yaml config files / command line flags.
    config = util.ParseConfigsToLuaTable(FLAGS.config_paths,
                                         FLAGS.model_params)
    num_views = config.data.num_views

    validation_records = util.GetFilesRecursively(config.data.validation)
    batch_size = config.data.batch_size

    checkpointdir = FLAGS.checkpointdir

    # If evaluating a specific checkpoint, do that.
    if FLAGS.checkpoint_iter:
        checkpoint_path = os.path.join('%s/model.ckpt-%s' %
                                       (checkpointdir, FLAGS.checkpoint_iter))
        evaluate_once(config, checkpointdir, validation_records,
                      checkpoint_path, batch_size, num_views)
    else:
        for checkpoint_path in tf.contrib.training.checkpoints_iterator(
                checkpointdir):
            evaluate_once(config, checkpointdir, validation_records,
                          checkpoint_path, batch_size, num_views)
def main(_):
    """Runs main labeled eval loop."""
    # Parse config dict from yaml config files / command line flags.
    config = util.ParseConfigsToLuaTable(FLAGS.config_paths,
                                         FLAGS.model_params)

    # Choose an estimator based on training strategy.
    checkpointdir = FLAGS.checkpointdir
    checkpoint_path = os.path.join('%s/model.ckpt-%s' %
                                   (checkpointdir, FLAGS.checkpoint_iter))
    estimator = get_estimator(config, checkpointdir)

    # Get records to embed.
    validation_dir = FLAGS.embedding_records
    validation_records = util.GetFilesRecursively(validation_dir)

    sequences_to_data = {}
    for (view_embeddings, view_raw_image_strings,
         seqname) in estimator.inference(validation_records,
                                         checkpoint_path,
                                         config.data.embed_batch_size,
                                         num_sequences=FLAGS.num_sequences):
        sequences_to_data[seqname] = {
            'embeddings': view_embeddings,
            'images': view_raw_image_strings,
        }

    all_embeddings = np.zeros((0, config.embedding_size))
    all_ims = []
    all_seqnames = []

    num_embeddings = FLAGS.num_embed
    # Concatenate all views from all sequences into a big flat list.
    for seqname, data in sequences_to_data.iteritems():
        embs = data['embeddings']
        ims = data['images']
        for v in range(config.data.num_views):
            for (emb, im) in zip(embs[v], ims[v]):
                all_embeddings = np.append(all_embeddings, [emb], axis=0)
                all_ims.append(im)
                all_seqnames.append(seqname)

    # Choose N indices uniformly from all images.
    random_indices = range(all_embeddings.shape[0])
    random.shuffle(random_indices)
    viz_indices = random_indices[:num_embeddings]

    # Extract embs.
    viz_embs = np.array(all_embeddings[viz_indices])

    # Extract and decode ims.
    viz_ims = list(np.array(all_ims)[viz_indices])
    decoded_ims = []

    sprite_dim = FLAGS.sprite_dim
    for i, im in enumerate(viz_ims):
        if i % 100 == 0:
            print('Decoding image %d/%d.' % (i, num_embeddings))
        nparr_i = np.fromstring(str(im), np.uint8)
        img_np = cv2.imdecode(nparr_i, 1)
        img_np = img_np[..., [2, 1, 0]]

        img_np = imresize(img_np, [sprite_dim, sprite_dim, 3])
        decoded_ims.append(img_np)
    decoded_ims = np.array(decoded_ims)

    # Extract sequence names.
    outdir = FLAGS.outdir

    # The embedding variable, which needs to be stored
    # Note this must a Variable not a Tensor!
    embedding_var = tf.Variable(viz_embs, name='viz_embs')

    with tf.Session() as sess:
        sess.run(embedding_var.initializer)
        summary_writer = tf.summary.FileWriter(outdir)
        config = projector.ProjectorConfig()
        embedding = config.embeddings.add()
        embedding.tensor_name = embedding_var.name

        # Comment out if you don't want sprites
        embedding.sprite.image_path = os.path.join(outdir, 'sprite.png')
        embedding.sprite.single_image_dim.extend(
            [decoded_ims.shape[1], decoded_ims.shape[1]])

        projector.visualize_embeddings(summary_writer, config)
        saver = tf.train.Saver([embedding_var])
        saver.save(sess, os.path.join(outdir, 'model2.ckpt'), 1)

    sprite = images_to_sprite(decoded_ims)
    imsave(os.path.join(outdir, 'sprite.png'), sprite)