def testDeviceObject(self):
     op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [])
     op._set_device("/job:goo/device:GPU:0")
     self.assertProtoEquals("op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
     op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], [])
     op._set_device(pydev.Device(job="muu", device_type="CPU", device_index=0))
     self.assertProtoEquals("op:'noop' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
Beispiel #2
0
 def _WeMustGoDeeper(self, msg):
   with self.assertRaisesOpError(msg):
     node_def = ops._NodeDef("op_type", "name")
     node_def_orig = ops._NodeDef("op_type_orig", "orig")
     op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
     op = ops.Operation(node_def, ops.get_default_graph(), original_op=op_orig)
     raise errors.UnauthenticatedError(node_def, op, "true_err")
Beispiel #3
0
 def _WeMustGoDeeper(self, msg):
     with self.assertRaisesOpError(msg):
         node_def = ops._NodeDef("op_type", "name")
         node_def_orig = ops._NodeDef("op_type_orig", "orig")
         op_orig = ops.Operation(node_def_orig, ops.get_default_graph())
         op = ops.Operation(node_def, ops.get_default_graph(), original_op=op_orig)
         raise errors.UnauthenticatedError(node_def, op, "true_err")
Beispiel #4
0
    def testInputsAndOutputs(self):
        g = ops.Graph()
        op1 = ops.Operation(ops._NodeDef("noop", "myop1"), g, [],
                            [dtypes.float32])
        self.assertEqual(1, len(op1.values()))
        float1_t, = op1.values()

        op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [],
                            [dtypes.float32, dtypes.string])
        self.assertEqual(2, len(op2.values()))
        float2_t, label2_str_t = op2.values()

        # Note that we consume label2_str_t twice here.
        op3 = ops.Operation(ops._NodeDef("add", "myop3"), g,
                            [float1_t, label2_str_t, label2_str_t],
                            [dtypes.float32, dtypes.int32])
        self.assertEqual(2, len(op3.values()))

        self.assertEqual(1, len(float1_t._consumers))
        self.assertEqual(op3, float1_t._consumers[0])

        self.assertEqual(0, len(float2_t._consumers))

        self.assertEqual(2, len(label2_str_t._consumers))
        self.assertEqual(op3, label2_str_t._consumers[0])
        self.assertEqual(op3, label2_str_t._consumers[1])

        self.assertProtoEquals(
            """
    op:'add' name:'myop3'
    input:'myop1' input:'myop2:1' input:'myop2:1'
    """, op3.node_def)
Beispiel #5
0
  def testInputsAndOutputs(self):
    g = ops.Graph()
    op1 = ops.Operation(
        ops._NodeDef("noop", "myop1"), g, [], [dtypes.float32])
    self.assertEqual(1, len(op1.values()))
    float1_t, = op1.values()

    op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g,
                        [], [dtypes.float32, dtypes.string])
    self.assertEqual(2, len(op2.values()))
    float2_t, label2_str_t = op2.values()

    # Note that we consume label2_str_t twice here.
    op3 = ops.Operation(ops._NodeDef("add", "myop3"), g,
                        [float1_t, label2_str_t, label2_str_t],
                        [dtypes.float32, dtypes.int32])
    self.assertEqual(2, len(op3.values()))

    self.assertEqual(1, len(float1_t._consumers))
    self.assertEqual(op3, float1_t._consumers[0])

    self.assertEqual(0, len(float2_t._consumers))

    self.assertEqual(2, len(label2_str_t._consumers))
    self.assertEqual(op3, label2_str_t._consumers[0])
    self.assertEqual(op3, label2_str_t._consumers[1])

    self.assertProtoEquals("""
    op:'add' name:'myop3'
    input:'myop1' input:'myop2:1' input:'myop2:1'
    """, op3.node_def)
Beispiel #6
0
 def testInvalidNames(self):
   g = ops.Graph()
   with self.assertRaises(ValueError):
     ops.Operation(ops._NodeDef("op", ""), g)
   with self.assertRaises(ValueError):
     ops.Operation(ops._NodeDef("op", "_invalid"), g)
   with self.assertRaises(ValueError):
     ops.Operation(ops._NodeDef("op", "-invalid"), g)
   with self.assertRaises(ValueError):
     ops.Operation(ops._NodeDef("op", "/invalid"), g)
