예제 #1
0
    def test_stateful_metrics(self):
        with self.cached_session():
            np.random.seed(1334)

            class BinaryTruePositives(layers.Layer):
                """Stateful Metric to count the total true positives over all batches.

        Assumes predictions and targets of shape `(samples, 1)`.

        Arguments:
            threshold: Float, lower limit on prediction value that counts as a
                positive class prediction.
            name: String, name for the metric.
        """
                def __init__(self, name='true_positives', **kwargs):
                    super(BinaryTruePositives, self).__init__(name=name,
                                                              **kwargs)
                    self.true_positives = K.variable(value=0, dtype='int32')
                    self.stateful = True

                def reset_states(self):
                    K.set_value(self.true_positives, 0)

                def __call__(self, y_true, y_pred):
                    """Computes the number of true positives in a batch.

          Args:
              y_true: Tensor, batch_wise labels
              y_pred: Tensor, batch_wise predictions

          Returns:
              The total number of true positives seen this epoch at the
                  completion of the batch.
          """
                    y_true = math_ops.cast(y_true, 'int32')
                    y_pred = math_ops.cast(math_ops.round(y_pred), 'int32')
                    correct_preds = math_ops.cast(
                        math_ops.equal(y_pred, y_true), 'int32')
                    true_pos = math_ops.cast(
                        math_ops.reduce_sum(correct_preds * y_true), 'int32')
                    current_true_pos = self.true_positives * 1
                    self.add_update(state_ops.assign_add(
                        self.true_positives, true_pos),
                                    inputs=[y_true, y_pred])
                    return current_true_pos + true_pos

            metric_fn = BinaryTruePositives()
            config = metrics.serialize(metric_fn)
            metric_fn = metrics.deserialize(
                config,
                custom_objects={'BinaryTruePositives': BinaryTruePositives})

            # Test on simple model
            inputs = layers.Input(shape=(2, ))
            outputs = layers.Dense(1, activation='sigmoid')(inputs)
            model = Model(inputs, outputs)
            model.compile(optimizer='sgd',
                          loss='binary_crossentropy',
                          metrics=['acc', metric_fn])

            # Test fit, evaluate
            samples = 100
            x = np.random.random((samples, 2))
            y = np.random.randint(2, size=(samples, 1))
            val_samples = 10
            val_x = np.random.random((val_samples, 2))
            val_y = np.random.randint(2, size=(val_samples, 1))

            history = model.fit(x,
                                y,
                                epochs=1,
                                batch_size=10,
                                validation_data=(val_x, val_y))
            outs = model.evaluate(x, y, batch_size=10)
            preds = model.predict(x)

            def ref_true_pos(y_true, y_pred):
                return np.sum(np.logical_and(y_pred > 0.5, y_true == 1))

            # Test correctness (e.g. updates should have been run)
            self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)

            # Test correctness of the validation metric computation
            val_preds = model.predict(val_x)
            val_outs = model.evaluate(val_x, val_y, batch_size=10)
            self.assertAllClose(val_outs[2],
                                ref_true_pos(val_y, val_preds),
                                atol=1e-5)
            self.assertAllClose(val_outs[2],
                                history.history['val_true_positives'][-1],
                                atol=1e-5)

            # Test with generators
            gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(x, y)]
            val_gen = [(np.array([x0]), np.array([y0]))
                       for x0, y0 in zip(val_x, val_y)]
            history = model.fit_generator(iter(gen),
                                          epochs=1,
                                          steps_per_epoch=samples,
                                          validation_data=iter(val_gen),
                                          validation_steps=val_samples)
            outs = model.evaluate_generator(iter(gen), steps=samples)
            preds = model.predict_generator(iter(gen), steps=samples)

            # Test correctness of the metric results
            self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)

            # Test correctness of the validation metric computation
            val_preds = model.predict_generator(iter(val_gen),
                                                steps=val_samples)
            val_outs = model.evaluate_generator(iter(val_gen),
                                                steps=val_samples)
            self.assertAllClose(val_outs[2],
                                ref_true_pos(val_y, val_preds),
                                atol=1e-5)
            self.assertAllClose(val_outs[2],
                                history.history['val_true_positives'][-1],
                                atol=1e-5)
