Esempio n. 1
0
    def _build_toy_problem(self):
        """Construct a toy linear regression problem.

    Initial loss should be,
      2.5 = 0.5 * (1^2 + 2^2)

    Returns:
      loss: 0-D Tensor representing loss to be minimized.
      accuracy: 0-D Tensors representing model accuracy.
      layer_collection: LayerCollection instance describing model architecture.
    """
        x = np.asarray([[1.], [2.]]).astype(np.float32)
        y = np.asarray([1., 2.]).astype(np.float32)
        x, y = (tf.data.Dataset.from_tensor_slices(
            (x, y)).repeat(100).batch(2).make_one_shot_iterator().get_next())
        w = tf.get_variable("w",
                            shape=[1, 1],
                            initializer=tf.zeros_initializer())
        y_hat = tf.matmul(x, w)
        loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
        accuracy = loss

        layer_collection = lc.LayerCollection()
        layer_collection.register_fully_connected(params=w,
                                                  inputs=x,
                                                  outputs=y_hat)
        layer_collection.register_normal_predictive_distribution(y_hat)

        return loss, accuracy, layer_collection
Esempio n. 2
0
    def testBuildModel(self):
        with tf.Graph().as_default():
            x = tf.placeholder(tf.float32, [None, 6, 6, 3])
            y = tf.placeholder(tf.int64, [None])
            layer_collection = lc.LayerCollection()
            loss, accuracy = convnet.build_model(
                x, y, num_labels=5, layer_collection=layer_collection)

            # Ensure layers and logits were registered.
            self.assertEqual(len(layer_collection.fisher_blocks), 3)
            self.assertEqual(len(layer_collection.losses), 1)

            # Ensure inference doesn't crash.
            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                feed_dict = {
                    x: np.random.randn(10, 6, 6, 3).astype(np.float32),
                    y: np.random.randint(5, size=10).astype(np.int64),
                }
                sess.run([loss, accuracy], feed_dict=feed_dict)