Beispiel #7
0
 def testInvalidNames(self):
     g = ops.Graph()
     with self.assertRaises(ValueError):
         ops.Operation(ops._NodeDef("op", ""), g)
     with self.assertRaises(ValueError):
         ops.Operation(ops._NodeDef("op", "_invalid"), g)
     with self.assertRaises(ValueError):
         ops.Operation(ops._NodeDef("op", "-invalid"), g)
     with self.assertRaises(ValueError):
         ops.Operation(ops._NodeDef("op", "/invalid"), g)
Beispiel #8
0
 def testDeviceObject(self):
     op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [])
     op._set_device("/job:goo/device:GPU:0")
     self.assertProtoEquals(
         "op:'noop' name:'myop' device:'/job:goo/device:GPU:0' ",
         op.node_def)
     op = ops.Operation(ops._NodeDef("noop", "op2"), ops.Graph(), [], [])
     op._set_device(
         pydev.Device(job="muu", device_type="CPU", device_index=0))
     self.assertProtoEquals(
         "op:'noop' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
 def testReferenceInput(self):
     g = ops.Graph()
     op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [], [types.float32_ref, types.float32])
     self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def)
     ref_t, nonref_t = op1.values()
     # NOTE(mrry): Must specify input_types to preserve ref-typed input.
     op2 = ops.Operation(
         ops._NodeDef("refop", "op2"), g, [ref_t, nonref_t], [], input_types=[types.float32_ref, types.float32]
     )
     self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", op2.node_def)
     op3 = ops.Operation(ops._NodeDef("nonrefop", "op3"), g, [ref_t, nonref_t], [])
     self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", op3.node_def)
    def testNoOutputs(self):
        g = ops.Graph()
        op1 = ops.Operation(ops._NodeDef("noop", "myop1"), g, [], [types.float32])
        float_t, = op1.values()
        op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], [])
        self.assertEquals(0, len(op2.values()))
        self.assertEquals(1, len(op2.inputs))
        self.assertIs(float_t, op2.inputs[0])

        self.assertEquals(1, len(float_t._consumers))
        self.assertEquals(op2, float_t._consumers[0])

        self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def)
        self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'", op2.node_def)
Beispiel #11
0
 def testShape(self):
     op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [],
                        [dtypes.float32])
     t = op.outputs[0]
     self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
     t.set_shape([1, 2, 3])
     self.assertEqual([1, 2, 3], t.get_shape())
Beispiel #12
0
 def testReferenceInput(self):
     g = ops.Graph()
     op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [],
                         [dtypes.float32_ref, dtypes.float32])
     self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def)
     ref_t, nonref_t = op1.values()
     # NOTE(mrry): Must specify input_types to preserve ref-typed input.
     op2 = ops.Operation(ops._NodeDef("refop", "op2"),
                         g, [ref_t, nonref_t], [],
                         input_types=[dtypes.float32_ref, dtypes.float32])
     self.assertProtoEquals(
         "op:'refop' name:'op2' input:'op1' input:'op1:1'", op2.node_def)
     op3 = ops.Operation(ops._NodeDef("nonrefop", "op3"), g,
                         [ref_t, nonref_t], [])
     self.assertProtoEquals(
         "op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", op3.node_def)
Beispiel #13
0
 def testShape(self):
   op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(),
                      [], [dtypes.float32])
   t = op.outputs[0]
   self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
   t.set_shape([1, 2, 3])
   self.assertEqual([1, 2, 3], t.get_shape())
Beispiel #14
0
    def testNoOutputs(self):
        g = ops.Graph()
        op1 = ops.Operation(ops._NodeDef("noop", "myop1"), g, [],
                            [dtypes.float32])
        float_t, = op1.values()
        op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], [])
        self.assertEqual(0, len(op2.values()))
        self.assertEqual(1, len(op2.inputs))
        self.assertIs(float_t, op2.inputs[0])

        self.assertEqual(1, len(float_t._consumers))
        self.assertEqual(op2, float_t._consumers[0])

        self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def)
        self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'",
                               op2.node_def)
Beispiel #15
0
 def testNoShapeFunction(self):
     g = ops.Graph()
     op = ops.Operation(ops._NodeDef("op", "an_op"),
                        g,
                        output_types=[dtypes.float32])
     self.assertEqual(
         tensor_shape.unknown_shape(),
         _apply_op(g, "an_op", [], [dtypes.float32]).get_shape())
Beispiel #16
0
 def testIterable(self):
   op = ops.Operation(
       ops._NodeDef("noop", "myop"), ops.Graph(), [], [dtypes.float32])
   t = op.outputs[0]
   self.assertTrue(isinstance(t, ops.Tensor))
   with self.assertRaisesRegexp(TypeError, "not iterable"):
     for _ in t:
       pass
