def test_get_simple_gaussian_embedder(self):
    # Shape = [4, 2, 3].
    input_features = tf.constant([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
                                  [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
                                  [[13.0, 14.0, 15.0], [16.0, 17.0, 18.0]],
                                  [[19.0, 20.0, 21.0], [22.0, 23.0, 24.0]]])
    embedder_fn = models.get_embedder(
        base_model_type=common.BASE_MODEL_TYPE_SIMPLE,
        embedding_type=common.EMBEDDING_TYPE_GAUSSIAN,
        num_embedding_components=3,
        embedding_size=16,
        num_embedding_samples=32,
        is_training=True,
        weight_max_norm=0.0)
    outputs, activations = embedder_fn(input_features)

    self.assertCountEqual(outputs.keys(), [
        common.KEY_EMBEDDING_MEANS,
        common.KEY_EMBEDDING_STDDEVS,
        common.KEY_EMBEDDING_SAMPLES,
    ])
    self.assertAllEqual(outputs[common.KEY_EMBEDDING_MEANS].shape.as_list(),
                        [4, 2, 3, 16])
    self.assertAllEqual(outputs[common.KEY_EMBEDDING_STDDEVS].shape.as_list(),
                        [4, 2, 3, 16])
    self.assertAllEqual(outputs[common.KEY_EMBEDDING_SAMPLES].shape.as_list(),
                        [4, 2, 3, 32, 16])
    self.assertCountEqual(activations.keys(), ['base_activations'])
    self.assertAllEqual(activations['base_activations'].shape.as_list(),
                        [4, 2, 1024])
示例#2
0
def main(_):
    """Runs inference."""
    keypoint_profile_2d = (keypoint_profiles.create_keypoint_profile_or_die(
        FLAGS.input_keypoint_profile_name_2d))

    g = tf.Graph()
    with g.as_default():
        keypoints_2d, keypoint_masks_2d = read_inputs(keypoint_profile_2d)

        model_inputs, _ = input_generator.create_model_input(
            keypoints_2d,
            keypoint_masks_2d=keypoint_masks_2d,
            keypoints_3d=None,
            model_input_keypoint_type=common.
            MODEL_INPUT_KEYPOINT_TYPE_2D_INPUT,
            model_input_keypoint_mask_type=FLAGS.
            model_input_keypoint_mask_type,
            keypoint_profile_2d=keypoint_profile_2d,
            # Fix seed for determinism.
            seed=1)

        embedder_fn = models.get_embedder(
            base_model_type=FLAGS.base_model_type,
            embedding_type=FLAGS.embedding_type,
            num_embedding_components=FLAGS.num_embedding_components,
            embedding_size=FLAGS.embedding_size,
            num_embedding_samples=FLAGS.num_embedding_samples,
            is_training=False,
            num_fc_blocks=FLAGS.num_fc_blocks,
            num_fcs_per_block=FLAGS.num_fcs_per_block,
            num_hidden_nodes=FLAGS.num_hidden_nodes,
            num_bottleneck_nodes=FLAGS.num_bottleneck_nodes,
            weight_max_norm=FLAGS.weight_max_norm)

        outputs, _ = embedder_fn(model_inputs)

        if FLAGS.use_moving_average:
            variables_to_restore = (
                pipeline_utils.get_moving_average_variables_to_restore())
            saver = tf.train.Saver(variables_to_restore)
        else:
            saver = tf.train.Saver()

        scaffold = tf.train.Scaffold(init_op=tf.global_variables_initializer(),
                                     saver=saver)
        session_creator = tf.train.ChiefSessionCreator(
            scaffold=scaffold,
            master=FLAGS.master,
            checkpoint_filename_with_path=FLAGS.checkpoint_path)

        with tf.train.MonitoredSession(session_creator=session_creator,
                                       hooks=None) as sess:
            outputs_result = sess.run(outputs)

    tf.gfile.MakeDirs(FLAGS.output_dir)
    for key in [
            common.KEY_EMBEDDING_MEANS, common.KEY_EMBEDDING_STDDEVS,
            common.KEY_EMBEDDING_SAMPLES
    ]:
        if key in outputs_result:
            output = outputs_result[key]
            np.savetxt(os.path.join(FLAGS.output_dir, key + '.csv'),
                       output.reshape([output.shape[0], -1]),
                       delimiter=',')