예제 #1
0
 def testWhileContext(self):
   with self.test_session() as sess:
     i = constant_op.constant(0)
     c = lambda i: math_ops.less(i, 10)
     b = lambda i: math_ops.add(i, 1)
     control_flow_ops.while_loop(c, b, [i])
     for op in sess.graph.get_operations():
       c = op._get_control_flow_context()
       if c:
         compare.ProtoEq(
             c.to_proto(),
             control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto())
예제 #2
0
    def _AssertProtoEquals(self, a, b):
        """Asserts that a and b are the same proto.

    Uses ProtoEq() first, as it returns correct results
    for floating point attributes, and then use assertProtoEqual()
    in case of failure as it provides good error messages.

    Args:
      a: a proto.
      b: another proto.
    """
        if not compare.ProtoEq(a, b):
            compare.assertProtoEqual(self, a, b, normalize_numbers=True)
예제 #3
0
 def testCondContext(self):
   with self.test_session() as sess:
     x = constant_op.constant(2)
     y = constant_op.constant(5)
     control_flow_ops.cond(
         math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
         lambda: math_ops.add(y, 23))
     for op in sess.graph.get_operations():
       c = op._get_control_flow_context()
       if c:
         compare.ProtoEq(
             c.to_proto(),
             control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
예제 #4
0
  def testControlContextImportScope(self):
    with self.test_session():
      constant_op.constant(0, name="a")
      constant_op.constant(2, name="test_scope/a")
      b1 = constant_op.constant(1, name="b")
      b2 = constant_op.constant(3, name="test_scope/b")

      c = control_flow_ops.ControlFlowContext()
      c._values = ["a", "b"]
      c._external_values = {"a": b1}

      c_with_scope = control_flow_ops.ControlFlowContext._from_proto(
          c._to_proto(), import_scope="test_scope")

      # _values and _external_values should be have scope prepended.
      self.assertEquals(
          c_with_scope._values, set(["test_scope/a", "test_scope/b"]))
      self.assertEquals(
          c_with_scope._external_values, {"test_scope/a": b2})

      # Calling _to_proto() with export_scope should remove "test_scope".
      compare.ProtoEq(
          c._to_proto(),
          c_with_scope._to_proto(export_scope="test_scope"))
예제 #5
0
 def testPrimitives(self):
   googletest.TestCase.assertEqual(self, True, compare.ProtoEq('a', 'a'))
   googletest.TestCase.assertEqual(self, False, compare.ProtoEq('b', 'a'))
예제 #6
0
 def assertEquals(self, a, b):
   """Asserts that ProtoEq says a == b."""
   a, b = LargePbs(a, b)
   googletest.TestCase.assertEquals(self, compare.ProtoEq(a, b), True)
예제 #7
0
 def __eq__(self, other: '_DatasetFeatureStatisticsComparatorWrapper'):
     return compare.ProtoEq(self._normalized, other._normalized)  # pylint: disable=protected-access
예제 #8
0
 def __eq__(self, other):
     return compare.ProtoEq(self._normalized, other._normalized)  # pylint: disable=protected-access