Esempio n. 1
0
    def test_weights_in_out_none(self):
        """Test case with no weights at all."""
        scope = 'StreamingLossScope'
        targets = tf.constant([[0, 1, 2], [0, 9, 23]], dtype=tf.int32)
        predictions = tf.constant([[0, 1, 2], [0, 9, 23]], dtype=tf.int32)
        values = tf.constant([5, 6, 7], dtype=tf.float32)

        func = mock.Mock()
        func.side_effect = [(values, None)]

        avg = streaming.StreamingAverage()
        avg.compute = mock.MagicMock()

        loss = losses.StreamingLoss(func, avg)
        loss.compute(targets, predictions, scope=scope)

        func.assert_called_once_with(targets, predictions, None)
        avg.compute.assert_called_once()
        args, kwargs = avg.compute.call_args
        act_values, = args
        self.assertEqual(act_values, values)
        self.assertIn('weights', kwargs)
        self.assertEqual(kwargs.pop('weights'), None)
        self.assertIn('scope', kwargs)
        self.assertEqual(kwargs.pop('scope').name, scope)
Esempio n. 2
0
    def test_weights_out_none(self):
        """Test case with no weights returned by the wrapped function."""
        scope = 'MyScope'
        targets = tf.constant([[0, 1, 2], [0, 9, 23]], dtype=tf.int32)
        predictions = tf.constant([[0, 1, 2], [0, 9, 23]], dtype=tf.int32)
        weights = tf.constant([[1, 1, 1], [0, 0, 1]], dtype=tf.float32)
        values = tf.constant([5, 6, 7], dtype=tf.float32)

        func = mock.Mock()
        func.side_effect = [(values, None)]

        avg = streaming.StreamingAverage()
        avg.compute = mock.MagicMock()

        metric = metrics.StreamingMetric(func, avg)
        metric.compute(targets, predictions, weights, scope=scope)

        func.assert_called_once_with(targets, predictions, weights)
        avg.compute.assert_called_once()
        args, kwargs = avg.compute.call_args
        act_values, act_weights_out = args
        self.assertEqual(act_values, values)
        self.assertEqual(act_weights_out, None)
        self.assertIn('scope', kwargs)
        self.assertEqual(kwargs.pop('scope').name, scope)
Esempio n. 3
0
    def test_default(self):
        """Default test for the StreamingAvrage class."""

        # Set up the data.
        values_01 = np.asarray([[1, 2, 3], [4, 5, 1000]], dtype=np.float32)  # pylint: disable=I0011,E1101
        weights_01 = np.asarray([[1, 1, 1], [1, 1, 0]], dtype=np.float32)  # pylint: disable=I0011,E1101
        avg_01 = 3.0
        values_02 = np.asarray([[1000, 1000, 8], [9, 10, 1000]],
                               dtype=np.float32)  # pylint: disable=I0011,E1101
        weights_02 = np.asarray([[0, 0, 1], [1, 1, 0]], dtype=np.float32)  # pylint: disable=I0011,E1101
        avg_02 = 9.0
        avg = 5.25

        # Build the graph.
        values = tf.placeholder(dtype=tf.float32, shape=[2, 3], name='values')
        weights = tf.placeholder(dtype=tf.float32,
                                 shape=[2, 3],
                                 name='weights')
        streaming_avg = streaming.StreamingAverage(name='StreamingAvg')
        streaming_avg.compute(values, weights)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            # First batch.
            feed_dict = {values: values_01, weights: weights_01}
            sess.run(streaming_avg.update_op, feed_dict)
            self.assertEqual(avg_01,
                             sess.run(streaming_avg.batch_value, feed_dict))
            self.assertEqual(avg_01, sess.run(streaming_avg.value, feed_dict))

            # Second batch.
            feed_dict = {
                values: values_02,
                weights: weights_02,
            }
            sess.run(streaming_avg.update_op, feed_dict)
            self.assertEqual(avg_02,
                             sess.run(streaming_avg.batch_value, feed_dict))
            self.assertEqual(avg, sess.run(streaming_avg.value, feed_dict))

            # Reset.
            sess.run(streaming_avg.reset_op)
            self.assertEqual(0.0, sess.run(streaming_avg.value))
            self.assertEqual(0, sess.run(streaming_avg.count))

            # Second batch as first.
            feed_dict = {
                values: values_02,
                weights: weights_02,
            }
            sess.run(streaming_avg.update_op, feed_dict)
            self.assertEqual(avg_02,
                             sess.run(streaming_avg.batch_value, feed_dict))
            self.assertEqual(avg_02, sess.run(streaming_avg.value, feed_dict))
Esempio n. 4
0
 def __init__(self, func, average=None, name=None):
     super(StreamingMetric, self).__init__(name=name)
     self._func = func
     self._avg = average or streaming.StreamingAverage()