예제 #2
0
  def test_stateful_metrics(self):
    with self.test_session():
      np.random.seed(1334)

      class BinaryTruePositives(layers.Layer):
        """Stateful Metric to count the total true positives over all batches.

        Assumes predictions and targets of shape `(samples, 1)`.

        Arguments:
            threshold: Float, lower limit on prediction value that counts as a
                positive class prediction.
            name: String, name for the metric.
        """

        def __init__(self, name='true_positives', **kwargs):
          super(BinaryTruePositives, self).__init__(name=name, **kwargs)
          self.true_positives = K.variable(value=0, dtype='int32')
          self.stateful = True

        def reset_states(self):
          K.set_value(self.true_positives, 0)

        def __call__(self, y_true, y_pred):
          """Computes the number of true positives in a batch.

          Args:
              y_true: Tensor, batch_wise labels
              y_pred: Tensor, batch_wise predictions

          Returns:
              The total number of true positives seen this epoch at the
                  completion of the batch.
          """
          y_true = math_ops.cast(y_true, 'int32')
          y_pred = math_ops.cast(math_ops.round(y_pred), 'int32')
          correct_preds = math_ops.cast(math_ops.equal(y_pred, y_true), 'int32')
          true_pos = math_ops.cast(
              math_ops.reduce_sum(correct_preds * y_true), 'int32')
          current_true_pos = self.true_positives * 1
          self.add_update(
              state_ops.assign_add(self.true_positives, true_pos),
              inputs=[y_true, y_pred])
          return current_true_pos + true_pos

      metric_fn = BinaryTruePositives()
      config = metrics.serialize(metric_fn)
      metric_fn = metrics.deserialize(
          config, custom_objects={'BinaryTruePositives': BinaryTruePositives})

      # Test on simple model
      inputs = layers.Input(shape=(2,))
      outputs = layers.Dense(1, activation='sigmoid')(inputs)
      model = Model(inputs, outputs)
      model.compile(optimizer='sgd',
                    loss='binary_crossentropy',
                    metrics=['acc', metric_fn])

      # Test fit, evaluate
      samples = 100
      x = np.random.random((samples, 2))
      y = np.random.randint(2, size=(samples, 1))
      val_samples = 10
      val_x = np.random.random((val_samples, 2))
      val_y = np.random.randint(2, size=(val_samples, 1))

      history = model.fit(x, y,
                          epochs=1,
                          batch_size=10,
                          validation_data=(val_x, val_y))
      outs = model.evaluate(x, y, batch_size=10)
      preds = model.predict(x)

      def ref_true_pos(y_true, y_pred):
        return np.sum(np.logical_and(y_pred > 0.5, y_true == 1))

      # Test correctness (e.g. updates should have been run)
      self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)

      # Test correctness of the validation metric computation
      val_preds = model.predict(val_x)
      val_outs = model.evaluate(val_x, val_y, batch_size=10)
      self.assertAllClose(
          val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
      self.assertAllClose(
          val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)

      # Test with generators
      gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(x, y)]
      val_gen = [(np.array([x0]), np.array([y0]))
                 for x0, y0 in zip(val_x, val_y)]
      history = model.fit_generator(iter(gen),
                                    epochs=1,
                                    steps_per_epoch=samples,
                                    validation_data=iter(val_gen),
                                    validation_steps=val_samples)
      outs = model.evaluate_generator(iter(gen), steps=samples)
      preds = model.predict_generator(iter(gen), steps=samples)

      # Test correctness of the metric results
      self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)

      # Test correctness of the validation metric computation
      val_preds = model.predict_generator(iter(val_gen), steps=val_samples)
      val_outs = model.evaluate_generator(iter(val_gen), steps=val_samples)
      self.assertAllClose(
          val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
      self.assertAllClose(
          val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)