Пример #1
0
 def testWhileContext(self):
   with self.test_session() as sess:
     i = constant_op.constant(0)
     c = lambda i: math_ops.less(i, 10)
     b = lambda i: math_ops.add(i, 1)
     control_flow_ops.while_loop(c, b, [i])
     for op in sess.graph.get_operations():
       c = op._get_control_flow_context()
       if c:
         compare.ProtoEq(
             c.to_proto(),
             control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto())
Пример #2
0
    def _AssertProtoEquals(self, a, b):
        """Asserts that a and b are the same proto.

    Uses ProtoEq() first, as it returns correct results
    for floating point attributes, and then use assertProtoEqual()
    in case of failure as it provides good error messages.

    Args:
      a: a proto.
      b: another proto.
    """
        if not compare.ProtoEq(a, b):
            compare.assertProtoEqual(self, a, b, normalize_numbers=True)
Пример #3
0
 def testCondContext(self):
   with self.test_session() as sess:
     x = constant_op.constant(2)
     y = constant_op.constant(5)
     control_flow_ops.cond(
         math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
         lambda: math_ops.add(y, 23))
     for op in sess.graph.get_operations():
       c = op._get_control_flow_context()
       if c:
         compare.ProtoEq(
             c.to_proto(),
             control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
Пример #4
0
  def testControlContextImportScope(self):
    with self.test_session():
      constant_op.constant(0, name="a")
      constant_op.constant(2, name="test_scope/a")
      b1 = constant_op.constant(1, name="b")
      b2 = constant_op.constant(3, name="test_scope/b")

      c = control_flow_ops.ControlFlowContext()
      c._values = ["a", "b"]
      c._external_values = {"a": b1}

      c_with_scope = control_flow_ops.ControlFlowContext._from_proto(
          c._to_proto(), import_scope="test_scope")

      # _values and _external_values should be have scope prepended.
      self.assertEquals(
          c_with_scope._values, set(["test_scope/a", "test_scope/b"]))
      self.assertEquals(
          c_with_scope._external_values, {"test_scope/a": b2})

      # Calling _to_proto() with export_scope should remove "test_scope".
      compare.ProtoEq(
          c._to_proto(),
          c_with_scope._to_proto(export_scope="test_scope"))
Пример #5
0
 def testPrimitives(self):
   googletest.TestCase.assertEqual(self, True, compare.ProtoEq('a', 'a'))
   googletest.TestCase.assertEqual(self, False, compare.ProtoEq('b', 'a'))
Пример #6
0
 def assertEquals(self, a, b):
   """Asserts that ProtoEq says a == b."""
   a, b = LargePbs(a, b)
   googletest.TestCase.assertEquals(self, compare.ProtoEq(a, b), True)
Пример #7
0
 def __eq__(self, other: '_DatasetFeatureStatisticsComparatorWrapper'):
     return compare.ProtoEq(self._normalized, other._normalized)  # pylint: disable=protected-access
Пример #8
0
 def __eq__(self, other):
     return compare.ProtoEq(self._normalized, other._normalized)  # pylint: disable=protected-access