def __init__(self, inputs, sequence_length, sampling_probability, time_major=False, seed=None, next_inputs_fn=None, auxiliary_inputs=None, name=None): """Initializer. Args: inputs: A (structure) of input tensors. sequence_length: An int32 vector tensor. sampling_probability: A 0D `float32` tensor: the probability of sampling from the outputs instead of reading directly from the inputs. time_major: Python bool. Whether the tensors in `inputs` are time major. If `False` (default), they are assumed to be batch major. seed: The sampling seed. next_inputs_fn: (Optional) callable to apply to the RNN outputs to create the next input when sampling. If `None` (default), the RNN outputs will be used as the next inputs. auxiliary_inputs: An optional (structure of) auxiliary input tensors with a shape that matches `inputs` in all but (potentially) the final dimension. These tensors will be concatenated to the sampled output or the `inputs` when not sampling for use as the next input. name: Name scope for any created operations. Raises: ValueError: if `sampling_probability` is not a scalar or vector. """ with ops.name_scope(name, "ScheduledOutputTrainingHelper", [inputs, auxiliary_inputs, sampling_probability]): self._sampling_probability = ops.convert_to_tensor( sampling_probability, name="sampling_probability") if self._sampling_probability.get_shape().ndims not in (0, 1): raise ValueError( "sampling_probability must be either a scalar or a vector. " "saw shape: %s" % (self._sampling_probability.get_shape())) if auxiliary_inputs is None: maybe_concatenated_inputs = inputs else: inputs = ops.convert_to_tensor(inputs, name="inputs") auxiliary_inputs = ops.convert_to_tensor( auxiliary_inputs, name="auxiliary_inputs") maybe_concatenated_inputs = nest.map_structure( lambda x, y: array_ops.concat((x, y), -1), inputs, auxiliary_inputs) if not time_major: auxiliary_inputs = nest.map_structure( _transpose_batch_time, auxiliary_inputs) self._auxiliary_input_tas = ( nest.map_structure(_unstack_ta, auxiliary_inputs) if auxiliary_inputs is not None else None) self._seed = seed self._next_inputs_fn = next_inputs_fn super(ScheduledOutputTrainingHelper, self).__init__( inputs=maybe_concatenated_inputs, sequence_length=sequence_length, time_major=time_major, name=name)
def body(time, outputs_ta, state, inputs, finished, sequence_lengths): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: bool tensor (keeping track of what's finished). sequence_lengths: int32 tensor (keeping track of time of finish). Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)`. ``` """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) next_finished = math_ops.logical_or(decoder_finished, finished) if maximum_iterations is not None: next_finished = math_ops.logical_or( next_finished, time + 1 >= maximum_iterations) next_sequence_lengths = array_ops.where( math_ops.logical_and(math_ops.logical_not(finished), next_finished), array_ops.fill(array_ops.shape(sequence_lengths), time + 1), sequence_lengths) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish if impute_finished: emit = nest.map_structure( lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tensor_array_ops.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else array_ops.where(finished, cur, new) if impute_finished: next_state = nest.map_structure( _maybe_copy_state, decoder_state, state) else: next_state = decoder_state outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)
def body(time, outputs_ta, state, inputs, finished): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: 1-D bool tensor. Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`. """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step( time, inputs, state) next_finished = math_ops.logical_or(decoder_finished, finished) nest.assert_same_structure(state, decoder_state) nest.assert_same_structure(outputs_ta, next_outputs) nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish emit = nest.map_structure( lambda out, zero: array_ops.where(finished, zero, out), next_outputs, zero_outputs) # Copy through states past finish def _maybe_copy_state(new, cur): return (new if isinstance(cur, tensor_array_ops.TensorArray) else array_ops.where(finished, cur, new)) next_state = nest.map_structure(_maybe_copy_state, decoder_state, state) outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
def get_next(self, name=None): """Return PerIteration with `iterations x batches_per_iteration` inputs.""" data = [] for _ in range(self._batches_per_iteration): batch = [] for _ in range(self._iterations): batch.append(self._dataset_iterator.get_next(name=name)) data.append(batch) # Here is an example. Suppose each get_next returns a tuple of two tensors. # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is: # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]] # # After the first `map_structure` it gets transformed to: # [(Batches(a, A), Batches(z, Z)), # (Batches(b, B), Batches(y, Y)), # (Batches(c, C), Batches(x, X))] # # After the second `map_structure` it gets transformed to a tuple of: # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]), # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)])) data = nest.map_structure(Batches, *data) data = nest.map_structure(PerIteration, *data) return data
def testMemoryIsFreed(self): # Note: we use `set` values for components and metadata because we need # to construct weakrefs to them. Other builtin types, such as `list` and # `tuple`, do not support weakrefs. ct1 = CT(set([1, 2]), set(['no', 'leaks'])) ct2 = CT(set([3, 4]), set(['no', 'leaks'])) ct3 = CT(set([5, 6]), set(['other', 'metadata'])) # Note: map_structure exercises flatten, pack_sequence_as, and # assert_same_structure. func = lambda x, y: x | y ct4 = nest.map_structure(func, ct1, ct2, expand_composites=True) # Check that the exception-raising path in assert_same_structure # doesn't leak any objects. with self.assertRaisesRegexp(ValueError, ".*don't have the same nested structure.*"): nest.map_structure(func, ct2, ct3, expand_composites=True) if hasattr(sys, 'exc_clear'): sys.exc_clear() # Remove any references in exception stack traces. refs = [] for ct in [ct1, ct2, ct3, ct4]: refs.append(weakref.ref(ct)) refs.append(weakref.ref(ct.components)) refs.append(weakref.ref(ct.metadata)) del ct # pylint: disable=undefined-loop-variable for ref in refs: self.assertIsNotNone(ref()) del ct1, ct2, ct3, ct4 gc.collect() for ref in refs: self.assertIsNone(ref())
def generate_synthetic_data( input_shape, input_value=0, input_dtype=None, label_shape=None, label_value=0, label_dtype=None): """Create a repeating dataset with constant values. Args: input_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of the input data. input_value: Value of each input element. input_dtype: Input dtype. If None, will be inferred by the input value. label_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of the label data. label_value: Value of each input element. label_dtype: Input dtype. If None, will be inferred by the target value. Returns: Dataset of tensors or tuples of tensors (if label_shape is set). """ # TODO(kathywu): Replace with SyntheticDataset once it is in contrib. element = input_element = nest.map_structure( lambda s: tf.constant(input_value, input_dtype, s), input_shape) if label_shape: label_element = nest.map_structure( lambda s: tf.constant(label_value, label_dtype, s), label_shape) element = (input_element, label_element) return tf.data.Dataset.from_tensors(element).repeat()
def __init__(self, inputs, sequence_length, time_major=False, name=None): """Initializer. Args: inputs: A (structure of) input tensors. sequence_length: An int32 vector tensor. time_major: Python bool. Whether the tensors in `inputs` are time major. If `False` (default), they are assumed to be batch major. name: Name scope for any created operations. Raises: ValueError: if `sequence_length` is not a 1D tensor. """ with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]): inputs = ops.convert_to_tensor(inputs, name="inputs") if not time_major: inputs = nest.map_structure(_transpose_batch_time, inputs) self._input_tas = nest.map_structure(_unstack_ta, inputs) self._sequence_length = ops.convert_to_tensor( sequence_length, name="sequence_length") if self._sequence_length.get_shape().ndims != 1: raise ValueError( "Expected sequence_length to be a vector, but received shape: %s" % self._sequence_length.get_shape()) self._zero_inputs = nest.map_structure( lambda inp: array_ops.zeros_like(inp[0, :]), inputs) self._batch_size = array_ops.size(sequence_length)
def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor inputs: A input tensors state: A state tensors and TensorArrays name: Name scope for any created operations Returns: outputs: An instance of `BeamSearchDecoderOutput` next_state: A state tensors and TensorArrays next_inputs: The tensor that should be used as input for the next step finished: A boolean tensor telling whether the sequence is complete, for each sequence in the batch """ print('===== step (beam search) =====') decoder_state, beam_state = state # Call the original decoder decoder_output, decoder_state, _, _ = self.decoder.step(time, inputs, decoder_state) # Perform a step of beam search beam_search_output, beam_state = beam_search_step( time=time, logits=decoder_output.logits, beam_state=beam_state, beam_width=self.beam_width, vocab_size=self.vocab_size, eos_index=self.eos_index, length_penalty_weight=self.length_penalty_weight, choose_successors_fn=self.choose_successors_fn) # Shuffle everything according to beam search result decoder_state = nest.map_structure( lambda x: tf.gather(x, beam_search_output.beam_parent_ids), decoder_state) decoder_output = nest.map_structure( lambda x: tf.gather(x, beam_search_output.beam_parent_ids), decoder_output) next_state = (decoder_state, beam_state) outputs = BeamSearchDecoderOutput( logits=tf.zeros([self.beam_width, self.vocab_size]), predicted_ids=beam_search_output.predicted_ids, log_probs=beam_state.log_probs, scores=beam_search_output.scores, beam_parent_ids=beam_search_output.beam_parent_ids, original_outputs=decoder_output) finished, next_inputs, next_state = self.decoder.helper.next_inputs( time=time, outputs=decoder_output, state=next_state, sample_ids=beam_search_output.predicted_ids) next_inputs.set_shape([self.batch_size, None]) return outputs, next_state, next_inputs, finished
def _build(self, inputs, prev_state): """Connects the DeepRNN module into the graph. If this is not the first time the module has been connected to the graph, the Tensors provided as input_ and state must have the same final dimension, in order for the existing variables to be the correct size for their corresponding multiplications. The batch size may differ for each connection. Args: inputs: a nested tuple of Tensors of arbitrary dimensionality, with at least an initial batch dimension. prev_state: a tuple of `prev_state`s that corresponds to the state of each one of the cores of the `DeepCore`. Returns: output: a nested tuple of Tensors of arbitrary dimensionality, with at least an initial batch dimension. next_state: a tuple of `next_state`s that corresponds to the updated state of each one of the cores of the `DeepCore`. Raises: ValueError: if connecting the module into the graph any time after the first time, and the inferred size of the inputs does not match previous invocations. This may happen if one connects a module any time after the first time that does not have the configuration of skip connections as the first time. """ current_input = inputs next_states = [] outputs = [] recurrent_idx = 0 concatenate = lambda *args: tf.concat(args, axis=-1) for i, core in enumerate(self._cores): if self._skip_connections and i > 0: current_input = nest.map_structure(concatenate, inputs, current_input) # Determine if this core in the stack is recurrent or not and call # accordingly. if self._is_recurrent_list[i]: current_input, next_state = core(current_input, prev_state[recurrent_idx]) next_states.append(next_state) recurrent_idx += 1 else: current_input = core(current_input) if self._skip_connections: outputs.append(current_input) if self._skip_connections and self._concat_final_output_if_skip: output = nest.map_structure(concatenate, *outputs) else: output = current_input self._last_output_size = _get_shape_without_batch_dimension(output) return output, tuple(next_states)
def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): """Convert to tensor and possibly mask `memory`. Args: memory: `Tensor`, shaped `[batch_size, max_time, ...]`. memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. check_inner_dims_defined: Python boolean. If `True`, the `memory` argument's shape is checked to ensure all but the two outermost dimensions are fully defined. Returns: A (possibly masked), checked, new `memory`. Raises: ValueError: If `check_inner_dims_defined` is `True` and not `memory.shape[2:].is_fully_defined()`. """ memory = nest.map_structure( lambda m: ops.convert_to_tensor(m, name="memory"), memory) if memory_sequence_length is not None: memory_sequence_length = ops.convert_to_tensor( memory_sequence_length, name="memory_sequence_length") if check_inner_dims_defined: def _check_dims(m): if not m.get_shape()[2:].is_fully_defined(): raise ValueError("Expected memory %s to have fully defined inner dims, " "but saw shape: %s" % (m.name, m.get_shape())) nest.map_structure(_check_dims, memory) if memory_sequence_length is None: seq_len_mask = None else: seq_len_mask = array_ops.sequence_mask( memory_sequence_length, maxlen=array_ops.shape(nest.flatten(memory)[0])[1], dtype=nest.flatten(memory)[0].dtype) seq_len_batch_size = ( memory_sequence_length.shape[0].value or array_ops.shape(memory_sequence_length)[0]) def _maybe_mask(m, seq_len_mask): rank = m.get_shape().ndims rank = rank if rank is not None else array_ops.rank(m) extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) m_batch_size = m.shape[0].value or array_ops.shape(m)[0] if memory_sequence_length is not None: message = ("memory_sequence_length and memory tensor batch sizes do not " "match.") with ops.control_dependencies([ check_ops.assert_equal( seq_len_batch_size, m_batch_size, message=message)]): seq_len_mask = array_ops.reshape( seq_len_mask, array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) return m * seq_len_mask else: return m return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
def testAssertions(self): a = tracking.Checkpointable() a.l = {"k": [numpy.zeros([2, 2])]} self.assertAllEqual(nest.flatten({"k": [numpy.zeros([2, 2])]}), nest.flatten(a.l)) self.assertAllClose({"k": [numpy.zeros([2, 2])]}, a.l) nest.map_structure(self.assertAllClose, a.l, {"k": [numpy.zeros([2, 2])]}) a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]} self.assertAllClose({"k": [numpy.ones([2, 2]), numpy.zeros([3, 3])]}, self.evaluate(a.tensors))
def test_autolambda(self, model_fn): inputs, outputs = model_fn() model = keras.Model(inputs, outputs) model.compile( adam.Adam(0.001), 'mse', run_eagerly=testing_utils.should_run_eagerly()) np_inputs = nest.map_structure(lambda x: np.ones((10, 10), 'float32'), inputs) np_outputs = nest.map_structure(lambda x: np.ones((10, 10), 'float32'), outputs) model.fit(np_inputs, np_outputs, batch_size=2)
def _reset_padding(self, memory, memory_sequence_length, check_inner_dims_defined=True): """Reset the padding part for encoder inputs. This funtion comes from tensorflow's `_prepare_memory` function. """ memory = nest.map_structure( lambda m: ops.convert_to_tensor(m, name="memory"), memory) if memory_sequence_length is not None: memory_sequence_length = ops.convert_to_tensor( memory_sequence_length, name="memory_sequence_length") if check_inner_dims_defined: def _check_dims(m): if not m.get_shape()[2:].is_fully_defined(): raise ValueError( "Expected memory %s to have fully defined inner dims, " "but saw shape: %s" % (m.name, m.get_shape())) nest.map_structure(_check_dims, memory) if memory_sequence_length is None: seq_len_mask = None else: seq_len_mask = array_ops.sequence_mask( memory_sequence_length, maxlen=array_ops.shape(nest.flatten(memory)[0])[1], dtype=nest.flatten(memory)[0].dtype) seq_len_batch_size = (memory_sequence_length.shape[0].value or array_ops.shape(memory_sequence_length)[0]) def _maybe_mask(m, seq_len_mask): rank = m.get_shape().ndims rank = rank if rank is not None else array_ops.rank(m) extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) m_batch_size = m.shape[0].value or array_ops.shape(m)[0] if memory_sequence_length is not None: message = ("memory_sequence_length and memory tensor " "batch sizes do not match.") with ops.control_dependencies([ check_ops.assert_equal( seq_len_batch_size, m_batch_size, message=message) ]): seq_len_mask = array_ops.reshape( seq_len_mask, array_ops.concat( (array_ops.shape(seq_len_mask), extra_ones), 0)) return m * seq_len_mask else: return m return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure(self._merge_batch_beams, inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state) try: cell_outputs, next_cell_state = self._cell( inputs, cell_state, tiling_factor=beam_width) except TypeError as e: if "unexpected keyword argument 'tiling_factor'" in str(e): cell_outputs, next_cell_state = self._cell(inputs, cell_state) else: raise cell_outputs = nest.map_structure(self._split_batch_beams, cell_outputs) next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) return (beam_search_output, beam_search_state, next_inputs, finished)
def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): """Concatenate outputs with auxiliary inputs, if they exist.""" if self._auxiliary_input_tas is None: return outputs_ next_time = time + 1 auxiliary_inputs = nest.map_structure( lambda ta: ta.read(next_time), self._auxiliary_input_tas) if indices is not None: auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices) return nest.map_structure( lambda x, y: array_ops.concat((x, y), -1), outputs_, auxiliary_inputs)
def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight coverage_penalty_weight = self._coverage_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell(inputs, cell_state) cell_outputs = nest.map_structure( lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) next_cell_state = nest.map_structure( self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) if self._output_layer is not None: cell_outputs = self._output_layer(cell_outputs) beam_search_output, beam_search_state = _beam_search_step( time=time, logits=cell_outputs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight, coverage_penalty_weight=coverage_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) return (beam_search_output, beam_search_state, next_inputs, finished)
def mark_checked(tensors): """Marks that these Tensors should not be tracked. This prevents Layers from attempting to create TensorFlowOpLayers for these Tensors. Arguments: tensors: An arbitrary structure of Tensors. """ def _mark_checked(tensor): tensor._keras_history_checked = True # pylint: disable=protected-access nest.map_structure(_mark_checked, tensors)
def _prepare_feed_values(model, inputs, targets, sample_weights, mode): """Prepare feed values to the model execution function. Arguments: model: Model to prepare feed values for. inputs: List or dict of model inputs. targets: Optional list of model targets. sample_weights: Optional list of sample weight arrays. mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT. Returns: Feed values for the model in the given mode. """ strategy = model._distribution_strategy inputs, targets, sample_weights = _get_input_from_iterator(inputs, model) inputs = flatten_perdevice_values(strategy, inputs) targets = flatten_perdevice_values(strategy, targets) # Expand 1-dimensional inputs. # TODO(b/124535720): Remove once this standarize data logic is shared with # main flow. inputs, targets = nest.map_structure(training_utils.standardize_single_array, (inputs, targets)) if mode == ModeKeys.PREDICT: sample_weights = [] targets = [] else: sample_weights = [ None for _ in range(len(model.outputs) * strategy.num_replicas_in_sync) ] ins = inputs + targets + sample_weights if mode == ModeKeys.TRAIN and not isinstance(K.symbolic_learning_phase(), int): ins += [True] return ins
def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): if self._initial_cell_state is not None: cell_state = self._initial_cell_state else: cell_state = self._cell.zero_state(batch_size, dtype) error_message = ( "When calling zero_state of AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and the requested batch size. Are you using " "the BeamSearchDecoder? If so, make sure your encoder output has " "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and " "the batch_size= argument passed to zero_state is " "batch_size * beam_width.") with ops.control_dependencies( [check_ops.assert_equal(batch_size, self._attention_mechanism.batch_size, message=error_message)]): cell_state = nest.map_structure( lambda s: array_ops.identity(s, name="checked_cell_state"), cell_state) if self._alignment_history: alignment_history = tensor_array_ops.TensorArray( dtype=dtype, size=0, dynamic_size=True) else: alignment_history = () return AttentionWrapperState( cell_state=cell_state, time=array_ops.zeros([], dtype=dtypes.int32), attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype), alignments=self._attention_mechanism.initial_alignments( batch_size, dtype), alignment_history=alignment_history)
def _enumerated_map_structure(map_fn, *args, **kwargs): ix = [0] def enumerated_fn(*inner_args, **inner_kwargs): r = map_fn(ix[0], *inner_args, **inner_kwargs) ix[0] += 1 return r return nest.map_structure(enumerated_fn, *args, **kwargs)
def _model_start_state_placeholders( self, batch_size_tensor, static_batch_size=None): """Creates placeholders with zeroed start state for the current model.""" gathered_state = {} # Models may not know the shape of their state without creating some # variables/ops. Avoid polluting the default graph by making a new one. We # use only static metadata from the returned Tensors. with ops.Graph().as_default(): self._model.initialize_graph() # Evaluate the initial state as same-dtype "zero" values. These zero # constants aren't used, but are necessary for feeding to # placeholder_with_default for the "cold start" case where state is not # fed to the model. def _zeros_like_constant(tensor): return tensor_util.constant_value(array_ops.zeros_like(tensor)) start_state = nest.map_structure( _zeros_like_constant, self._model.get_start_state()) for prefixed_state_name, state in ts_head_lib.state_to_dictionary( start_state).items(): state_shape_with_batch = tensor_shape.TensorShape( (static_batch_size,)).concatenate(state.shape) default_state_broadcast = array_ops.tile( state[None, ...], multiples=array_ops.concat( [batch_size_tensor[None], array_ops.ones(len(state.shape), dtype=dtypes.int32)], axis=0)) gathered_state[prefixed_state_name] = array_ops.placeholder_with_default( input=default_state_broadcast, name=prefixed_state_name, shape=state_shape_with_batch) return gathered_state
def _RawNestedShape(nested_shape): def _RawShape(shape): if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None: return [x.value for x in shape] else: return None return nest.map_structure(_RawShape, nested_shape)
def _gather_beams(nested, beam_indices, batch_size, new_beam_size): """Gather beams from nested structure of tensors. Each tensor in nested represents a batch of beams, where beam refers to a single search state (beam search involves searching through multiple states in parallel). This function is used to gather the top beams, specified by beam_indices, from the nested tensors. Args: nested: Nested structure (tensor, list, tuple or dict) containing tensors with shape [batch_size, beam_size, ...]. beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each value in beam_indices must be between [0, beam_size), and are not necessarily unique. batch_size: int size of batch new_beam_size: int number of beams to be pulled from the nested tensors. Returns: Nested structure containing tensors with shape [batch_size, new_beam_size, ...] """ # Computes the i'th coodinate that contains the batch index for gather_nd. # Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..]. batch_pos = tf.range(batch_size * new_beam_size) // new_beam_size batch_pos = tf.reshape(batch_pos, [batch_size, new_beam_size]) # Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor # with shape [batch_size, beam_size, 2], where the last dimension contains # the (i, j) gathering coordinates. coordinates = tf.stack([batch_pos, beam_indices], axis=2) return nest.map_structure( lambda state: tf.gather_nd(state, coordinates), nested)
def output_size(self): # Return the cell output and the id prepend_beam_width = ( lambda s: tensor_shape.TensorShape([self._beam_width]).concatenate(s)) return BeamSearchDecoderOutput( rnn_output=nest.map_structure( prepend_beam_width, self._rnn_output_size()))
def _update_non_slot(self, colocate_with, fn, args, kwargs, group): del colocate_with result = fn(*args, **kwargs) if group: return result else: return nest.map_structure(self._unwrap, result)
def body_infer(time, inputs, caches, outputs_tas, finished, log_probs, lengths, bs_stat_ta, predicted_ids): """Internal while_loop body. Args: time: Scalar int32 Tensor. inputs: A list of inputs Tensors. caches: A dict of decoder states. outputs_tas: A list of TensorArrays. finished: A bool tensor (keeping track of what's finished). log_probs: The log probability Tensor. lengths: The decoding length Tensor. bs_stat_ta: structure of TensorArray. predicted_ids: A Tensor. Returns: `(time + 1, next_inputs, next_caches, next_outputs_tas, next_finished, next_log_probs, next_lengths, next_infer_status_ta)`. """ # step decoder def _decoding(_decoder, _input, _cache, _decoder_output_remover, _outputs_ta, _outputs_to_logits_fn): with tf.variable_scope(_decoder.name): _output, _next_cache = _decoder.step(_input, _cache) _decoder_top_features = _decoder.merge_top_features(_output) _ta = nest.map_structure(lambda _ta_ms, _output_ms: _ta_ms.write(time, _output_ms), _outputs_ta, _decoder_output_remover.apply(_output)) _logit = _outputs_to_logits_fn(_decoder_top_features) return _output, _next_cache, _ta, _logit outputs, next_caches, next_outputs_tas, logits = repeat_n_times( num_models, _decoding, decoders, inputs, caches, decoder_output_removers, outputs_tas, outputs_to_logits_fns) # sample next symbols sample_ids, beam_ids, next_log_probs, next_lengths \ = helper.sample_symbols(logits, log_probs, finished, lengths, time=time) for c in next_caches: c["decoding_states"] = gather_states(c["decoding_states"], beam_ids) infer_status = BeamSearchStateSpec( log_probs=next_log_probs, beam_ids=beam_ids) bs_stat_ta = nest.map_structure(lambda ta, out: ta.write(time, out), bs_stat_ta, infer_status) predicted_ids = gather_states(tf.reshape(predicted_ids, [-1, time + 1]), beam_ids) next_predicted_ids = tf.concat([predicted_ids, tf.expand_dims(sample_ids, axis=1)], axis=1) next_predicted_ids = tf.reshape(next_predicted_ids, [-1]) next_predicted_ids.set_shape([None]) next_finished, next_input_symbols = helper.next_symbols(time=time, sample_ids=sample_ids) next_inputs = repeat_n_times(num_models, target_to_embedding_fns, next_input_symbols, time + 1) next_finished = tf.logical_or(next_finished, finished) return time + 1, next_inputs, next_caches, next_outputs_tas, \ next_finished, next_log_probs, next_lengths, bs_stat_ta, \ next_predicted_ids
def tile_batch(t, multiplier, name=None): """Tile the batch dimension of a (possibly nested structure of) tensor(s) t. For each tensor t in a (possibly nested structure) of tensors, this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated `multiplier` times. Args: t: `Tensor` shaped `[batch_size, ...]`. multiplier: Python int. name: Name scope for any created operations. Returns: A (possibly nested structure of) `Tensor` shaped `[batch_size * multiplier, ...]`. Raises: ValueError: if tensor(s) `t` do not have a statically known rank or the rank is < 1. """ flat_t = nest.flatten(t) with ops.name_scope(name, "tile_batch", flat_t + [multiplier]): return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
def finalize(self, outputs, final_state, sequence_lengths): """Finalize and return the predicted_ids. Args: outputs: An instance of BeamSearchDecoderOutput. final_state: An instance of BeamSearchDecoderState. Passed through to the output. sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. **NOTE** These are ignored; the updated sequence lengths are stored in `final_state.lengths`. Returns: outputs: An instance of `FinalBeamSearchDecoderOutput` where the predicted_ids are the result of calling _gather_tree. final_state: The same input instance of `BeamSearchDecoderState`. """ del sequence_lengths # Get max_sequence_length across all beams for each batch. max_sequence_lengths = math_ops.to_int32( math_ops.reduce_max(final_state.lengths, axis=1)) predicted_ids = beam_search_ops.gather_tree( outputs.predicted_ids, outputs.parent_ids, max_sequence_lengths=max_sequence_lengths, end_token=self._end_token) if self._reorder_tensor_arrays: final_state = final_state._replace(cell_state=nest.map_structure( lambda t: self._maybe_sort_array_beams( t, outputs.parent_ids, final_state.lengths), final_state.cell_state)) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state
def compute_output_shape(self, input_shape): if self._output_shape is None: # Make use of existing autocomputation but provide Lambda-specific # error message. This is always safe to run even when the outer context # is Graph mode because Lambda layers don't have side effects such as # `add_loss`. with context.eager_mode(): try: return super(Lambda, self).compute_output_shape(input_shape) except NotImplementedError: raise NotImplementedError( 'We could not automatically infer the shape of the Lambda\'s ' 'output. Please specify `output_shape` for this Lambda.') if callable(self._output_shape): output_shapes = self._output_shape(input_shape) return tf_utils.convert_shapes(output_shapes, to_tuples=False) # Output shapes are passed directly and don't include batch dimension. input_tensor_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) batch_size = nest.flatten(input_tensor_shape)[0][0] if input_shape else None def _add_batch(shape): return tensor_shape.TensorShape([batch_size] + shape.as_list()) output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False) return nest.map_structure(_add_batch, output_shapes)
def output_dtype(self): # Assume the dtype of the cell is the output_size structure # containing the input_state's first component's dtype. # Return that structure and int32 (the id) dtype = nest.flatten(self._initial_cell_state)[0].dtype return BeamSearchDecoderOutput( rnn_output=nest.map_structure(lambda _: dtype, self._rnn_output_size()))
def __init__(self, inputs, sequence_length, sampling_probability, time_major=False, seed=None, next_input_layer=None, auxiliary_inputs=None, name=None): """Initializer. Args: inputs: A (structure) of input tensors. sequence_length: An int32 vector tensor. sampling_probability: A 0D `float32` tensor: the probability of sampling from the outputs instead of reading directly from the inputs. time_major: Python bool. Whether the tensors in `inputs` are time major. If `False` (default), they are assumed to be batch major. seed: The sampling seed. next_input_layer: (Optional) An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer to apply to the RNN output to create the next input. auxiliary_inputs: An optional (structure of) auxiliary input tensors with a shape that matches `inputs` in all but (potentially) the final dimension. These tensors will be concatenated to the sampled output or the `inputs` when not sampling for use as the next input. name: Name scope for any created operations. Raises: ValueError: if `sampling_probability` is not a scalar or vector. """ with ops.name_scope(name, "ScheduledOutputTrainingHelper", [inputs, auxiliary_inputs, sampling_probability]): self._sampling_probability = ops.convert_to_tensor( sampling_probability, name="sampling_probability") if self._sampling_probability.get_shape().ndims not in (0, 1): raise ValueError( "sampling_probability must be either a scalar or a vector. " "saw shape: %s" % (self._sampling_probability.get_shape())) if auxiliary_inputs is None: maybe_concatenated_inputs = inputs else: inputs = ops.convert_to_tensor(inputs, name="inputs") auxiliary_inputs = ops.convert_to_tensor( auxiliary_inputs, name="auxiliary_inputs") maybe_concatenated_inputs = nest.map_structure( lambda x, y: array_ops.concat((x, y), -1), inputs, auxiliary_inputs) if not time_major: auxiliary_inputs = nest.map_structure( _transpose_batch_time, auxiliary_inputs) self._auxiliary_input_tas = ( nest.map_structure(_unstack_ta, auxiliary_inputs) if auxiliary_inputs is not None else None) self._seed = seed if (next_input_layer is not None and not isinstance(next_input_layer, layers_base.Layer)): raise TypeError("next_input_layer must be a Layer, received: %s" % type(next_input_layer)) self._next_input_layer = next_input_layer super(ScheduledOutputTrainingHelper, self).__init__( inputs=maybe_concatenated_inputs, sequence_length=sequence_length, time_major=time_major, name=name)
def _graph_mode_decorator(f, args, kwargs): """Implement custom gradient decorator for graph mode.""" # TODO(rsepassi): Add support for kwargs if kwargs: raise ValueError( "The custom_gradient decorator currently supports keywords " "arguments only when eager execution is enabled.") name = generate_name() args = nest.map_structure(ops.convert_to_tensor, args) # Checking global and local variables attempts to ensure that no non-resource # Variables are added to the graph. current_var_scope = variable_scope.get_variable_scope() before_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) with tape_lib.VariableWatcher() as variable_watcher: result, grad_fn = f(*args) args = nest.flatten(args) flat_result = nest.flatten(result) flat_result_len = len(flat_result) after_vars = set([ v.ref() for v in current_var_scope.global_variables() + current_var_scope.local_variables() ]) new_vars = after_vars - before_vars new_vars_list = [v.deref() for v in new_vars] for v in new_vars_list: if not resource_variable_ops.is_resource_variable(v): raise TypeError( "All variables used by a function wrapped with @custom_gradient must " "be `ResourceVariable`s. Ensure that no `variable_scope` is created " "with `use_resource=False`.") # The variables that grad_fn needs to return gradients for are the set of # variables used that are *not* part of the inputs. variables_in_tape = frozenset( [v.ref() for v in variable_watcher.watched_variables()]) graphs = {getattr(o, "graph", None) for o in flat_result} # Not all results may be tensors. However, we want to ensure all tensor # outputs are from the same graph and get a list of captured inputs for # variable search graphs.discard(None) # Discard non-graph outputs if graphs: if len(graphs) > 1: raise ValueError( "All custom_gradient outputs should be from the same graph") output_graph = graphs.pop() filtered_input_tensors = [] for i in args: if i.graph == output_graph: filtered_input_tensors.append(i) else: filtered_input_tensors = args variables_in_subgraph = frozenset([ v.ref() for v in _get_dependent_variables(input_ops=filtered_input_tensors, output_ops=flat_result) ]) variables = sorted( [v.deref() for v in variables_in_subgraph.union(variables_in_tape)], key=lambda v: v.name) grad_argspec = tf_inspect.getfullargspec(grad_fn) variables_in_signature = ("variables" in grad_argspec.args or "variables" in grad_argspec.kwonlyargs or grad_argspec.varkw) if variables and not variables_in_signature: raise TypeError( "@tf.custom_gradient grad_fn must accept keyword argument 'variables', " "since function uses variables: {}".format(variables)) if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. logging.warn( "@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") all_tensors = flat_result + args + variables def tape_grad_fn(*result_grads): """Custom grad fn wrapper.""" result_grads = result_grads[:flat_result_len] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): raise ValueError("Must return gradient for each variable from " "@custom_gradient grad_fn.") else: input_grads = grad_fn(*result_grads) variable_grads = [] # Need to return one value per input to the IdentityN, so pad the # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) return ([None] * flat_result_len) + input_grads + variable_grads @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable """Custom grad fn wrapper.""" return tape_grad_fn(*result_grads) original_tensors = all_tensors with ops.get_default_graph().gradient_override_map({"IdentityN": name}): all_tensors = array_ops.identity_n(all_tensors) original_tensors = [ops.convert_to_tensor(x) for x in original_tensors] # Propagate handle data for happier shape inference for resource variables. for i, t in enumerate(original_tensors): if t.dtype == dtypes.resource and hasattr(t, "_handle_data"): all_tensors[i]._handle_data = t._handle_data # pylint: disable=protected-access tape_lib.record_operation(f.__name__, all_tensors, original_tensors, tape_grad_fn) for ot, t in zip(original_tensors, all_tensors): copy_handle_data(ot, t) return nest.pack_sequence_as(structure=result, flat_sequence=all_tensors[:flat_result_len])
def default_residual_fn(inputs, outputs): nest.assert_same_structure(inputs, outputs) nest.map_structure(assert_shape_match, inputs, outputs) return nest.map_structure(lambda inp, out: inp + out, inputs, outputs)
def _model_loss(model, inputs, targets, output_loss_metrics=None, sample_weights=None, training=False): """Calculates the loss for a given model. Arguments: model: The model on which metrics are being calculated. inputs: Either a dictionary of inputs to the model or a list of input arrays. targets: List of target arrays. output_loss_metrics: List of metrics that are used to aggregated output loss values. sample_weights: Optional list of sample weight arrays. training: Whether the model should be run in inference or training mode. Returns: Returns the model output, total loss, loss value calculated using the specified loss function and masks for each output. The total loss includes regularization losses and applies masking and sample weighting to the loss value. """ # TODO(psv): Dedup code here with graph mode prepare_total_loss() fn. # Used to keep track of the total loss value (stateless). # eg., total_loss = loss_weight_1 * output_1_loss_fn(...) + # loss_weight_2 * output_2_loss_fn(...) + # layer losses. total_loss = 0 kwargs = {} if model._expects_training_arg: kwargs['training'] = training if len(inputs) == 1 and not isinstance(inputs, dict): inputs = inputs[0] # Allow mixed `NumPy` and `EagerTensor` input here. if any( isinstance(input_t, (np.ndarray, float, int)) for input_t in nest.flatten(inputs)): inputs = nest.map_structure(ops.convert_to_tensor_v2, inputs) outs = model(inputs, **kwargs) outs = nest.flatten(outs) if targets: targets = training_utils.cast_if_floating_dtype_and_mismatch( targets, outs) # TODO(sallymatson/psv): check if we should do same mismatch fix for weights if sample_weights: sample_weights = [ training_utils.cast_if_floating_dtype( ops.convert_to_tensor_v2(val)) if val is not None else None for val in sample_weights ] masks = [getattr(t, '_keras_mask', None) for t in outs] targets = nest.flatten(targets) # Used to keep track of individual output losses. output_losses = [] with backend.name_scope('loss'): loss_fns = [ loss_fn for loss_fn in model.loss_functions if loss_fn is not None ] custom_losses = model.losses # Regularization losses if not loss_fns and not custom_losses: if training: raise ValueError('The model cannot be trained ' 'because it has no loss to optimize.') else: raise ValueError('The model cannot be evaluated ' 'because it has no loss to compute.') for i, loss_fn in enumerate(loss_fns): weights = sample_weights[i] if sample_weights else None mask = masks[i] with backend.name_scope(model.output_names[i] + '_loss'): if mask is not None: mask = math_ops.cast(mask, outs[i].dtype) # Update weights with mask. if weights is None: weights = mask else: # Update dimensions of weights to match with mask if possible. weights = math_ops.cast(weights, outs[i].dtype) mask, _, weights = ( tf_losses_utils.squeeze_or_expand_dimensions( mask, sample_weight=weights)) weights *= mask if hasattr(loss_fn, 'reduction'): per_sample_losses = loss_fn.call(targets[i], outs[i]) weighted_losses = losses_utils.compute_weighted_loss( per_sample_losses, sample_weight=weights, reduction=losses_utils.ReductionV2.NONE) loss_reduction = loss_fn.reduction # `AUTO` loss reduction defaults to `SUM_OVER_BATCH_SIZE` for all # compile use cases. if loss_reduction == losses_utils.ReductionV2.AUTO: loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE # Compute the stateless loss value. output_loss = losses_utils.reduce_weighted_loss( weighted_losses, reduction=loss_reduction) else: # Compute the stateless loss value for a custom loss class. # Here we assume that the class takes care of loss reduction # because if this class returns a vector value we cannot # differentiate between use case where a custom optimizer # expects a vector loss value vs unreduced per-sample loss value. output_loss = loss_fn(targets[i], outs[i], sample_weight=weights) loss_reduction = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE # If the number of outputs is 1 then we don't append the loss metric # associated with each model output. When there are multiple outputs # associated with a model, each output's loss is calculated and returned # as part of the loss_metrics. if len(model.outputs) > 1: # Keep track of the stateful output loss result. output_losses.append(output_loss_metrics[i](output_loss)) # Scale output loss for distribution. For custom losses we assume # reduction was mean. if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE: output_loss = losses_utils.scale_loss_for_distribution( output_loss) total_loss += model._loss_weights_list[i] * output_loss # Add regularization losses if custom_losses: total_loss += losses_utils.scale_loss_for_distribution( math_ops.add_n(custom_losses)) return outs, total_loss, output_losses, masks
def step(self, time, inputs, state, name=None): """Perform a decoding step. Args: time: scalar `int32` tensor. inputs: A (structure of) input tensors. state: A (structure of) state tensors and TensorArrays. name: Name scope for any created operations. Returns: `(outputs, next_state, next_inputs, finished)`. """ batch_size = self._batch_size beam_width = self._beam_width end_token = self._end_token length_penalty_weight = self._length_penalty_weight with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): cell_state = state.cell_state inputs = nest.map_structure( lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, self._cell.state_size) cell_outputs, next_cell_state = self._cell( inputs, cell_state) # 得到当前时序的cell_outputs和next_cell_state # ====================================================================== gen_log_probs = tf.nn.log_softmax( self._output_layer(cell_outputs)) # 得到普通词的概率分布 emo_log_probs = tf.nn.log_softmax( self._emo_output_layer(cell_outputs)) # 得到情感词的概率分布 alphas = tf.nn.sigmoid( self._emo_choice_layer(cell_outputs)) # 得到选择情感词的概率 gen_log_probs = gen_log_probs + tf.log( 1 - alphas) # 对普通词的概率分布按照选择情感词的概率进行修正 emo_log_probs = emo_log_probs + tf.log( alphas) # 对情感词的概率分布按照选择情感词的概率进行修正 raw_log_probs = tf.concat([gen_log_probs, emo_log_probs], axis=-1) # 这里拼接后得到2 * vocab_size的概率分布 raw_log_probs = self._split_batch_beams( raw_log_probs, raw_log_probs.shape[1:] ) # [batch_size*beam_width, s] -> [batch_size, beam_width, s] # ====================================================================== next_cell_state = nest.map_structure(self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) real_vocab_size = gen_log_probs.shape[-1].value # 真实的词典大小 beam_search_output, beam_search_state = _beam_search_step( real_vocab_size=real_vocab_size, # 传入真实的词典大小 time=time, logits=raw_log_probs, # 这里的logits为raw_log_probs, next_cell_state=next_cell_state, beam_state=state, batch_size=batch_size, beam_width=beam_width, end_token=end_token, length_penalty_weight=length_penalty_weight) finished = beam_search_state.finished sample_ids = beam_search_output.predicted_ids next_inputs = control_flow_ops.cond( math_ops.reduce_all(finished), lambda: self._start_inputs, lambda: self._embedding_fn(sample_ids)) return (beam_search_output, beam_search_state, next_inputs, finished)
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=True, scope=None): """Creates a recurrent neural network specified by RNNCell `cell`. Performs fully dynamic unrolling of `inputs`. Example: ```python # create a BasicRNNCell rnn_cell = tf.compat.v1.nn.rnn_cell.BasicRNNCell(hidden_size) # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] # defining initial state initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) # 'state' is a tensor of shape [batch_size, cell_state_size] outputs, state = tf.compat.v1.nn.dynamic_rnn(rnn_cell, input_data, initial_state=initial_state, dtype=tf.float32) ``` ```python # create 2 LSTMCells rnn_layers = [tf.compat.v1.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] # create a RNN cell composed sequentially of a number of RNNCells multi_rnn_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(rnn_layers) # 'outputs' is a tensor of shape [batch_size, max_time, 256] # 'state' is a N-tuple where N is the number of LSTMCells containing a # tf.nn.rnn_cell.LSTMStateTuple for each cell outputs, state = tf.compat.v1.nn.dynamic_rnn(cell=multi_rnn_cell, inputs=data, dtype=tf.float32) ``` Args: cell: An instance of RNNCell. inputs: The RNN inputs. If `time_major == False` (default), this must be a `Tensor` of shape: `[batch_size, max_time, ...]`, or a nested tuple of such elements. If `time_major == True`, this must be a `Tensor` of shape: `[max_time, batch_size, ...]`, or a nested tuple of such elements. This may also be a (possibly nested) tuple of Tensors satisfying this property. The first two dimensions must match across all the inputs, but otherwise the ranks and other shape components may differ. In this case, input to `cell` at each time-step will replicate the structure of these tuples, except for the time dimension (from which the time is taken). The input to `cell` at each time step will be a `Tensor` or (possibly nested) tuple of Tensors each with dimensions `[batch_size, ...]`. sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used to copy-through state and zero-out outputs when past a batch element's sequence length. So it's more for performance than correctness. initial_state: (optional) An initial state for the RNN. If `cell.state_size` is an integer, this must be a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this should be a tuple of tensors having shapes `[batch_size, s] for s in cell.state_size`. dtype: (optional) The data type for the initial state and expected output. Required if initial_state is not provided or RNN state has a heterogeneous dtype. parallel_iterations: (Default: 32). The number of iterations to run in parallel. Those operations which do not have any temporal dependency and can be run in parallel, will be. This parameter trades off time for space. Values >> 1 use more memory but take less time, while smaller values use less memory but computations take longer. swap_memory: Transparently swap the tensors produced in forward inference but needed for back prop from GPU to CPU. This allows training RNNs which would typically not fit on a single GPU, with very minimal (or no) performance penalty. time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using `time_major = True` is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. scope: VariableScope for the created subgraph; defaults to "rnn". Returns: A pair (outputs, state) where: outputs: The RNN output `Tensor`. If time_major == False (default), this will be a `Tensor` shaped: `[batch_size, max_time, cell.output_size]`. If time_major == True, this will be a `Tensor` shaped: `[max_time, batch_size, cell.output_size]`. Note, if `cell.output_size` is a (possibly nested) tuple of integers or `TensorShape` objects, then `outputs` will be a tuple having the same structure as `cell.output_size`, containing Tensors having shapes corresponding to the shape data in `cell.output_size`. state: The final state. If `cell.state_size` is an int, this will be shaped `[batch_size, cell.state_size]`. If it is a `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. If it is a (possibly nested) tuple of ints or `TensorShape`, this will be a tuple having the corresponding shapes. If cells are `LSTMCells` `state` will be a tuple containing a `LSTMStateTuple` for each cell. Raises: TypeError: If `cell` is not an instance of RNNCell. ValueError: If inputs is None or an empty list. RuntimeError: If not using control flow v2. """ # Currently only support time_major == True case. assert time_major # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or # TfLiteRNNCells. rnn_cell_impl.assert_like_rnncell("cell", cell) if not control_flow_util.ENABLE_CONTROL_FLOW_V2: raise RuntimeError("OpHint dynamic rnn only supports control flow v2.") parent_first_child_input = [{ "parent_ophint_input_index": 0, "first_child_ophint_input_index": 0 }] parent_last_child_output = [{ "parent_output_index": 0, # For LstmCell, the index is 2. # For RnnCell, the index is 1. # So we use -1 meaning it's the last one. "child_output_index": -1 }] internal_children_input_output = [{ "child_input_index": 0, # For LstmCell, the index is 2. # For RnnCell, the index is 1. # So we use -1 meaning it's the last one. "child_output_index": -1 }] inputs_outputs_mappings = { "parent_first_child_input": parent_first_child_input, "parent_last_child_output": parent_last_child_output, "internal_children_input_output": internal_children_input_output } tflite_wrapper = op_hint.OpHint( "TfLiteDynamicRnn", level=2, children_inputs_mappings=inputs_outputs_mappings) with vs.variable_scope(scope or "rnn") as varscope: # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0) # By default, time_major==False and inputs are batch-major: shaped # [batch, time, depth] # For internal calculations, we transpose to [time, batch, depth] flat_input = nest.flatten(inputs) if not time_major: # (batch, time, depth) => (time, batch, depth) flat_input = [ ops.convert_to_tensor(input_) for input_ in flat_input ] flat_input = tuple( _transpose_batch_time(input_) for input_ in flat_input) parallel_iterations = parallel_iterations or 32 if sequence_length is not None: sequence_length = math_ops.cast(sequence_length, dtypes.int32) if sequence_length.shape.rank not in (None, 1): raise ValueError( "sequence_length must be a vector of length batch_size, " "but saw shape: %s" % sequence_length.shape) sequence_length = array_ops.identity( # Just to find it in the graph. sequence_length, name="sequence_length") batch_size = _best_effort_input_batch_size(flat_input) if initial_state is not None: state = initial_state else: if not dtype: raise ValueError( "If there is no initial_state, you must give a dtype.") if getattr(cell, "get_initial_state", None) is not None: state = cell.get_initial_state(inputs=None, batch_size=batch_size, dtype=dtype) else: state = cell.zero_state(batch_size, dtype) def _assert_has_shape(x, shape): x_shape = array_ops.shape(x) packed_shape = array_ops.stack(shape) return control_flow_ops.Assert( math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ "Expected shape for Tensor %s is " % x.name, packed_shape, " but saw shape: ", x_shape ]) if not context.executing_eagerly() and sequence_length is not None: # Perform some shape validation with ops.control_dependencies( [_assert_has_shape(sequence_length, [batch_size])]): sequence_length = array_ops.identity(sequence_length, name="CheckSeqLen") inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) outputs, final_state = _dynamic_rnn_loop( cell, inputs, state, parallel_iterations=parallel_iterations, swap_memory=swap_memory, sequence_length=sequence_length, dtype=dtype) # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. # If we are performing batch-major calculations, transpose output back # to shape [batch, time, depth] if not time_major: # (time, batch, depth) => (batch, time, depth) outputs = nest.map_structure(_transpose_batch_time, outputs) outputs = tflite_wrapper.add_output(outputs, name="outputs") return outputs, final_state
def from_value(cls, value): return cls( nest.map_structure(type_spec.type_spec_from_value, value.nest))
def __init__(self, cell, attention_mechanism, is_manual_attention, manual_alignments, attention_layer_size=None, alignment_history=False, cell_input_fn=None, output_attention=True, initial_cell_state=None, name=None): """Construct the `AttentionWrapper`. Args: cell: An instance of `RNNCell`. attention_mechanism: A list of `AttentionMechanism` instances or a single instance. attention_layer_size: A list of Python integers or a single Python integer, the depth of the attention (output) layer(s). If None (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). cell_input_fn: (optional) A `callable`. The default is: `lambda inputs, attention: array_tf.concat([inputs, attention], -1)`. output_attention: Python bool. If `True` (default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. If `False`, the output at each time step is the output of `cell`. This is the beahvior of Bhadanau-style attention mechanisms. In both cases, the `attention` tensor is propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated up to the next cell in an RNN stack or to the top RNN output. initial_cell_state: The initial state value to use for the cell when the user calls `zero_state()`. Note that if this value is provided now, and the user uses a `batch_size` argument of `zero_state` which does not match the batch size of `initial_cell_state`, proper behavior is not guaranteed. name: Name to use when creating tf. Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` is a list but `attention_layer_size` is not; or vice versa). ValueError: if `attention_layer_size` is not None, `attention_mechanism` is a list, and its length does not match that of `attention_layer_size`. """ super(AttentionWrapper, self).__init__(name=name) self.is_manual_attention = is_manual_attention self.manual_alignments = manual_alignments if isinstance(attention_mechanism, (list, tuple)): self._is_multi = True attention_mechanisms = attention_mechanism for attention_mechanism in attention_mechanisms: if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must contain only instances of " "AttentionMechanism, saw type: %s" % type(attention_mechanism).__name__) else: self._is_multi = False if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must be an AttentionMechanism or list of " "multiple AttentionMechanism instances, saw type: %s" % type(attention_mechanism).__name__) attention_mechanisms = (attention_mechanism, ) if cell_input_fn is None: cell_input_fn = ( lambda inputs, attention: tf.concat([inputs, attention], -1)) else: if not callable(cell_input_fn): raise TypeError( "cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__) if attention_layer_size is not None: attention_layer_sizes = tuple(attention_layer_size if isinstance( attention_layer_size, (list, tuple)) else (attention_layer_size, )) if len(attention_layer_sizes) != len(attention_mechanisms): raise ValueError( "If provided, attention_layer_size must contain exactly one " "integer per attention_mechanism, saw: %d vs %d" % (len(attention_layer_sizes), len(attention_mechanisms))) self._attention_layers = tuple( tf.layers.Dense(attention_layer_size, name="attention_layer", use_bias=False) for attention_layer_size in attention_layer_sizes) self._attention_layer_size = sum(attention_layer_sizes) else: self._attention_layers = None self._attention_layer_size = sum( attention_mechanism.values.get_shape()[-1].value for attention_mechanism in attention_mechanisms) self._cell = cell self._attention_mechanisms = attention_mechanisms self._cell_input_fn = cell_input_fn self._output_attention = output_attention self._alignment_history = alignment_history with tf.name_scope(name, "AttentionWrapperInit"): if initial_cell_state is None: self._initial_cell_state = None else: final_state_tensor = nest.flatten(initial_cell_state)[-1] state_batch_size = (final_state_tensor.shape[0].value or tf.shape(final_state_tensor)[0]) error_message = ( "When constructing AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and initial_cell_state. Are you using " "the BeamSearchDecoder? You may need to tile your initial state " "via the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with tf.control_dependencies( self._batch_size_checks(state_batch_size, error_message)): self._initial_cell_state = nest.map_structure( lambda s: tf.identity(s, name="check_initial_cell_state"), initial_cell_state)
def tpu_fn(): activations = mid_level.dequeue() mid_level.apply_gradients( nest.map_structure(array_ops.ones_like, activations)) return activations
def while_loop(cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, maximum_iterations=None, name=None, return_same_structure=True): """Like tf.while_loop, except emits a single While op.""" # Keep the original loop_vars around to know which args were TensorArrays. orig_loop_vars = loop_vars # Cache its length since we use it at multiple places below. len_orig_loop_vars = len(orig_loop_vars) # Convert TensorArrays to their flow variables. These get converted back to # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and # `wrapped_body` below. loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) loop_vars = nest.map_structure( ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) if shape_invariants is not None: nest.assert_same_structure(orig_loop_vars, shape_invariants) else: shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) if not name: name = "while" with ops.name_scope(name) as scope: with ops.name_scope(None): cond_name = util.unique_fn_name(scope, "cond") body_name = util.unique_fn_name(scope, "body") maximum_iterations_loop_var = _build_maximum_iterations_loop_var( maximum_iterations) loop_counter = constant_op.constant( 0, dtype=maximum_iterations_loop_var.dtype if maximum_iterations is not None else None, name="loop_counter") # Add loop counter needed for computing gradients. loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars shape_invariants = type(shape_invariants)( [tensor_shape.scalar(), tensor_shape.scalar()]) + shape_invariants # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies # Build a `cond` wrapper that can handle the extra counter loop_var. def wrapped_cond(loop_counter, maximum_iterations_arg, *args): # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. if maximum_iterations is None: return cond(*_pack_sequence_as(orig_loop_vars, args)) else: return math_ops.logical_and( loop_counter < maximum_iterations_arg, cond(*_pack_sequence_as(orig_loop_vars, args))) # NOTE(skyewm): we set collections to the outer graph's collections for # compatibility with TPUEstimator. cond_graph = func_graph_module.func_graph_from_py_func( cond_name, wrapped_cond, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileCondFuncGraph( cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) def wrapped_body(loop_counter, maximum_iterations_arg, *args): """Loop body augmented with counter update. Args: loop_counter: Loop counter which needs to be incremented in the body. maximum_iterations_arg: Maximum iterations of the loop. *args: List of args Returns: A list of tensors the same length as args. """ # Capture the tensors already captured in cond_graph so that they appear # in the same order in body_graph.external_captures. for t in cond_graph.external_captures: ops.get_default_graph().capture(t) # Convert the flow variables in `args` to TensorArrays. `args` should # already have the same structure as `orig_loop_vars` but currently there # is no nest.zip so we call `_pack_sequence_as` which flattens both # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays # and packs it into the structure of `orig_loop_vars`. outputs = body(*_pack_sequence_as(orig_loop_vars, args)) if not nest.is_sequence(outputs): outputs = [outputs] # Compare the structure of input and output of body converting the # top-level tuples to list to be compatible with legacy while_loop. nest.assert_same_structure(list(outputs), list(orig_loop_vars)) outputs = _tensor_array_to_flow(outputs) # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. return [loop_counter + 1, maximum_iterations_arg] + list(outputs) body_graph = func_graph_module.func_graph_from_py_func( body_name, wrapped_body, [], # We provide signature instead of args. {}, signature=_build_signature(loop_vars, shape_invariants), func_graph=util.WhileBodyFuncGraph( body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies) # Add external captures of body to the list of loop vars. # Note that external tensors will be treated as loop invariants, i.e., # the value of that tensor in each iteration is the same as it was at the # beginning of the loop execution. loop_vars = loop_vars + body_graph.external_captures # TODO(srbs): Update lowering code to create _Enter nodes with # is_constant=True for inputs that are directly passed to outputs. body_graph.outputs.extend(body_graph.internal_captures) # Capture the extra `external_captures` of `body_graph` in `cond_graph` so # that it expects to receive those as arguments. with cond_graph.as_default(): num_cond_captures = len(cond_graph.external_captures) assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) for body_capture in body_graph.external_captures[ num_cond_captures:]: assert body_capture not in cond_graph.captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the # shape invariants, or the shapes of the loop vars if the invariants are not # specified. num_flattened_outputs = len(nest.flatten(orig_loop_vars)) # First var is loop counter and second var is maximum_iterations. first_loop_var_index = 2 _check_shapes_compat( body_graph.outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs], nest.flatten( shape_invariants[first_loop_var_index:first_loop_var_index + len_orig_loop_vars]), nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + len_orig_loop_vars])) flattened_loop_vars = nest.flatten(loop_vars) _check_num_inputs_outputs(cond_graph, body_graph, len(flattened_loop_vars)) with ops.control_dependencies( list(cond_graph.control_captures) + list(body_graph.control_captures)): outputs = gen_functional_ops._while( flattened_loop_vars, util.create_new_tf_function(cond_graph), util.create_new_tf_function(body_graph), output_shapes=[t.shape for t in body_graph.outputs], parallel_iterations=parallel_iterations, name=scope) _copy_handle_data(body_graph.outputs, outputs) util.maybe_set_lowering_attr(outputs[0].op) util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) # Return identities for each output of the While op, rather than the output # of the While op directly. This makes pruning work if the output of # while_loop() is fetched: the lowering pass converts the While outputs into # IdentityN outputs, which if fetched will cause all ops in the body to be # run (since it takes all exit ops as input). After lowering, each output # identity op will end up with only the appropriate exit op as input. outputs = tuple(array_ops.identity(t) for t in outputs) outputs = _pack_sequence_as( orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index + num_flattened_outputs]) if return_same_structure: return outputs flattened_outputs = nest.flatten(outputs) if len(flattened_outputs) == 1: return flattened_outputs[0] else: return outputs
def _testWithMaybeMultiAttention(self, is_multi, create_attention_mechanisms, expected_final_output, expected_final_state, attention_mechanism_depths, alignment_history=False, expected_final_alignment_history=None, attention_layer_sizes=None, attention_layers=None, create_query_layer=False, create_memory_layer=True, create_attention_kwargs=None): # Allow is_multi to be True with a single mechanism to enable test for # passing in a single mechanism in a list. assert len(create_attention_mechanisms) == 1 or is_multi encoder_sequence_length = [3, 2, 3, 1, 1] decoder_sequence_length = [2, 0, 1, 2, 3] batch_size = 5 encoder_max_time = 8 decoder_max_time = 4 input_depth = 7 encoder_output_depth = 10 cell_depth = 9 create_attention_kwargs = create_attention_kwargs or {} if attention_layer_sizes is not None: # Compute sum of attention_layer_sizes. Use encoder_output_depth if None. attention_depth = sum(attention_layer_size or encoder_output_depth for attention_layer_size in attention_layer_sizes) elif attention_layers is not None: # Compute sum of attention_layers output depth. attention_depth = sum( attention_layer.compute_output_shape( [batch_size, cell_depth + encoder_output_depth]).dims[-1].value for attention_layer in attention_layers) else: attention_depth = encoder_output_depth * len(create_attention_mechanisms) decoder_inputs = np.random.randn(batch_size, decoder_max_time, input_depth).astype(np.float32) encoder_outputs = np.random.randn(batch_size, encoder_max_time, encoder_output_depth).astype(np.float32) attention_mechanisms = [] for creator, depth in zip(create_attention_mechanisms, attention_mechanism_depths): # Create a memory layer with deterministic initializer to avoid randomness # in the test between graph and eager. if create_query_layer: create_attention_kwargs["query_layer"] = keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) if create_memory_layer: create_attention_kwargs["memory_layer"] = keras.layers.Dense( depth, kernel_initializer="ones", use_bias=False) attention_mechanisms.append( creator( units=depth, memory=encoder_outputs, memory_sequence_length=encoder_sequence_length, **create_attention_kwargs)) with self.cached_session(use_gpu=True): attention_layer_size = attention_layer_sizes attention_layer = attention_layers if not is_multi: if attention_layer_size is not None: attention_layer_size = attention_layer_size[0] if attention_layer is not None: attention_layer = attention_layer[0] cell = keras.layers.LSTMCell(cell_depth, recurrent_activation="sigmoid", kernel_initializer="ones", recurrent_initializer="ones") cell = wrapper.AttentionWrapper( cell, attention_mechanisms if is_multi else attention_mechanisms[0], attention_layer_size=attention_layer_size, alignment_history=alignment_history, attention_layer=attention_layer) if cell._attention_layers is not None: for layer in cell._attention_layers: if getattr(layer, "kernel_initializer") is None: layer.kernel_initializer = initializers.glorot_uniform(seed=1337) sampler = sampler_py.TrainingSampler() my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) initial_state = cell.get_initial_state( dtype=dtypes.float32, batch_size=batch_size) final_outputs, final_state, _ = my_decoder( decoder_inputs, initial_state=initial_state, sequence_length=decoder_sequence_length) self.assertIsInstance(final_outputs, basic_decoder.BasicDecoderOutput) self.assertIsInstance(final_state, wrapper.AttentionWrapperState) expected_time = ( expected_final_state.time if context.executing_eagerly() else None) self.assertEqual((batch_size, expected_time, attention_depth), tuple(final_outputs.rnn_output.get_shape().as_list())) self.assertEqual((batch_size, expected_time), tuple(final_outputs.sample_id.get_shape().as_list())) self.assertEqual((batch_size, attention_depth), tuple(final_state.attention.get_shape().as_list())) self.assertEqual((batch_size, cell_depth), tuple(final_state.cell_state[0].get_shape().as_list())) self.assertEqual((batch_size, cell_depth), tuple(final_state.cell_state[1].get_shape().as_list())) if alignment_history: if is_multi: state_alignment_history = [] for history_array in final_state.alignment_history: history = history_array.stack() self.assertEqual((expected_time, batch_size, encoder_max_time), tuple(history.get_shape().as_list())) state_alignment_history.append(history) state_alignment_history = tuple(state_alignment_history) else: state_alignment_history = final_state.alignment_history.stack() self.assertEqual((expected_time, batch_size, encoder_max_time), tuple(state_alignment_history.get_shape().as_list())) nest.assert_same_structure(cell.state_size, cell.zero_state(batch_size, dtypes.float32)) # Remove the history from final_state for purposes of the # remainder of the tests. final_state = final_state._replace(alignment_history=()) # pylint: disable=protected-access else: state_alignment_history = () self.evaluate(variables.global_variables_initializer()) eval_result = self.evaluate({ "final_outputs": final_outputs, "final_state": final_state, "state_alignment_history": state_alignment_history, }) final_output_info = nest.map_structure(get_result_summary, eval_result["final_outputs"]) final_state_info = nest.map_structure(get_result_summary, eval_result["final_state"]) print("final_output_info: ", final_output_info) print("final_state_info: ", final_state_info) nest.map_structure(self.assertAllCloseOrEqual, expected_final_output, final_output_info) nest.map_structure(self.assertAllCloseOrEqual, expected_final_state, final_state_info) if alignment_history: # by default, the wrapper emits attention as output final_alignment_history_info = nest.map_structure( get_result_summary, eval_result["state_alignment_history"]) print("final_alignment_history_info: ", final_alignment_history_info) nest.map_structure( self.assertAllCloseOrEqual, # outputs are batch major but the stacked TensorArray is time major expected_final_alignment_history, final_alignment_history_info)
def output_dtype(self): dtype = nest.flatten(self._initial_state)[0].dtype return CustomDecoderOutput( nest.map_structure(lambda _: dtype, self._rnn_output_size()), tf.float32, self._helper.sample_ids_dtype)
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, beam_width, end_token, length_penalty_weight): """Performs a single step of Beam Search Decoding. Args: time: Beam search time step, should start at 0. At time 0 we assume that all beams are equal and consider only the first beam for continuations. logits: Logits at the current time step. A tensor of shape `[batch_size, beam_width, vocab_size]` next_cell_state: The next state from the cell, e.g. an instance of AttentionWrapperState if the cell is attentional. beam_state: Current state of the beam search. An instance of `BeamSearchDecoderState`. batch_size: The batch size for this input. beam_width: Python int. The size of the beams. end_token: The int32 end token. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. Returns: A new beam state. """ static_batch_size = tensor_util.constant_value(batch_size) # Calculate the current lengths of the predictions prediction_lengths = beam_state.lengths previously_finished = beam_state.finished # Calculate the total log probs for the new hypotheses # Final Shape: [batch_size, beam_width, vocab_size] step_log_probs = nn_ops.log_softmax(logits) step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished) total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs # Calculate the continuation lengths by adding to all continuing beams. vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1] lengths_to_add = array_ops.one_hot( indices=array_ops.fill([batch_size, beam_width], end_token), depth=vocab_size, on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished)) lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = ( lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) # Calculate the scores for each beam scores = _get_scores( log_probs=total_probs, sequence_lengths=new_prediction_lengths, length_penalty_weight=length_penalty_weight) time = ops.convert_to_tensor(time, name="time") # During the first time step we only consider the initial beam scores_flat = array_ops.reshape(scores, [batch_size, -1]) # Pick the next beams according to the specified successors function next_beam_size = ops.convert_to_tensor( beam_width, dtype=dtypes.int32, name="beam_width") next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) next_beam_scores.set_shape([static_batch_size, beam_width]) word_indices.set_shape([static_batch_size, beam_width]) # Pick out the probs, beam_ids, and states according to the chosen predictions next_beam_probs = _tensor_gather_helper( gather_indices=word_indices, gather_from=total_probs, batch_size=batch_size, range_size=beam_width * vocab_size, gather_shape=[-1], name="next_beam_probs") # Note: just doing the following # math_ops.to_int32(word_indices % vocab_size, # name="next_beam_word_ids") # would be a lot cleaner but for reasons unclear, that hides the results of # the op which prevents capturing it with tfdbg debug ops. raw_next_word_ids = math_ops.mod( word_indices, vocab_size, name="next_beam_word_ids") next_word_ids = math_ops.to_int32(raw_next_word_ids) next_beam_ids = math_ops.to_int32( word_indices / vocab_size, name="next_beam_parent_ids") # Append new ids to current predictions previously_finished = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=previously_finished, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) next_finished = math_ops.logical_or( previously_finished, math_ops.equal(next_word_ids, end_token), name="next_beam_finished") # Calculate the length of the next predictions. # 1. Finished beams remain unchanged. # 2. Beams that are now finished (EOS predicted) have their length # increased by 1. # 3. Beams that are not yet finished have their length increased by 1. lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished)) next_prediction_len = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, batch_size=batch_size, range_size=beam_width, gather_shape=[-1]) next_prediction_len += lengths_to_add # Pick out the cell_states according to the next_beam_ids. We use a # different gather_shape here because the cell_state tensors, i.e. # the tensors that would be gathered from, all have dimension # greater than two and we need to preserve those dimensions. # pylint: disable=g-long-lambda next_cell_state = nest.map_structure( lambda gather_from: _maybe_tensor_gather_helper( gather_indices=next_beam_ids, gather_from=gather_from, batch_size=batch_size, range_size=beam_width, gather_shape=[batch_size * beam_width, -1]), next_cell_state) # pylint: enable=g-long-lambda next_state = BeamSearchDecoderState( cell_state=next_cell_state, log_probs=next_beam_probs, lengths=next_prediction_len, finished=next_finished) output = BeamSearchDecoderOutput( scores=next_beam_scores, predicted_ids=next_word_ids, parent_ids=next_beam_ids) return output, next_state
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): """Creates a dynamic version of bidirectional recurrent neural network. Takes input and builds independent forward and backward RNNs. The input_size of forward and backward cell must match. The initial state for both directions is zero by default (but can be set optionally) and no intermediate states are ever returned -- the network is fully unrolled for the given (passed in) length(s) of the sequence(s) or completely unrolled if length(s) is not given. Args: cell_fw: An instance of RNNCell, to be used for forward direction. cell_bw: An instance of RNNCell, to be used for backward direction. inputs: The RNN inputs. If time_major == False (default), this must be a tensor of shape: `[batch_size, max_time, ...]`, or a nested tuple of such elements. If time_major == True, this must be a tensor of shape: `[max_time, batch_size, ...]`, or a nested tuple of such elements. sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, containing the actual lengths for each of the sequences in the batch. If not provided, all batch entries are assumed to be full sequences; and time reversal is applied from time `0` to `max_time` for each sequence. initial_state_fw: (optional) An initial state for the forward RNN. This must be a tensor of appropriate type and shape `[batch_size, cell_fw.state_size]`. If `cell_fw.state_size` is a tuple, this should be a tuple of tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. initial_state_bw: (optional) Same as for `initial_state_fw`, but using the corresponding properties of `cell_bw`. dtype: (optional) The data type for the initial states and expected output. Required if initial_states are not provided or RNN states have a heterogeneous dtype. parallel_iterations: (Default: 32). The number of iterations to run in parallel. Those operations which do not have any temporal dependency and can be run in parallel, will be. This parameter trades off time for space. Values >> 1 use more memory but take less time, while smaller values use less memory but computations take longer. swap_memory: Transparently swap the tensors produced in forward inference but needed for back prop from GPU to CPU. This allows training RNNs which would typically not fit on a single GPU, with very minimal (or no) performance penalty. time_major: The shape format of the `inputs` and `outputs` Tensors. If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using `time_major = True` is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. scope: VariableScope for the created subgraph; defaults to "bidirectional_rnn" Returns: A tuple (outputs, output_states) where: outputs: A tuple (output_fw, output_bw) containing the forward and the backward rnn output `Tensor`. If time_major == False (default), output_fw will be a `Tensor` shaped: `[batch_size, max_time, cell_fw.output_size]` and output_bw will be a `Tensor` shaped: `[batch_size, max_time, cell_bw.output_size]`. If time_major == True, output_fw will be a `Tensor` shaped: `[max_time, batch_size, cell_fw.output_size]` and output_bw will be a `Tensor` shaped: `[max_time, batch_size, cell_bw.output_size]`. It returns a tuple instead of a single concatenated `Tensor`, unlike in the `bidirectional_rnn`. If the concatenated one is preferred, the forward and backward outputs can be concatenated as `tf.concat(outputs, 2)`. output_states: A tuple (output_state_fw, output_state_bw) containing the forward and the backward final states of bidirectional rnn. Raises: TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. """ rnn_cell_impl.assert_like_rnncell("cell_fw", cell_fw) rnn_cell_impl.assert_like_rnncell("cell_bw", cell_bw) with vs.variable_scope(scope or "bidirectional_rnn"): # Forward direction with vs.variable_scope("fw") as fw_scope: output_fw, output_state_fw = dynamic_rnn( cell=cell_fw, inputs=inputs, sequence_length=sequence_length, initial_state=initial_state_fw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=fw_scope) # Backward direction if not time_major: time_axis = 1 batch_axis = 0 else: time_axis = 0 batch_axis = 1 def _reverse(input_, seq_lengths, seq_axis, batch_axis): if seq_lengths is not None: return array_ops.reverse_sequence(input=input_, seq_lengths=seq_lengths, seq_axis=seq_axis, batch_axis=batch_axis) else: return array_ops.reverse(input_, axis=[seq_axis]) with vs.variable_scope("bw") as bw_scope: def _map_reverse(inp): return _reverse(inp, seq_lengths=sequence_length, seq_axis=time_axis, batch_axis=batch_axis) inputs_reverse = nest.map_structure(_map_reverse, inputs) tmp, output_state_bw = dynamic_rnn( cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, initial_state=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=bw_scope) output_bw = _reverse(tmp, seq_lengths=sequence_length, seq_axis=time_axis, batch_axis=batch_axis) outputs = (output_fw, output_bw) output_states = (output_state_fw, output_state_bw) return (outputs, output_states)
def select_device(device, structured): """Specialize a nest of regular & per-device values for one device.""" def _get(x): return x.get(device) if isinstance(x, DistributedValues) else x return nest.map_structure(_get, structured)
def __init__(self, cell, embedding, start_tokens, end_token, initial_state, beam_width, output_layer=None, length_penalty_weight=0.0, reorder_tensor_arrays=True): """Initialize the BeamSearchDecoder. Args: cell: An `RNNCell` instance. embedding: A callable that takes a vector tensor of `ids` (argmax ids), or the `params` argument for `embedding_lookup`. start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. end_token: `int32` scalar, the token that marks end of decoding. initial_state: A (possibly nested tuple of...) tensors and TensorArrays. beam_width: Python integer, the number of beams. output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell state will be reordered according to the beam search path. If the `TensorArray` can be reordered, the stacked form will be returned. Otherwise, the `TensorArray` will be returned as is. Set this flag to `False` if the cell state contains `TensorArray`s that are not amenable to reordering. Raises: TypeError: if `cell` is not an instance of `RNNCell`, or `output_layer` is not an instance of `tf.layers.Layer`. ValueError: If `start_tokens` is not a vector or `end_token` is not a scalar. """ rnn_cell_impl.assert_like_rnncell("cell", cell) # pylint: disable=protected-access if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)): raise TypeError( "output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell self._output_layer = output_layer self._reorder_tensor_arrays = reorder_tensor_arrays if callable(embedding): self._embedding_fn = embedding else: self._embedding_fn = ( lambda ids: embedding_ops.embedding_lookup(embedding, ids)) self._start_tokens = ops.convert_to_tensor( start_tokens, dtype=dtypes.int32, name="start_tokens") if self._start_tokens.get_shape().ndims != 1: raise ValueError("start_tokens must be a vector") self._end_token = ops.convert_to_tensor( end_token, dtype=dtypes.int32, name="end_token") if self._end_token.get_shape().ndims != 0: raise ValueError("end_token must be a scalar") self._batch_size = array_ops.size(start_tokens) self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._initial_cell_state = nest.map_structure( self._maybe_split_batch_beams, initial_state, self._cell.state_size) self._start_tokens = array_ops.tile( array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) self._start_inputs = self._embedding_fn(self._start_tokens) self._finished = array_ops.one_hot( array_ops.zeros([self._batch_size], dtype=dtypes.int32), depth=self._beam_width, on_value=False, off_value=True, dtype=dtypes.bool)
def _fast_decode(self, features, decode_length, beam_size=1, top_beams=1, alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff beam_size > 1, otherwise beam search related arguments are ignored. Args: features: a map of string to model features. decode_length: an integer. How many additional timesteps to decode. beam_size: number of beams. top_beams: an integer. How many of the beams to return. alpha: Float that controls the length penalty. larger the alpha, stronger the preference for slonger translations. Returns: samples: an integer `Tensor`. Top samples from the beam search Raises: NotImplementedError: If there are multiple data shards. """ if self._num_datashards != 1: raise NotImplementedError( "Fast decoding only supports a single shard.") dp = self._data_parallelism hparams = self._hparams encoder_name = "omni" if hparams.share_ed else "encoder" decoder_name = "omni" if hparams.share_ed else "decoder" inputs = features["inputs"] batch_size = tf.shape(inputs)[0] target_modality = self._problem_hparams.target_modality if t2t_model.is_class_modality(target_modality): decode_length = 1 else: decode_length = tf.shape(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) if len(inputs.shape) < 5: inputs = tf.expand_dims(inputs, axis=4) s = tf.shape(inputs) inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) # _shard_features called to ensure that the variable names match inputs = self._shard_features({"inputs": inputs})["inputs"] input_modality = self._problem_hparams.input_modality["inputs"] with tf.variable_scope(input_modality.name): inputs = input_modality.bottom_sharded(inputs, dp) with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.omni_encoder, inputs, features["target_space_id"], hparams, encoder_name) encoder_output = encoder_output[0][0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0][0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(decode_length)) if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): body_outputs = dp(self.decode, targets, cache["encoder_output"], cache["encoder_decoder_attention_bias"], bias, hparams, cache, decoder_name, True if hparams.share_ed else None, None) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop # Note: Tensor.set_shape() does not work here since it merges shape info. # TODO(llion); Find a more robust solution. # pylint: disable=protected-access for layer in cache: cache[layer]["k"]._shape = tf.TensorShape( [None, None, key_channels]) cache[layer]["v"]._shape = tf.TensorShape( [None, None, value_channels]) # pylint: enable=protected-access cache["encoder_output"] = encoder_output cache[ "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias if beam_size > 1: # Beam Search target_modality = ( self._hparams.problems[self._problem_idx].target_modality) vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=cache, stop_early=(top_beams == 1)) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) next_id = tf.expand_dims(common_layers.sample_with_temperature( logits, temperature), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) scores = None next_id = tf.zeros([batch_size, 1], dtype=tf.int64) _, _, decoded_ids, _ = tf.while_loop( # TODO(llion): Early stopping. lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores
def _verify_tf_cond_vars(body_outputs, orelse_outputs, basic_symbol_names, composite_symbol_names): """Verifies variables manipulated by a conditional for consistency.""" # The whole point of _verify_tf_cond_vars is to give more useful error message # than tf-level exception by including variable names. If it's not available, # there is no point at performing this verification here. As of 2019-07-31, # conditional expression does not pass the names. if basic_symbol_names is None: return output_symbol_names = basic_symbol_names + composite_symbol_names basic_body_outputs, composite_body_outputs = body_outputs basic_orelse_outputs, composite_orelse_outputs = orelse_outputs assert isinstance(composite_body_outputs, tuple) assert isinstance(composite_orelse_outputs, tuple) # TODO(kkimlabs): Make this more consistent. # The basic outputs should always be a tuple. if not isinstance(basic_body_outputs, tuple): basic_body_outputs = (basic_body_outputs, ) if not isinstance(basic_orelse_outputs, tuple): basic_orelse_outputs = (basic_orelse_outputs, ) body_outputs = basic_body_outputs + composite_body_outputs orelse_outputs = basic_orelse_outputs + composite_orelse_outputs for body_output, orelse_output, name in zip(body_outputs, orelse_outputs, output_symbol_names): try: nest.assert_same_structure(body_output, orelse_output, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError( '"{}" does not have the same nested structure in the TRUE and FALSE' ' branches.\n\n{}'.format(name, str(e))) def _check_same_type(name, body_output_var, orelse_output_var): """Verfies that body_output_var and orelse_output_var have same dtype.""" if isinstance(body_output_var, (bool, int, float, str)): body_output_var = ops.convert_to_tensor_v2(body_output_var) if isinstance(orelse_output_var, (bool, int, float, str)): orelse_output_var = ops.convert_to_tensor_v2(orelse_output_var) if (not tensor_util.is_tensor(body_output_var) or not tensor_util.is_tensor(orelse_output_var)): return # TODO(mdan): Properly account for CompositeTensors. if (not hasattr(body_output_var, 'dtype') or not hasattr(orelse_output_var, 'dtype')): return if body_output_var.dtype != orelse_output_var.dtype: raise TypeError( '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE' ' branch. TensorFlow control flow requires that they are the' ' same.'.format(name, body_output_var.dtype.name, orelse_output_var.dtype.name)) nest.map_structure(functools.partial(_check_same_type, name), body_output, orelse_output)
def testNestMapStructure(self): k = object_identity._ObjectIdentityWrapper('k') v1 = object_identity._ObjectIdentityWrapper('v1') v2 = object_identity._ObjectIdentityWrapper('v2') struct = nest.map_structure(lambda a, b: (a, b), {k: v1}, {k: v2}) self.assertEqual(struct, {k: (v1, v2)})
def __init__(self, cell, embedding, start_tokens, end_token, initial_state, beam_width, output_layer=None, emo_output_layer=None, emo_choice_layer=None, length_penalty_weight=0.0): """Initialize the ECMBeamSearchDecoder. Args: cell: An `RNNCell` instance. embedding: A callable that takes a vector tensor of `ids` (argmax ids), or the `params` argument for `embedding_lookup`. start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. end_token: `int32` scalar, the token that marks end of decoding. initial_state: A (possibly nested tuple of...) tensors and TensorArrays. beam_width: Python integer, the number of beams. output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., `tf.layers.Dense`. Optional layer to apply to the RNN output prior to storing the result or sampling. length_penalty_weight: Float weight to penalize length. Disabled with 0.0. Raises: TypeError: if `cell` is not an instance of `RNNCell`, or `output_layer` is not an instance of `tf.layers.Layer`. ValueError: If `start_tokens` is not a vector or `end_token` is not a scalar. """ if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)): raise TypeError("output_layer must be a Layer, received: %s" % type(output_layer)) self._cell = cell self._output_layer = output_layer # 普通词典projection # ECM output layer self._emo_output_layer = emo_output_layer # 情感词典projection self._emo_choice_layer = emo_choice_layer # 选择情感词概率的 projection,输出(0,1)之间的概率 if callable(embedding): self._embedding_fn = embedding else: self._embedding_fn = ( lambda ids: embedding_ops.embedding_lookup(embedding, ids)) self._start_tokens = ops.convert_to_tensor(start_tokens, dtype=dtypes.int32, name="start_tokens") if self._start_tokens.get_shape().ndims != 1: raise ValueError("start_tokens must be a vector") self._end_token = ops.convert_to_tensor(end_token, dtype=dtypes.int32, name="end_token") if self._end_token.get_shape().ndims != 0: raise ValueError("end_token must be a scalar") self._batch_size = array_ops.size(start_tokens) self._beam_width = beam_width self._length_penalty_weight = length_penalty_weight self._initial_cell_state = nest.map_structure( self._maybe_split_batch_beams, initial_state, self._cell.state_size) self._start_tokens = array_ops.tile( array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) self._start_inputs = self._embedding_fn(self._start_tokens) self._finished = array_ops.zeros([self._batch_size, self._beam_width], dtype=dtypes.bool)
def __init__(self, cell, attention_mechanism, attention_layer_size=None, alignment_history=False, cell_input_fn=None, attention_input_fn=None, output_attention=True, initial_cell_state=None, name=None): """Construct the `AttentionWrapper`. Args: cell: An instance of `RNNCell`. attention_mechanism: An instance of `AttentionMechanism`. attention_layer_size: Python integer, the depth of the attention (output) layer. If None (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. alignment_history: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time major `TensorArray` on which you must call `stack()`). cell_input_fn: (optional) A `callable`. The default is: `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`. output_attention: Python bool. If `True` (default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. If `False`, the output at each time step is the output of `cell`. This is the beahvior of Bhadanau-style attention mechanisms. In both cases, the `attention` tensor is propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated up to the next cell in an RNN stack or to the top RNN output. initial_cell_state: The initial state value to use for the cell when the user calls `zero_state()`. Note that if this value is provided now, and the user uses a `batch_size` argument of `zero_state` which does not match the batch size of `initial_cell_state`, proper behavior is not guaranteed. name: Name to use when creating ops. """ super(AttentionWrapper, self).__init__(name=name) if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("cell must be an RNNCell, saw type: %s" % type(cell).__name__) if not isinstance(attention_mechanism, AttentionMechanism): raise TypeError( "attention_mechanism must be a AttentionMechanism, saw type: %s" % type(attention_mechanism).__name__) # -- what gets inputed to the core RNN cell we're wrapping around if cell_input_fn is None: cell_input_fn = (lambda inputs, attention: array_ops.concat( [inputs, attention], -1)) else: if not callable(cell_input_fn): raise TypeError( "cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__) ########### ADDED TO ALLOW DIFFERENT INPUTS TO ATTENTION MECHANISM ############# # what the attention unit gets as the query if attention_input_fn is None: attention_input_fn = (lambda _, state: state) else: if not callable(attention_input_fn): raise TypeError( "attention_input_fn must be callable, saw type: %s" % type(attention_input_fn).__name__) ############## DONE #################################################### if attention_layer_size is not None: self._attention_layer = layers_core.Dense(attention_layer_size, name="attention_layer", use_bias=False) self._attention_size = attention_layer_size else: self._attention_layer = None self._attention_size = attention_mechanism.values.get_shape( )[-1].value self._cell = cell self._attention_mechanism = attention_mechanism self._cell_input_fn = cell_input_fn self._attention_input_fn = attention_input_fn self._output_attention = output_attention self._alignment_history = alignment_history with ops.name_scope(name, "AttentionWrapperInit"): if initial_cell_state is None: self._initial_cell_state = None else: final_state_tensor = nest.flatten(initial_cell_state)[-1] state_batch_size = (final_state_tensor.shape[0].value or array_ops.shape(final_state_tensor)[0]) error_message = ( "When constructing AttentionWrapper %s: " % self._base_name + "Non-matching batch sizes between the memory " "(encoder output) and initial_cell_state. Are you using " "the BeamSearchDecoder? You may need to tile your initial state " "via the tf.contrib.seq2seq.tile_batch function with argument " "multiple=beam_width.") with ops.control_dependencies([ check_ops.assert_equal( state_batch_size, self._attention_mechanism.batch_size, message=error_message) ]): self._initial_cell_state = nest.map_structure( lambda s: array_ops.identity( s, name="check_initial_cell_state"), initial_cell_state)
def _verify_tf_loop_vars(init_loop_vars, first_iter_vars, basic_symbol_names, composite_symbol_names, include_shapes=True): """Verifies loop variables for consistency.""" # The whole point of _verify_tf_loop_vars is to give more useful error message # than tf-level exception by including variable names. If it's not available, # there is no point at performing this verification here. As of 2019-07-31, # operators:control_flow_test does not pass the names. if basic_symbol_names is None: return output_symbol_names = basic_symbol_names + composite_symbol_names assert len(init_loop_vars) == len(first_iter_vars) == len( output_symbol_names) for init_loop_var, first_iter_var, name in zip(init_loop_vars, first_iter_vars, output_symbol_names): try: nest.assert_same_structure(init_loop_var, first_iter_var, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError( '"{}" does not have the same nested structure after one' ' iteration.\n\n{}'.format(name, e)) def _check_same_type(name, init_loop_var, first_iter_var): """Ensures init_loop_var and first_iter_var are consistent.""" if isinstance(init_loop_var, (bool, int, float, str)): init_loop_var = ops.convert_to_tensor_v2(init_loop_var) if isinstance(first_iter_var, (bool, int, float, str)): first_iter_var = ops.convert_to_tensor_v2(first_iter_var) if (not tensor_util.is_tensor(init_loop_var) or not tensor_util.is_tensor(first_iter_var)): return # TODO(mdan): Properly account for CompositeTensors. if (not hasattr(init_loop_var, 'dtype') or not hasattr(first_iter_var, 'dtype')): return if (not hasattr(init_loop_var, 'shape') or not hasattr(first_iter_var, 'shape')): return if init_loop_var.dtype != first_iter_var.dtype: raise TypeError( '"{}" has dtype {} before the loop, but dtype {} after one' ' iteration. TensorFlow control flow requires it stays the' ' same.'.format( name, init_loop_var.dtype.name, first_iter_var.dtype.name, )) if include_shapes: init_shape = init_loop_var.shape first_iter_shape = first_iter_var.shape # TODO(b/135183013): Update needed once we support shape_invariants. if not _shape_greater_than_or_equal(init_shape, first_iter_shape): raise ValueError( '"{}" has shape {} before the loop, but shape {} after one' ' iteration. TensorFlow control flow requires it stays the' ' same or be more specific.'.format( name, init_shape, first_iter_shape)) nest.map_structure(functools.partial(_check_same_type, name), init_loop_var, first_iter_var)
def dynamic_rnn(cell, inputs, att_scores=None, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None): with vs.variable_scope(scope or "rnn") as varscope: # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. if _should_cache(): if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) # By default, time_major==False and inputs are batch-major: shaped # [batch, time, depth] # For internal calculations, we transpose to [time, batch, depth] flat_input = nest.flatten(inputs) if not time_major: # (B,T,D) => (T,B,D) flat_input = [ ops.convert_to_tensor(input_) for input_ in flat_input ] flat_input = tuple( _transpose_batch_time(input_) for input_ in flat_input) parallel_iterations = parallel_iterations or 32 if sequence_length is not None: sequence_length = math_ops.cast(sequence_length, dtypes.int32) if sequence_length.get_shape().rank not in (None, 1): raise ValueError( "sequence_length must be a vector of length batch_size, " "but saw shape: %s" % sequence_length.get_shape()) sequence_length = array_ops.identity( # Just to find it in the graph. sequence_length, name="sequence_length") batch_size = _best_effort_input_batch_size(flat_input) if initial_state is not None: state = initial_state else: if not dtype: raise ValueError( "If there is no initial_state, you must give a dtype.") if getattr(cell, "get_initial_state", None) is not None: state = cell.get_initial_state(inputs=None, batch_size=batch_size, dtype=dtype) else: state = cell.zero_state(batch_size, dtype) def _assert_has_shape(x, shape): x_shape = array_ops.shape(x) packed_shape = array_ops.stack(shape) return control_flow_ops.Assert( math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [ "Expected shape for Tensor %s is " % x.name, packed_shape, " but saw shape: ", x_shape ]) if not context.executing_eagerly() and sequence_length is not None: # Perform some shape validation with ops.control_dependencies( [_assert_has_shape(sequence_length, [batch_size])]): sequence_length = array_ops.identity(sequence_length, name="CheckSeqLen") inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) (outputs, final_state) = _dynamic_rnn_loop( cell, inputs, state, parallel_iterations=parallel_iterations, swap_memory=swap_memory, att_scores=att_scores, sequence_length=sequence_length, dtype=dtype) # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. # If we are performing batch-major calculations, transpose output back # to shape [batch, time, depth] if not time_major: # (T,B,D) => (B,T,D) outputs = nest.map_structure(_transpose_batch_time, outputs) return (outputs, final_state)
def print_v2(*inputs, **kwargs): """Print the specified inputs. Returns an operator that prints the specified inputs to a desired output stream or logging level. The inputs may be dense or sparse Tensors, primitive python objects, data structures that contain Tensors, and printable python objects. Printed tensors will recursively show the first and last `summarize` elements of each dimension. With eager execution enabled and/or inside a `tf.contrib.eager.defun` this operator will automatically execute, and users only need to call `tf.print` without using the return value. When constructing graphs outside of a `tf.contrib.eager.defun`, one must either include the returned op in the input to `session.run`, or use the operator as a control dependency for executed ops by specifying `with tf.control_dependencies([print_op])`. @compatibility(python2) In python 2.7, make sure to import the following: `from __future__ import print_function` @end_compatibility Example: Single-input usage: ```python tf.enable_eager_execution() tensor = tf.range(10) tf.print(tensor, output_stream=sys.stderr) ``` (This prints "[0 1 2 ... 7 8 9]" to sys.stderr) Multi-input usage: ```python tf.enable_eager_execution() tensor = tf.range(10) tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout) ``` (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to sys.stdout) Usage in a defun: ```python tf.enable_eager_execution() @tf.contrib.eager.defun def f(): tensor = tf.range(10) tf.print(tensor, output_stream=sys.stderr) return tensor range_tensor = f() ``` (This prints "[0 1 2 ... 7 8 9]" to sys.stderr) Usage when constructing graphs: ```python sess = tf.Session() with sess.as_default(): tensor = tf.range(10) print_op = tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout) with tf.control_dependencies([print_op]): tripled_tensor = tensor * 3 sess.run(tripled_tensor) ``` (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to sys.stdout) Note: This op is only partially compatible with Jupyter notebooks and colabs. Because it prints to the C++ standard out / standard error, this will go in the notebook kernel's console output, not in the notebook cell output. Args: *inputs: Positional arguments that are the inputs to print. Inputs in the printed output will be separated by spaces. Inputs may be python primitives, tensors, data structures such as dicts and lists that may contain tensors (with the data structures possibly nested in arbitrary ways), and printable python objects. output_stream: The output stream, logging level, or file to print to. Defaults to sys.stderr, but sys.stdout, tf.logging.info, tf.logging.warning, and tf.logging.error are also supported. To print to a file, pass a string started with "file://" followed by the file path, e.g., "file:///tmp/foo.out". summarize: The first and last `summarize` elements within each dimension are recursively printed per Tensor. If None, then the first 3 and last 3 elements of each dimension are printed for each tensor. If set to -1, it will print all elements of every tensor. name: A name for the operation (optional). Returns: A print operator that prints the specified inputs in the specified output stream or logging level. Raises: ValueError: If an unsupported output stream is specified. """ # Because we are using arbitrary-length positional arguments, python 2 # does not support explicitly specifying the keyword arguments in the # function definition. So, we manually get the keyword arguments w/ default # values here. output_stream = kwargs.pop("output_stream", sys.stderr) name = kwargs.pop("name", None) summarize = kwargs.pop("summarize", 3) if kwargs: raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs) format_name = None if name: format_name = name + "_format" # Match the C++ string constants representing the different output streams. # Keep this updated! output_stream_to_constant = { sys.stdout: "stdout", sys.stderr: "stderr", tf_logging.INFO: "log(info)", tf_logging.info: "log(info)", tf_logging.WARN: "log(warning)", tf_logging.warning: "log(warning)", tf_logging.warn: "log(warning)", tf_logging.ERROR: "log(error)", tf_logging.error: "log(error)", } if _is_filepath(output_stream): output_stream_string = output_stream else: output_stream_string = output_stream_to_constant.get(output_stream) if not output_stream_string: raise ValueError( "Unsupported output stream, logging level, or file." + str(output_stream) + ". Supported streams are sys.stdout, " "sys.stderr, tf.logging.info, " "tf.logging.warning, tf.logging.error. " + "File needs to be in the form of 'file://<filepath>'.") # If we are only printing a single string scalar, there is no need to format if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0]) and (not isinstance(inputs[0], sparse_tensor.SparseTensor)) and (inputs[0].shape.ndims == 0)and (inputs[0].dtype == dtypes.string)): formatted_string = inputs[0] # Otherwise, we construct an appropriate template for the tensors we are # printing, and format the template using those tensors. else: # For each input to this print function, we extract any nested tensors, # and construct an appropriate template to format representing the # printed input. templates = [] tensors = [] tensor_free_structure = nest.map_structure( lambda x: "" if tensor_util.is_tensor(x) else x, inputs) tensor_free_template = " ".join(pprint.pformat(x) for x in tensor_free_structure) placeholder = _generate_placeholder_string(tensor_free_template) for input_ in inputs: placeholders = [] # Use the nest utilities to flatten & process any nested elements in this # input. The placeholder for a tensor in the template should be the # placeholder string, and the placeholder for a non-tensor can just be # the printed value of the non-tensor itself. for x in nest.flatten(input_): # support sparse tensors if isinstance(x, sparse_tensor.SparseTensor): tensors.extend([x.indices, x.values, x.dense_shape]) placeholders.append( "SparseTensor(indices={}, values={}, shape={})".format( placeholder, placeholder, placeholder) ) elif tensor_util.is_tensor(x): tensors.append(x) placeholders.append(placeholder) else: placeholders.append(x) if isinstance(input_, six.string_types): # If the current input to format/print is a normal string, that string # can act as the template. cur_template = input_ else: # We pack the placeholders into a data structure that matches the # input data structure format, then format that data structure # into a string template. # # NOTE: We must use pprint.pformat here for building the template for # unordered data structures such as `dict`, because `str` doesn't # guarantee orderings, while pprint prints in sorted order. pprint # will match the ordering of `nest.flatten`. # This even works when nest.flatten reorders OrderedDicts, because # pprint is printing *after* the OrderedDicts have been reordered. cur_template = pprint.pformat( nest.pack_sequence_as(input_, placeholders)) templates.append(cur_template) # We join the templates for the various inputs into a single larger # template. We also remove all quotes surrounding the placeholders, so that # the formatted/printed output will not contain quotes around tensors. # (example of where these quotes might appear: if we have added a # placeholder string into a list, then pretty-formatted that list) template = " ".join(templates) template = template.replace("'" + placeholder + "'", placeholder) formatted_string = string_ops.string_format( inputs=tensors, template=template, placeholder=placeholder, summarize=summarize, name=format_name) return gen_logging_ops.print_v2(formatted_string, output_stream=output_stream_string, name=name)
def execution_function(input_fn): # `numpy` translates Tensors to values in Eager mode. return nest.map_structure(_non_none_constant_value, distributed_function(input_fn))
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state, all_reduce_sum_gradients): """`apply_gradients` using a `DistributionStrategy`.""" if all_reduce_sum_gradients: reduced_grads = self._aggregate_gradients(distribution, grads_and_vars) var_list = [v for _, v in grads_and_vars] grads_and_vars = zip(reduced_grads, var_list) def apply_grad_to_update_var(var, grad): """Apply gradient to variable.""" if isinstance(var, ops.Tensor): raise NotImplementedError("Trying to update a Tensor ", var) apply_kwargs = {} if isinstance(grad, ops.IndexedSlices): if var.constraint is not None: raise RuntimeError( "Cannot use a constraint function on a sparse variable.") if "apply_state" in self._sparse_apply_args: apply_kwargs["apply_state"] = apply_state return self._resource_apply_sparse_duplicate_indices( grad.values, var, grad.indices, **apply_kwargs) if "apply_state" in self._dense_apply_args: apply_kwargs["apply_state"] = apply_state update_op = self._resource_apply_dense(grad, var, **apply_kwargs) if var.constraint is not None: with ops.control_dependencies([update_op]): return var.assign(var.constraint(var)) else: return update_op eagerly_outside_functions = ops.executing_eagerly_outside_functions() update_ops = [] with ops.name_scope(name or self._name, skip_on_eager=True): for grad, var in grads_and_vars: # TODO(crccw): It's not allowed to assign PerReplica value to # MirroredVariable. Remove this after we relax this restriction. def _assume_mirrored(grad): if isinstance(grad, ds_values.PerReplica): return ds_values.Mirrored(grad.values) return grad grad = nest.map_structure(_assume_mirrored, grad) # Colocate the update with variables to avoid unnecessary communication # delays. See b/136304694. with distribution.extended.colocate_vars_with(var): with ops.name_scope("update" if eagerly_outside_functions else "update_" + var.op.name, skip_on_eager=True): update_ops.extend(distribution.extended.update( var, apply_grad_to_update_var, args=(grad,), group=False)) any_symbolic = any(isinstance(i, ops.Operation) or tf_utils.is_symbolic_tensor(i) for i in update_ops) if not context.executing_eagerly() or any_symbolic: # If the current context is graph mode or any of the update ops are # symbolic then the step update should be carried out under a graph # context. (eager updates execute immediately) with ops._get_graph_from_inputs(update_ops).as_default(): # pylint: disable=protected-access with ops.control_dependencies(update_ops): return self._iterations.assign_add(1).op return self._iterations.assign_add(1)
def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): """Convert to tensor and possibly mask `memory`. Args: memory: `Tensor`, shaped `[batch_size, max_time, ...]`. memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. check_inner_dims_defined: Python boolean. If `True`, the `memory` argument's shape is checked to ensure all but the two outermost dimensions are fully defined. Returns: A (possibly masked), checked, new `memory`. Raises: ValueError: If `check_inner_dims_defined` is `True` and not `memory.shape[2:].is_fully_defined()`. """ memory = nest.map_structure( lambda m: ops.convert_to_tensor(m, name="memory"), memory) def _maybe_mask(m, seq_len_mask): rank = m.get_shape().ndims rank = rank if rank is not None else array_ops.rank(m) extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) m_batch_size = m.shape[0].value or array_ops.shape(m)[0] if memory_sequence_length is not None: message = ( "memory_sequence_length and memory tensor batch sizes do not " "match.") with ops.control_dependencies([ check_ops.assert_equal(seq_len_batch_size, m_batch_size, message=message) ]): seq_len_mask = array_ops.reshape( seq_len_mask, array_ops.concat( (array_ops.shape(seq_len_mask), extra_ones), 0)) return m * seq_len_mask else: return m if memory_sequence_length is not None: memory_sequence_length = ops.convert_to_tensor( memory_sequence_length, name="memory_sequence_length") if check_inner_dims_defined: def _check_dims(m): if not m.get_shape()[2:].is_fully_defined(): raise ValueError( "Expected memory %s to have fully defined inner dims, " "but saw shape: %s" % (m.name, m.get_shape())) nest.map_structure(_check_dims, memory) if memory_sequence_length is None: seq_len_mask = None else: seq_len_mask = array_ops.sequence_mask( memory_sequence_length, maxlen=array_ops.shape(nest.flatten(memory)[0])[1], dtype=nest.flatten(memory)[0].dtype) seq_len_batch_size = (memory_sequence_length.shape[0].value or array_ops.shape(memory_sequence_length)[0]) return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
def build_decoder_cell(self): encoder_outputs = self.encoder_outputs encoder_last_state = self.encoder_last_state encoder_inputs_length = self.encoder_inputs_length # To use BeamSearchDecoder, encoder_outputs, encoder_last_state, encoder_inputs_length # needs to be tiled so that: [batch_size, .., ..] -> [batch_size x beam_width, .., ..] if self.use_beamsearch_decode: print("use beamsearch decoding...") encoder_outputs = seq2seq.tile_batch(self.encoder_outputs, multiplier=self.beam_width) encoder_last_state = nest.map_structure( lambda s: seq2seq.tile_batch(s, self.beam_width), self.encoder_last_state) encoder_inputs_length = seq2seq.tile_batch( self.encoder_inputs_length, multiplier=self.beam_width) # Building attention mechanism: Default Bahdanau # 'Bahdanau' style attention: https://arxiv.org/abs/1409.0473 self.attention_mechanism = attention_wrapper.BahdanauAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # 'Luong' style attention: https://arxiv.org/abs/1508.04025 if self.attention_type.lower() == 'luong': self.attention_mechanism = attention_wrapper.LuongAttention( num_units=self.hidden_units, memory=encoder_outputs, memory_sequence_length=encoder_inputs_length, ) # Building decoder_cell self.decoder_cell_list = [ self.build_single_cell() for _ in range(self.depth) ] decoder_initial_state = encoder_last_state # AttentionWrapper wraps RNNCell with the attention_mechanism # Note: We implement Attention mechanism only on the top decoder layer if self.use_joint_attention: def attn_decoder_input_fn(inputs, encoder_attention, decoder_attention): if not self.attn_input_feeding: return inputs # Essential when use_residual=True _input_layer = Dense(self.hidden_units, dtype=self.dtype, name='attn_input_feeding') print('before concat', inputs, encoder_attention, decoder_attention) return _input_layer( array_ops.concat( [inputs, encoder_attention, decoder_attention], -1)) self.decoder_cell_list[-1] = joint_attention.JointAttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.attention_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=decoder_initial_state[-1], output_attention=False, alignment_history=True, name='attention_wrapper') print('Wrapper init') else: def attn_decoder_input_fn(inputs, attention): if not self.attn_input_feeding: return inputs _input_layer = Dense(self.hidden_units, dtype=self.dtype, name='attn_input_feeding') print('before concat', inputs, attention) return _input_layer(array_ops.concat([inputs, attention], -1)) self.decoder_cell_list[-1] = attention_wrapper.AttentionWrapper( cell=self.decoder_cell_list[-1], attention_mechanism=self.attention_mechanism, attention_layer_size=self.attention_units, cell_input_fn=attn_decoder_input_fn, initial_cell_state=decoder_initial_state[-1], alignment_history=True, name='attention_wrapper') # To be compatible with AttentionWrapper, the encoder last state # of the top layer should be converted into the AttentionWrapperState form # We can easily do this by calling AttentionWrapper.zero_state # Also if beamsearch decoding is used, the batch_size argument in .zero_state # should be ${decoder_beam_width} times to the original batch_size batch_size = self.batch_size if not self.use_beamsearch_decode \ else self.batch_size * self.beam_width initial_state = [state for state in encoder_last_state] initial_state[-1] = self.decoder_cell_list[-1].zero_state( batch_size=batch_size, dtype=self.dtype) decoder_initial_state = tuple(initial_state) return MultiRNNCell(self.decoder_cell_list), decoder_initial_state
def prune(self, feeds, fetches, name=None, input_signature=None): """Extract a subgraph of this function's underlying graph. Wraps the subgraph in a new `WrappedFunction` object. Args: feeds: Input tensors to the subgraph to extract, as `Tensor` objects. fetches: Possibly-nested Python data structure containing information about outputs of the target subgraph. Each entry can either be a `Tensor` object (for data outputs), an `Operation` object (for control outputs), or a `TensorInfo` proto. Any additional shape/dtype information provided in a `TensorInfo` and not present in the original graph will be added to the returned subgraph. name: (optional) Name to give to the underlying `FuncGraph` of the returned object. If no name is provided, the graph's name will be `"pruned"`. input_signature: (optional) possibly-nested Python data structure containing `TensorSpec` objects, with which to populate the returned functions's `FuncGraph`'s `structured_input_signature` field. Returns: A new `WrappedFunction` object containing a copy of the portion of this object's graph that goes from `feeds` to `fetches`. """ # TODO(b/129646028): Add support for CompositeTensors. name = name or "pruned" flat_feeds = nest.flatten(feeds, expand_composites=True) flat_feeds = [self.graph.as_graph_element(t) for t in flat_feeds] for f in flat_feeds: if not isinstance(f, ops.Tensor): raise ValueError( "All memebers of argument `feeds` must be tensors. " f"Got {f} with type {type(f)}.") # Ignoring all feeds that are captures allows prune to be called # using wrapped_func.inputs even when it uses variables internal_captures = {id(c) for c in self.graph.internal_captures} flat_feeds = [f for f in flat_feeds if id(f) not in internal_captures] operation_fetches = [] tensor_fetches = [] tensor_infos = [] def _fetch_preprocessing_callback(fetch): """Extract out lists of ops, tensors, and tensor type info. Turns TensorInfos into Tensors in the original `fetches` structure. Also extracts ops from `fetches`. Args: fetch: The fetch to preprocess: Tensor, TensorInfo, or Operation, or string identifying a Tensor or Operation. Returns: `fetch` converted to a Tensor. """ if isinstance(fetch, ops.Operation): operation_fetches.append(fetch) return fetch elif isinstance(fetch, meta_graph_pb2.TensorInfo): tensor_infos.append(fetch) decoded = _get_element_from_tensor_info( fetch, self._func_graph) if (tensor_util.is_tf_type(decoded) or isinstance( decoded, composite_tensor.CompositeTensor)): tensor_fetches.append(decoded) else: operation_fetches.append(decoded) return decoded elif isinstance(fetch, (ops.Tensor, composite_tensor.CompositeTensor)): tensor_fetches.append(fetch) return fetch else: graph_element = self.graph.as_graph_element(fetch) return _fetch_preprocessing_callback(graph_element) fetches = nest.map_structure(_fetch_preprocessing_callback, fetches) # Expand composite tensors into their component dense Tensors. tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True) for f in flat_feeds + tensor_fetches + operation_fetches: if f.graph is not self._func_graph: raise ValueError( "Can only prune function whose feeds and fetches " f"from graph {self._func_graph}. Input " f"{f} is from a different graph {f.graph}.") with self._func_graph.as_default(): pruned_graph = func_graph.FuncGraph(name) lift_map = lift_to_graph.lift_to_graph( operation_fetches + tensor_fetches, pruned_graph, sources=flat_feeds + self.graph.internal_captures, base_graph=self._func_graph) # Note that we add the component tensors of any composite tensors to the # returned function's outputs list; the list must contain these component # tensors, or the function's sparse outputs won't work properly. pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) for external_capture, internal_capture in self.graph.captures: pruned_graph.add_capture(external_capture, lift_map[internal_capture]) for ti in tensor_infos: if ti.WhichOneof("encoding") == "name": # Dense tensors only t = pruned_graph.as_graph_element(ti.name) if tensor_util.is_tf_type(t): t.set_shape(tensor_shape.TensorShape(ti.tensor_shape)) # pylint: disable=protected-access for f in self.graph._functions.values(): pruned_graph._add_function(f) # pylint: enable=protected-access pruned_graph.variables = self.graph.variables def _structured_output_mapping(fetched): """callback for `nest.map_structure()`""" lifted = lift_map[fetched] if isinstance(lifted, ops.Operation): return None return lifted # expand_composites=True here causes composite tensors to be expanded # into their component dense Tensors, mapped to the new graph, and then # reconstituted into their original composite form. pruned_graph.structured_outputs = nest.map_structure( _structured_output_mapping, fetches, expand_composites=True) pruned_graph.structured_input_signature = input_signature pruned_fn = WrappedFunction(pruned_graph, variable_holder=self._variable_holder) pruned_fn._num_positional_args = len(flat_feeds) # pylint: disable=protected-access # TODO(kathywu): Enable keyword arguments if an input signature is specified pruned_fn._arg_keywords = [tensor.op.name for tensor in flat_feeds] # pylint: disable=protected-access return pruned_fn
def _clone_functional_model(model, input_tensors=None, share_weights=False): """Clone a functional `Model` instance. Model cloning is similar to calling a model on new inputs, except that it creates new layers (and thus new weights) instead of sharing the weights of the existing layers. Arguments: model: Instance of `Model`. input_tensors: optional list of input tensors to build the model upon. If not provided, placeholders will be created. share_weights: flag to enable sharing of non-input layers between the cloned and original model. Note this still clones the input layers. This is required when we create a per-replica copy of the model with distribution strategy; we want the weights to be shared but still feed inputs separately so we create new input layers. Returns: An instance of `Model` reproducing the behavior of the original model, on top of new inputs tensors, using newly instantiated weights. Raises: ValueError: in case of invalid `model` argument value. """ if not isinstance(model, Model): raise ValueError( 'Expected `model` argument ' 'to be a `Model` instance, got ', model) if isinstance(model, Sequential): raise ValueError( 'Expected `model` argument ' 'to be a functional `Model` instance, ' 'got a `Sequential` instance instead:', model) layer_map = {} # Cache for created layers. tensor_map = {} # Map {reference_tensor: corresponding_tensor} if input_tensors is None: # Create placeholders to build the model on top of. input_layers = [] input_tensors = [] for layer in model._input_layers: input_tensor = Input(batch_shape=layer._batch_input_shape, dtype=layer.dtype, sparse=layer.sparse, name=layer.name) input_tensors.append(input_tensor) # Cache newly created input layer. newly_created_input_layer = input_tensor._keras_history.layer layer_map[layer] = newly_created_input_layer for original_input_layer, cloned_input_layer in zip( model._input_layers, input_layers): layer_map[original_input_layer] = cloned_input_layer else: # Make sure that all input tensors come from a Keras layer. # If tensor comes from an input layer: cache the input layer. input_tensors = nest.flatten(input_tensors) input_tensors_ = [] for i in range(len(input_tensors)): input_tensor = input_tensors[i] if not K.is_keras_tensor(input_tensor): original_input_layer = model._input_layers[i] name = original_input_layer.name input_tensor = Input(tensor=input_tensor, name='input_wrapper_for_' + name) input_tensors_.append(input_tensor) # Cache newly created input layer. newly_created_input_layer = input_tensor._keras_history.layer layer_map[original_input_layer] = newly_created_input_layer else: input_tensors_.append(input_tensor) input_tensors = input_tensors_ for x, y in zip(model.inputs, input_tensors): tensor_map[x] = y # Iterated over every node in the reference model, in depth order. depth_keys = list(model._nodes_by_depth.keys()) depth_keys.sort(reverse=True) for depth in depth_keys: nodes = model._nodes_by_depth[depth] for node in nodes: # Recover the corresponding layer. layer = node.outbound_layer # Get or create layer. if layer not in layer_map: if not share_weights: # Clone layer. new_layer = _clone_layer(layer) layer_map[layer] = new_layer layer = new_layer else: # Reuse previously cloned layer. layer = layer_map[layer] # Don't call InputLayer multiple times. if isinstance(layer, InputLayer): continue # If all previous input tensors are available in tensor_map, # then call node.inbound_layer on them. if all(tensor in tensor_map for tensor in nest.flatten(node.input_tensors)): computed_tensors = nest.map_structure(lambda t: tensor_map[t], node.input_tensors) # Call layer. kwargs = node.arguments or {} output_tensors = layer(computed_tensors, **kwargs) for x, y in zip(nest.flatten(node.output_tensors), nest.flatten(output_tensors)): tensor_map[x] = y # Check that we did compute the model outputs, # then instantiate a new model from inputs and outputs. output_tensors = [] for x in model.outputs: assert x in tensor_map, 'Could not compute output ' + str(x) output_tensors.append(tensor_map[x]) input_tensors = nest.pack_sequence_as(model._nested_inputs, input_tensors) output_tensors = nest.pack_sequence_as(model._nested_outputs, output_tensors) return Model(input_tensors, output_tensors, name=model.name)