Beispiel #17
0
 def testRegisteredNode(self):
   graph = ops.Graph()
   node = ops._NodeDef("a", "an_a")
   weight_params = ops.get_stats_for_node_def(graph, node, "weight_parameters")
   self.assertEqual(10, weight_params.value)
   flops = ops.get_stats_for_node_def(graph, node, "flops")
   self.assertEqual(20, flops.value)
   missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
   self.assertEqual(None, missing_stat.value)
Beispiel #18
0
 def testRegisteredNode(self):
     graph = ops.Graph()
     node = ops._NodeDef("a", "an_a")
     weight_params = ops.get_stats_for_node_def(graph, node,
                                                "weight_parameters")
     self.assertEqual(10, weight_params.value)
     flops = ops.get_stats_for_node_def(graph, node, "flops")
     self.assertEqual(20, flops.value)
     missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
     self.assertEqual(None, missing_stat.value)
    def testNoInputs(self):
        op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], [types.float32, types.string])
        self.assertEquals(2, len(op.values()))
        self.assertEquals(0, len(op.inputs))
        self.assertEquals("myop", op.name)

        float_t, label_str_t = op.values()
        self.assertEquals(types.float32, float_t.dtype)
        self.assertEquals(op, float_t.op)
        self.assertEquals(0, float_t._value_index)
        self.assertEquals(0, len(float_t._consumers))
        self.assertEquals("myop", float_t._as_node_def_input())

        self.assertEquals(types.string, label_str_t.dtype)
        self.assertEquals(op, label_str_t.op)
        self.assertEquals(1, label_str_t._value_index)
        self.assertEquals(0, len(label_str_t._consumers))
        self.assertEquals("myop:1", label_str_t._as_node_def_input())

        self.assertProtoEquals("op:'noop' name:'myop'", op.node_def)
Beispiel #20
0
    def testNoInputs(self):
        op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [],
                           [dtypes.float32, dtypes.string])
        self.assertEqual(2, len(op.values()))
        self.assertEqual(0, len(op.inputs))
        self.assertEqual("myop", op.name)

        float_t, label_str_t = op.values()
        self.assertEqual(dtypes.float32, float_t.dtype)
        self.assertEqual(op, float_t.op)
        self.assertEqual(0, float_t._value_index)
        self.assertEqual(0, len(float_t._consumers))
        self.assertEqual("myop", float_t._as_node_def_input())

        self.assertEqual(dtypes.string, label_str_t.dtype)
        self.assertEqual(op, label_str_t.op)
        self.assertEqual(1, label_str_t._value_index)
        self.assertEqual(0, len(label_str_t._consumers))
        self.assertEqual("myop:1", label_str_t._as_node_def_input())

        self.assertProtoEquals("op:'noop' name:'myop'", op.node_def)
Beispiel #21
0
 def testNoShapeFunction(self):
   g = ops.Graph()
   op = ops.Operation(ops._NodeDef("op", "an_op"), g,
                      output_types = [dtypes.float32])
   self.assertEqual(tensor_shape.unknown_shape(),
                    _apply_op(g, "an_op", [], [dtypes.float32]).get_shape())
Beispiel #22
0
 def testNoArgs(self):
     nodedef = ops._NodeDef("noop", "bar")
     self.assertProtoEquals("op: 'noop' name: 'bar'", nodedef)
Beispiel #23
0
 def testUnregisteredNode(self):
   graph = ops.Graph()
   node = ops._NodeDef("b", "a_b")
   weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
   self.assertEqual(None, weight_params.value)
Beispiel #24
0
 def testArgs(self):
   nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
   self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
                          nodedef)
   nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j"))
   self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
Beispiel #25
0
 def testArgs(self):
     nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
     self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
                            nodedef)
     nodedef = ops._NodeDef("foo", "bar", device=pydev.Device(job="j"))
     self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
Beispiel #26
0
 def testNoArgs(self):
   nodedef = ops._NodeDef("noop", "bar")
   self.assertProtoEquals("op: 'noop' name: 'bar'", nodedef)
Beispiel #27
0
 def testUnregisteredNode(self):
     graph = ops.Graph()
     node = ops._NodeDef("b", "a_b")
     weight_params = ops.get_stats_for_node_def(graph, node,
                                                "weight_params")
     self.assertEqual(None, weight_params.value)