def test_compute_boundary_ts_1(self): """Test for ge.compute_boundary_ts.""" input_ts, output_ts, inside_ts = ge.compute_boundary_ts( [self.g.op, self.h.op]) self.assertEqual(list(input_ts), [self.c, self.a, self.f]) self.assertEqual(list(output_ts), [self.h]) self.assertEqual(list(inside_ts), [self.g])
def test_compute_boundary_ts_2(self): """Test for ge.compute_boundary_ts.""" graph = ops_lib.Graph() with graph.as_default(): a = constant_op.constant(1, name="a") b = constant_op.constant(1, name="b") c = math_ops.add(a, b, name="c") _ = a + c input_ts, output_ts, inside_ts = ge.compute_boundary_ts([a.op, c.op]) self.assertEqual(list(input_ts), [b]) self.assertEqual(list(output_ts), [a, c]) self.assertEqual(list(inside_ts), [a])
def test_compute_boundary_ts_0(self): """Test for ge.compute_boundary_ts.""" input_ts, output_ts, inside_ts = ge.compute_boundary_ts(self.g.op) self.assertEqual(list(input_ts), [self.c, self.a]) self.assertEqual(list(output_ts), [self.g]) self.assertEqual(list(inside_ts), [])