def __init__(self, dtype, shape, send_device, recv_device, name=None): """Construct a channel. Args: dtype: The dtype of tensors sent through the channel. shape: The shape of tensors sent through the channel. Must be a fully defined shape for TPUs. send_device: A fully-specified tensorflow device. recv_device: A fully-specified tensorflow device. name: A name for the channel (optional). """ current_graph = tf.get_default_graph() assert current_graph, "A channel is scoped within a tf.Graph" self._dtype = dtype self._send_device = send_device self._recv_device = recv_device self._name = current_graph.unique_name(name if name else "channel") assert shape is not None shape = tf.TensorShape(shape) self._shape = shape self._send_tpu_core = _TpuCore(send_device) self._recv_tpu_core = _TpuCore(recv_device) self._send_called = False self._recv_op = None assert ((self._send_tpu_core == -1) == (self._recv_tpu_core == -1)), ( "Mixing TPU and non-TPU: %s and %s" % (send_device, recv_device)) if self._send_tpu_core >= 0: assert self._shape.is_fully_defined(), ( "TPU channel must have fully defined shape. Name: %s, shape: %s" % (self._name, self._shape)) assert self._send_tpu_core != self._recv_tpu_core, ( "TPU send/recv must be cross-core: %s and %s" % (send_device, recv_device))
def _get_input_shapes(self, *args): p = self.params if p.nested_map_fprop: assert len(args) == 1 assert isinstance(args[0], py_utils.NestedMap) input_tensors = py_utils.Flatten(args[0]) else: input_tensors = _ToTuple(args) # Get batch size from the first tensor which is not None. mini_batch_size = None for input_tensor in input_tensors: if input_tensor is not None: mini_batch_size = input_tensor.get_shape().as_list()[p.batch_dim] assert mini_batch_size is not None micro_batch_size = p.micro_batch_size if not micro_batch_size: if p.num_micro_batches > mini_batch_size: p.num_micro_batches = mini_batch_size micro_batch_size = mini_batch_size // p.num_micro_batches if mini_batch_size is not None: if micro_batch_size * p.num_micro_batches != mini_batch_size: raise ValueError('micro_batch_size * num_micro_batches != batch_size.') input_shapes = () for input_tensor in input_tensors: if input_tensor is not None: input_shape = input_tensor.get_shape().as_list() input_shape[p.batch_dim] = micro_batch_size input_shapes += (tf.TensorShape(input_shape),) else: input_shapes += (None,) if p.nested_map_fprop: input_shapes = py_utils.Pack(args[0], input_shapes) return input_shapes
def ToTensorShape(self): """Converts to a possibly partially specified tf.TensorShape.""" dims = [] for d in self._shape: if d.is_number and d.is_integer: dims.append(int(d)) else: dims.append(None) return tf.TensorShape(dims)
def _GetShapes(tensors, none_shapes=False): """Util for getting nested structure of shapes from structure of tensors. Args: tensors: Structure of Tensors to get shapes for. none_shapes: Returns None shapes if true. Returns: The same structure as tensors but of corresponding `TensorShape` objects. """ shapes = [] for t in tf.nest.flatten(tensors): shape = t.get_shape() if isinstance(t, tf.Tensor) else None if none_shapes: if shape: shapes.append(tf.TensorShape([None] * len(shape))) else: shapes.append(tf.TensorShape(None)) else: shapes.append(tf.TensorShape(shape)) return type(tensors)(tf.nest.pack_sequence_as(tensors, shapes))
def _CellFn(unused_theta, unused_state0, inputs): """Recurrent cell function wrapper of body.FProp.""" # Sets shapes for both theta and inputs to self.body.FProp. for dst, src in zip(inputs.args + inputs.theta.Flatten(), list(args) + theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp fprop_outputs = self.body.FProp(inputs.theta, *inputs.args) fprop_outputs = _ToTuple(fprop_outputs) assert len(fprop_outputs) == len(out_shapes) # Passes fprop outputs to the next layer through state. state1 = py_utils.NestedMap(outputs=list(fprop_outputs)) return state1, py_utils.NestedMap()
def _AssignVar(self, var_op): size = var_op.get_attr('dtype').size shape = tf.TensorShape(var_op.get_attr('shape')) assert self._var_space_pq, ('No ps devices to use.') allocated, device = heapq.heappop(self._var_space_pq) if shape.num_elements() is None: assert var_op.name.endswith( 'wb/var'), 'Unexpected name pattern: %s' % var_op.name # CuDNN RNN vars shape aren't known statically, decide to make a constant # estimate to avoid introducing more complexities. allocated += 10 * 1024**2 * size else: allocated += shape.num_elements() * size heapq.heappush(self._var_space_pq, (allocated, device)) tf.logging.info('Place variable %s on %s %d', var_op.name, device, allocated) return device
def _CellFn(unused_theta, state0, theta_i): """Recurrent cell function wrapper of body.FProp.""" # Retrieves fprop arguments from state and sets shapes. frop_inputs = _StateToArgs(state0) # Sets shapes for theta_i as well. for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp frop_outputs = self.body.FProp(theta_i, *frop_inputs) frop_outputs = _ToTuple(frop_outputs) assert len(frop_outputs) == len(frop_inputs) # Passes fprop outputs to the next layer through state. state1 = _ArgsToState(frop_outputs) return state1, py_utils.NestedMap()
def GreedySearchDecode(self, theta, encoder_outputs, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs greedy-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as src_batch_size. - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos> token is encountered during search. - hyp_lens: [num_hyps]. - done_hyps: [num_hyps], whether or not an eos is encountered. """ p = self.params if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, 1 # num_hyps_per_beam ) num_hyps = tf.shape(initial_results.log_probs)[0] if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) cur_step = tf.constant(0, dtype=tf.int32) done_hyps = inplace_ops.empty(shape=[num_hyps], dtype=tf.bool, init=True, name='done_hyps') hyp_lens = inplace_ops.empty(shape=[num_hyps], dtype=tf.int32, init=True, name='hyp_lens') hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps], dtype=tf.int32, init=True, name='hyp_ids') def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids, unused_hyp_lens, done_hyps, unused_other_states_list): return tf.math.logical_and( cur_step < max_steps, tf.math.logical_not(tf.reduce_all(done_hyps))) def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states_list): (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states) = self._GreedySearchStep( theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states.Pack(other_states_list), pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(step_ids.get_shape()), tf.TensorShape(hyp_ids.get_shape()), tf.TensorShape(hyp_lens.get_shape()), tf.TensorShape(done_hyps.get_shape()), _GetShapes(flat_other_states, none_shapes=True))) # transpose hyp_ids so it matches BeamSearchDecode's output final_hyp_ids = tf.transpose(final_hyp_ids) return final_hyp_ids, final_hyp_lens, final_done_hyps
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. Mostly opaque to BeamSearchHelper, except that it should contain either a 'seq_lengths' field of shape [source_batch_size] or a 'paddings' field of shape [source_max_lengths, source_batch_size]. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # Assume that `paddings` has shape [source_max_lengths, source_batch_size] # by default, and compute `encoded_seq_lengths` accordingly. This can be # overridden by directly passing `seq_lengths` in the `encoder_outputs` # NestedMap. encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None) if encoded_seq_lengths is None: source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum( 1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, encoded_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def _BeamSearchDecodeIds(self, theta, encoder_outputs, num_hyps_per_beam, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap computed by encoder. num_hyps_per_beam: Number of hyps per beam. init_beam_search_state: The InitBeamSearchState callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The PreBeamSearchStepCallback callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The PostBeamSearchStepCallback callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: hyps: A tensor of shape [time, b * k] with ids of the token selected. prev_hyps: A tensor of shape [time, b * k] with index to the previous hyps which was selected. done_hyps: A boolean tensor of shape [time, b * k] where value indicates if hyps was terminated. scores: A tensor of shape [time, b * k] with scores of the token selected. atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. eos_scores: A tensor of shape [time, b * k] with scores of the eos token selected. eos_atten_probs: A tensor of shape [time, b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. source_seq_lengths: A tensor of shape [time] containing the source seq_lengths. flat_final_other_states: A array of tensors that are part of other states. """ p = self.params source_paddings = encoder_outputs.padding initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam # We cache the NestedMap as member variable so that we can use it to # pack the final outputs. Tpu rewrite methods forces us to strictly pass # in Tensors, and output Tensors self._other_states = other_states step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 fprop_dtype = py_utils.FPropDtype(p) best_scores = (tf.zeros(shape=[num_beams], dtype=fprop_dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=fprop_dtype) histories = tf.zeros(shape=[num_hyps], dtype=tf.int32) in_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_prev_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_done_hyps = tf.TensorArray(dtype=tf.int32, size=max_steps) in_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_scores = tf.TensorArray(dtype=fprop_dtype, size=max_steps) in_eos_atten_probs = tf.TensorArray(dtype=fprop_dtype, size=max_steps) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) # States for beam search that are inputs into Beam search step. accum_bs_states = [best_scores, cumulative_scores, histories] # States that are not accumulators. non_accum_bs_states = [ in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, in_eos_scores, in_eos_atten_probs, ] core_bs_states = tuple(accum_bs_states + non_accum_bs_states) flat_other_states = other_states.Flatten() # If there is an optimized implementation for short sequence, LoopBodyShort # will run first for short_seq_limit steps (after which the # LoopBodyShort does not have performance benefit). Then LoopBodyLong (the # default implementation) is used to continue the rest of the steps. For # decoders which do not have the short sequence specific implementation, # only the LoopBodyLong (the default implementation) will run. if p.short_seq_limit > 0: def LoopContinueShort(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Use short_seq optimization when cur_step is smaller than limit.""" return tf.math.logical_and(cur_step < p.short_seq_limit, tf.math.logical_not(all_done)) def LoopBodyShort(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of short_seq optimization. Instead of doing computation for the entire padded sequence, while loop with early exit is used within each _BeamSearchStep to do computation for only the actual sequence (seq_length <= cur_step). use_short_seq_opt is used as the flag to pass this information down to the decoder implementation. Args: cur_step: A scalar int tensor, the current time step, 0-based. unused_all_done: A tf.bool, indicating whether the decoding finishes. step_ids: An int32 tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. other_states_list: A flattened NestedMap of other beam search states. Returns: The updated input tuple, with the same shape. """ (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=True) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) (cur_step, all_done, step_ids, core_bs_states, flat_other_states) = tf.while_loop( LoopContinueShort, LoopBodyShort, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=True)), maximum_iterations=max_steps) def LoopContinueLong(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): """Continue default implementation until decoding finishes.""" return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBodyLong(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): """Loop body of default long_seq implementation.""" (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback, use_short_seq_opt=False) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinueLong, LoopBodyLong, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=( tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), tuple( list(_GetShapes(accum_bs_states)) + list(_GetShapes(non_accum_bs_states, none_shapes=True))), _GetShapes(flat_other_states, none_shapes=False)), maximum_iterations=max_steps) if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), dtype=tf.int32) else: source_seq_lengths = tf.cast(tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), dtype=tf.int32) # Concatenate all outputs on axis=0. scores = final_bs_states[3].stack() hyps = final_bs_states[4].stack() prev_hyps = final_bs_states[5].stack() done_hyps = tf.cast(final_bs_states[6].stack(), tf.bool) atten_probs = final_bs_states[7].stack() eos_scores = final_bs_states[8].stack() eos_atten_probs = final_bs_states[9].stack() rets = (hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, source_seq_lengths) # TODO(rohananil): Only send a single R1 tensor to host instead of 3 after # b/111131551 is resolved. # Canonical shapes for tensors of various. ranks r_shapes = [ py_utils.GetShape(source_seq_lengths), py_utils.GetShape(hyps), py_utils.GetShape(atten_probs) ] # Reshape all tensors to [-1] to avoid cost of copy due to padding. rets_r1 = [tf.reshape(r, [-1]) for r in rets] return tuple(r_shapes) + tuple(rets_r1) + tuple( flat_final_other_states)
def FProp(self, theta, *args): """Runs p.repeat copies of self.body.FProp independently. Args: theta: Layer model parameters. The shape of each variable in theta is always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the i-th copy of self.body. *args: Input arguments. The shape of each tensor in args is always [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs to the i-th copy of self.body.FProp. Returns: The accumulated output_tensors. Each tensor t in the return has the shape [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the return tuple of the i-th self.body.FProp. """ p = self.params for arg in args: if arg is not None: arg = py_utils.HasShape(arg, [p.repeat], ndims=1) theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) inputs = py_utils.NestedMap(theta=theta_stack, args=list(args)) # Infer out_shapes from FPropMeta. out_shapes = self._InferOutShapes(args) def _CellFn(unused_theta, unused_state0, inputs): """Recurrent cell function wrapper of body.FProp.""" # Sets shapes for both theta and inputs to self.body.FProp. for dst, src in zip(inputs.args + inputs.theta.Flatten(), list(args) + theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp fprop_outputs = self.body.FProp(inputs.theta, *inputs.args) fprop_outputs = _ToTuple(fprop_outputs) assert len(fprop_outputs) == len(out_shapes) # Passes fprop outputs to the next layer through state. state1 = py_utils.NestedMap(outputs=list(fprop_outputs)) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Initiate state0 with inferred output shapes. state0 = py_utils.NestedMap(outputs=[ tf.zeros(shape, args[0].dtype) for shape in out_shapes ]) # Runs body.FProp p.repeat times using Recurrent. acc_states, _ = recurrent.Recurrent(theta=py_utils.NestedMap(), state0=state0, inputs=inputs, cell_fn=_CellFn) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = tuple(acc_states.outputs) for out_idx in range(len(output_tensors)): output_tensors[out_idx].set_shape( tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list())) return output_tensors[0] if len(args) == 1 else tuple( output_tensors)