Exemplo n.º 1
0
  def test_config(self):
    k_obj = metrics.KullbackLeiblerDivergence(name='kld', dtype=dtypes.int32)
    self.assertEqual(k_obj.name, 'kld')
    self.assertEqual(k_obj._dtype, dtypes.int32)

    k_obj2 = metrics.KullbackLeiblerDivergence.from_config(k_obj.get_config())
    self.assertEqual(k_obj2.name, 'kld')
    self.assertEqual(k_obj2._dtype, dtypes.int32)
Exemplo n.º 2
0
    def test_unweighted(self):
        self.setup()
        k_obj = metrics.KullbackLeiblerDivergence()
        self.evaluate(variables.variables_initializer(k_obj.variables))

        update_op = k_obj.update_state(self.y_true, self.y_pred)
        self.evaluate(update_op)
        result = k_obj.result()
        expected_result = np.sum(self.expected_results) / self.batch_size
        self.assertAllClose(result, expected_result, atol=1e-3)
Exemplo n.º 3
0
  def test_weighted(self):
    self.setup()
    k_obj = metrics.KullbackLeiblerDivergence()
    self.evaluate(variables.variables_initializer(k_obj.variables))

    sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
    result = k_obj(self.y_true, self.y_pred, sample_weight=sample_weight)

    sample_weight = np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3))
    expected_result = np.multiply(self.expected_results, sample_weight)
    expected_result = np.sum(expected_result) / (1.2 + 3.4)
    self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)