def testDeviceObject(self): op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], []) op._set_device("/job:goo/device:GPU:0") self.assertProtoEquals("op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], []) op._set_device(pydev.Device(job="muu", device_type="CPU", device_index=0)) self.assertProtoEquals("op:'noop' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
def _WeMustGoDeeper(self, msg): with self.assertRaisesOpError(msg): node_def = ops._NodeDef("op_type", "name") node_def_orig = ops._NodeDef("op_type_orig", "orig") op_orig = ops.Operation(node_def_orig, ops.get_default_graph()) op = ops.Operation(node_def, ops.get_default_graph(), original_op=op_orig) raise errors.UnauthenticatedError(node_def, op, "true_err")
def testInputsAndOutputs(self): g = ops.Graph() op1 = ops.Operation(ops._NodeDef("noop", "myop1"), g, [], [dtypes.float32]) self.assertEqual(1, len(op1.values())) float1_t, = op1.values() op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [], [dtypes.float32, dtypes.string]) self.assertEqual(2, len(op2.values())) float2_t, label2_str_t = op2.values() # Note that we consume label2_str_t twice here. op3 = ops.Operation(ops._NodeDef("add", "myop3"), g, [float1_t, label2_str_t, label2_str_t], [dtypes.float32, dtypes.int32]) self.assertEqual(2, len(op3.values())) self.assertEqual(1, len(float1_t._consumers)) self.assertEqual(op3, float1_t._consumers[0]) self.assertEqual(0, len(float2_t._consumers)) self.assertEqual(2, len(label2_str_t._consumers)) self.assertEqual(op3, label2_str_t._consumers[0]) self.assertEqual(op3, label2_str_t._consumers[1]) self.assertProtoEquals( """ op:'add' name:'myop3' input:'myop1' input:'myop2:1' input:'myop2:1' """, op3.node_def)
def testInputsAndOutputs(self): g = ops.Graph() op1 = ops.Operation( ops._NodeDef("noop", "myop1"), g, [], [dtypes.float32]) self.assertEqual(1, len(op1.values())) float1_t, = op1.values() op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [], [dtypes.float32, dtypes.string]) self.assertEqual(2, len(op2.values())) float2_t, label2_str_t = op2.values() # Note that we consume label2_str_t twice here. op3 = ops.Operation(ops._NodeDef("add", "myop3"), g, [float1_t, label2_str_t, label2_str_t], [dtypes.float32, dtypes.int32]) self.assertEqual(2, len(op3.values())) self.assertEqual(1, len(float1_t._consumers)) self.assertEqual(op3, float1_t._consumers[0]) self.assertEqual(0, len(float2_t._consumers)) self.assertEqual(2, len(label2_str_t._consumers)) self.assertEqual(op3, label2_str_t._consumers[0]) self.assertEqual(op3, label2_str_t._consumers[1]) self.assertProtoEquals(""" op:'add' name:'myop3' input:'myop1' input:'myop2:1' input:'myop2:1' """, op3.node_def)
def testInvalidNames(self): g = ops.Graph() with self.assertRaises(ValueError): ops.Operation(ops._NodeDef("op", ""), g) with self.assertRaises(ValueError): ops.Operation(ops._NodeDef("op", "_invalid"), g) with self.assertRaises(ValueError): ops.Operation(ops._NodeDef("op", "-invalid"), g) with self.assertRaises(ValueError): ops.Operation(ops._NodeDef("op", "/invalid"), g)
def testDeviceObject(self): op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], []) op._set_device("/job:goo/device:GPU:0") self.assertProtoEquals( "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], []) op._set_device( pydev.Device(job="muu", device_type="CPU", device_index=0)) self.assertProtoEquals( "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
def testReferenceInput(self): g = ops.Graph() op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [], [types.float32_ref, types.float32]) self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def) ref_t, nonref_t = op1.values() # NOTE(mrry): Must specify input_types to preserve ref-typed input. op2 = ops.Operation( ops._NodeDef("refop", "op2"), g, [ref_t, nonref_t], [], input_types=[types.float32_ref, types.float32] ) self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", op2.node_def) op3 = ops.Operation(ops._NodeDef("nonrefop", "op3"), g, [ref_t, nonref_t], []) self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", op3.node_def)
def testNoOutputs(self): g = ops.Graph() op1 = ops.Operation(ops._NodeDef("noop", "myop1"), g, [], [types.float32]) float_t, = op1.values() op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], []) self.assertEquals(0, len(op2.values())) self.assertEquals(1, len(op2.inputs)) self.assertIs(float_t, op2.inputs[0]) self.assertEquals(1, len(float_t._consumers)) self.assertEquals(op2, float_t._consumers[0]) self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def) self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'", op2.node_def)
def testShape(self): op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [dtypes.float32]) t = op.outputs[0] self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) t.set_shape([1, 2, 3]) self.assertEqual([1, 2, 3], t.get_shape())
def testReferenceInput(self): g = ops.Graph() op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [], [dtypes.float32_ref, dtypes.float32]) self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def) ref_t, nonref_t = op1.values() # NOTE(mrry): Must specify input_types to preserve ref-typed input. op2 = ops.Operation(ops._NodeDef("refop", "op2"), g, [ref_t, nonref_t], [], input_types=[dtypes.float32_ref, dtypes.float32]) self.assertProtoEquals( "op:'refop' name:'op2' input:'op1' input:'op1:1'", op2.node_def) op3 = ops.Operation(ops._NodeDef("nonrefop", "op3"), g, [ref_t, nonref_t], []) self.assertProtoEquals( "op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", op3.node_def)
def testNoOutputs(self): g = ops.Graph() op1 = ops.Operation(ops._NodeDef("noop", "myop1"), g, [], [dtypes.float32]) float_t, = op1.values() op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], []) self.assertEqual(0, len(op2.values())) self.assertEqual(1, len(op2.inputs)) self.assertIs(float_t, op2.inputs[0]) self.assertEqual(1, len(float_t._consumers)) self.assertEqual(op2, float_t._consumers[0]) self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def) self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'", op2.node_def)
def testNoShapeFunction(self): g = ops.Graph() op = ops.Operation(ops._NodeDef("op", "an_op"), g, output_types=[dtypes.float32]) self.assertEqual( tensor_shape.unknown_shape(), _apply_op(g, "an_op", [], [dtypes.float32]).get_shape())
def testIterable(self): op = ops.Operation( ops._NodeDef("noop", "myop"), ops.Graph(), [], [dtypes.float32]) t = op.outputs[0] self.assertTrue(isinstance(t, ops.Tensor)) with self.assertRaisesRegexp(TypeError, "not iterable"): for _ in t: pass
def testRegisteredNode(self): graph = ops.Graph() node = ops._NodeDef("a", "an_a") weight_params = ops.get_stats_for_node_def(graph, node, "weight_parameters") self.assertEqual(10, weight_params.value) flops = ops.get_stats_for_node_def(graph, node, "flops") self.assertEqual(20, flops.value) missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat") self.assertEqual(None, missing_stat.value)
def testNoInputs(self): op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [types.float32, types.string]) self.assertEquals(2, len(op.values())) self.assertEquals(0, len(op.inputs)) self.assertEquals("myop", op.name) float_t, label_str_t = op.values() self.assertEquals(types.float32, float_t.dtype) self.assertEquals(op, float_t.op) self.assertEquals(0, float_t._value_index) self.assertEquals(0, len(float_t._consumers)) self.assertEquals("myop", float_t._as_node_def_input()) self.assertEquals(types.string, label_str_t.dtype) self.assertEquals(op, label_str_t.op) self.assertEquals(1, label_str_t._value_index) self.assertEquals(0, len(label_str_t._consumers)) self.assertEquals("myop:1", label_str_t._as_node_def_input()) self.assertProtoEquals("op:'noop' name:'myop'", op.node_def)
def testNoInputs(self): op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [dtypes.float32, dtypes.string]) self.assertEqual(2, len(op.values())) self.assertEqual(0, len(op.inputs)) self.assertEqual("myop", op.name) float_t, label_str_t = op.values() self.assertEqual(dtypes.float32, float_t.dtype) self.assertEqual(op, float_t.op) self.assertEqual(0, float_t._value_index) self.assertEqual(0, len(float_t._consumers)) self.assertEqual("myop", float_t._as_node_def_input()) self.assertEqual(dtypes.string, label_str_t.dtype) self.assertEqual(op, label_str_t.op) self.assertEqual(1, label_str_t._value_index) self.assertEqual(0, len(label_str_t._consumers)) self.assertEqual("myop:1", label_str_t._as_node_def_input()) self.assertProtoEquals("op:'noop' name:'myop'", op.node_def)
def testNoShapeFunction(self): g = ops.Graph() op = ops.Operation(ops._NodeDef("op", "an_op"), g, output_types = [dtypes.float32]) self.assertEqual(tensor_shape.unknown_shape(), _apply_op(g, "an_op", [], [dtypes.float32]).get_shape())
def testNoArgs(self): nodedef = ops._NodeDef("noop", "bar") self.assertProtoEquals("op: 'noop' name: 'bar'", nodedef)
def testUnregisteredNode(self): graph = ops.Graph() node = ops._NodeDef("b", "a_b") weight_params = ops.get_stats_for_node_def(graph, node, "weight_params") self.assertEqual(None, weight_params.value)
def testArgs(self): nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", nodedef) nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j")) self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)