Exemplo n.º 1
0
            test=test,
            correctness_function=self.default_correctness_function)

    def test_uniform_random(self):
        self._uniform_random_ops(test=True)

    def test_tensor_name_error(self):
        with self.assertRaises(AssertionError):
            self._uniform_random_ops(test=True, wrong_name=True)

    @unittest.skipIf(keras_utils.is_v2_0(),
                     "TODO:(b/136010138) Fails on TF 2.0.")
    def test_tensor_shape_error(self):
        with self.assertRaises(AssertionError):
            self._uniform_random_ops(test=True, wrong_shape=True)

    def test_incorrectness_function(self):
        with self.assertRaises(AssertionError):
            self._uniform_random_ops(test=True, bad_function=True)

    def test_dense(self):
        self._dense_ops(test=True)

    def regenerate(self):
        self._uniform_random_ops(test=False)
        self._dense_ops(test=False)


if __name__ == "__main__":
    reference_data.main(argv=sys.argv, test_class=GoldenBaseTest)
Exemplo n.º 2
0
  def test_tensor_name_error(self):
    with self.assertRaises(AssertionError):
      self._uniform_random_ops(test=True, wrong_name=True)

  def test_tensor_shape_error(self):
    with self.assertRaises(AssertionError):
      self._uniform_random_ops(test=True, wrong_shape=True)

  @unittest.skipIf(sys.version_info[0] == 2,
                   "catch_warning doesn't catch tf.logging.warn in py 2.")
  def test_bad_seed(self):
    with warnings.catch_warnings(record=True) as warn_catch:
      self._uniform_random_ops(test=True, bad_seed=True)
      assert len(warn_catch) == 1, "Test did not warn of minor graph change."

  def test_incorrectness_function(self):
    with self.assertRaises(AssertionError):
      self._uniform_random_ops(test=True, bad_function=True)

  def test_dense(self):
    self._dense_ops(test=True)

  def regenerate(self):
    self._uniform_random_ops(test=False)
    self._dense_ops(test=False)


if __name__ == "__main__":
  reference_data.main(argv=sys.argv, test_class=GoldenBaseTest)