示例#1
0
 def _StateToArgs(state):
   """Returns a list of FProp args from a NestedMap."""
   arg_list = []
   for idx in range(len(args)):
     attr = '_s{}'.format(idx)
     arg_list.append(state[attr] if attr in state else None)
     if isinstance(arg_list[-1], py_utils.NestedMap):
       assert isinstance(args[idx], py_utils.NestedMap)
       py_utils.SetShapes(arg_list[-1], args[idx])
     elif isinstance(arg_list[-1], tf.Tensor):
       if arg_list[-1] is not None:
         arg_list[-1].set_shape(args[idx].shape)
   return arg_list
示例#2
0
 def testSetShape(self):
   dst = py_utils.NestedMap(
       a=tf.placeholder(tf.int32, shape=None),
       b=py_utils.NestedMap(
           b1=tf.placeholder(tf.int32, shape=None),
           b2=tf.placeholder(tf.int32, shape=None)))
   src = py_utils.NestedMap(
       a=tf.constant(0, shape=[2, 4], dtype=tf.int32),
       b=py_utils.NestedMap(
           b1=tf.constant(0, shape=[1, 3], dtype=tf.int32),
           b2=tf.constant(0, shape=[5, 8], dtype=tf.int32)))
   py_utils.SetShapes(dst, src)
   self.assertAllClose(
       [2, 4],
       py_utils.GetShape(dst.a, 2),
   )
   self.assertAllClose([1, 3], py_utils.GetShape(dst.b.b1, 2))
   self.assertAllClose([5, 8], py_utils.GetShape(dst.b.b2, 2))