def testRefIdentityShape(self):
   with self.cached_session():
     shape = [2, 3]
     tensor = variables.VariableV1(
         constant_op.constant(
             [[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32))
     self.assertEquals(shape, tensor.get_shape())
     self.assertEquals(shape, gen_array_ops.ref_identity(tensor).get_shape())
Example #2
0
 def testRefIdentityShape(self):
     with self.cached_session():
         shape = [2, 3]
         tensor = variables.VariableV1(
             constant_op.constant([[1, 2, 3], [6, 5, 4]],
                                  dtype=dtypes.int32))
         self.assertEquals(shape, tensor.get_shape())
         self.assertEquals(shape,
                           gen_array_ops.ref_identity(tensor).get_shape())
Example #3
0
 def testColocationContraints(self):
   with ops.Graph().as_default() as g:
     c = constant_op.constant([10])
     v = variables.Variable([3], dtype=dtypes.int32)
     i = gen_array_ops.ref_identity(v)
     a = state_ops.assign(i, c)
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
     train_op.append(a)
     mg = meta_graph.create_meta_graph_def(graph=g)
     grappler_item = item.Item(mg)
     groups = grappler_item.GetColocationGroups()
     self.assertEqual(len(groups), 1)
     self.assertItemsEqual(
         groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])
 def testColocationContraints(self):
   with ops.Graph().as_default() as g:
     c = constant_op.constant([10])
     v = variables.VariableV1([3], dtype=dtypes.int32)
     i = gen_array_ops.ref_identity(v)
     a = state_ops.assign(i, c)
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
     train_op.append(a)
     mg = meta_graph.create_meta_graph_def(graph=g)
     grappler_item = item.Item(mg)
     groups = grappler_item.GetColocationGroups()
     self.assertEqual(len(groups), 1)
     self.assertItemsEqual(
         groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])