def _StateToArgs(state): """Returns a list of FProp args from a NestedMap.""" arg_list = [] for idx in range(len(args)): attr = '_s{}'.format(idx) arg_list.append(state[attr] if attr in state else None) if isinstance(arg_list[-1], py_utils.NestedMap): assert isinstance(args[idx], py_utils.NestedMap) py_utils.SetShapes(arg_list[-1], args[idx]) elif isinstance(arg_list[-1], tf.Tensor): if arg_list[-1] is not None: arg_list[-1].set_shape(args[idx].shape) return arg_list
def testSetShape(self): dst = py_utils.NestedMap( a=tf.placeholder(tf.int32, shape=None), b=py_utils.NestedMap( b1=tf.placeholder(tf.int32, shape=None), b2=tf.placeholder(tf.int32, shape=None))) src = py_utils.NestedMap( a=tf.constant(0, shape=[2, 4], dtype=tf.int32), b=py_utils.NestedMap( b1=tf.constant(0, shape=[1, 3], dtype=tf.int32), b2=tf.constant(0, shape=[5, 8], dtype=tf.int32))) py_utils.SetShapes(dst, src) self.assertAllClose( [2, 4], py_utils.GetShape(dst.a, 2), ) self.assertAllClose([1, 3], py_utils.GetShape(dst.b.b1, 2)) self.assertAllClose([5, 8], py_utils.GetShape(dst.b.b2, 2))