Ejemplo n.º 1
0
        def _decoding_function_outer(step_target_ids, current_time_step,
                                     memories):
            """Single-step decoding function (outer version).

            This is a wrapper around _decoding_function_inner() that does some
            housekeeping before calling that function to do the actual work.

            Args:
                step_target_ids: Tensor with shape (batch_size)
                current_time_step: scalar Tensor.
                memories: dictionary (see top-level class description)

            Returns:
            """
            with tf.name_scope(self._scope):

                shapes = {step_target_ids: ('batch_size', )}
                tf_utils.assert_shapes(shapes)

                logits, memories['base_states'], memories['high_states'] = \
                    _decoding_function_inner(
                        step_target_ids, memories['base_states'],
                        memories['high_states'], current_time_step)

                return logits, memories
Ejemplo n.º 2
0
 def gather_states(states):
     shapes = {states: ('batch_size', self._config.state_size)}
     tf_utils.assert_shapes(shapes)
     states_shape = tf.shape(states)
     state_size = states_shape[1]
     tmp = tf.reshape(states, [beam_size, batch_size_x, state_size])
     flat_tensor = tf.transpose(tmp, [1, 0, 2])
     tmp = tf.gather_nd(flat_tensor, gather_coordinates)
     tmp = tf.transpose(tmp, [1, 0, 2])
     gathered_values = tf.reshape(tmp, states_shape)
     return gathered_values
Ejemplo n.º 3
0
 def gather_attn(attn):
     # TODO Specify second and third?
     shapes = {attn: ('batch_size', None, None)}
     tf_utils.assert_shapes(shapes)
     attn_dims = tf_utils.get_shape_list(attn)
     new_shape = [beam_size, batch_size_x] + attn_dims[1:]
     tmp = tf.reshape(attn, new_shape)
     flat_tensor = tf.transpose(a=tmp, perm=[1, 0, 2, 3])
     tmp = tf.gather_nd(flat_tensor, gather_coordinates)
     tmp = tf.transpose(a=tmp, perm=[1, 0, 2, 3])
     gathered_values = tf.reshape(tmp, attn_dims)
     return gathered_values
Ejemplo n.º 4
0
    def generate_initial_memories(self, batch_size, beam_size):
        with tf.name_scope(self._scope):
            d = self._model.decoder

            shapes = {d.init_state: ('batch_size', self.config.state_size)}
            tf_utils.assert_shapes(shapes)

            high_depth = 0 if d.high_gru_stack is None \
                           else len(d.high_gru_stack.grus)

            initial_memories = {}
            initial_memories['base_states'] = d.init_state
            initial_memories['high_states'] = [d.init_state] * high_depth
            return initial_memories
Ejemplo n.º 5
0
    def gather_memories(self, memories, gather_coordinates):
        """ Gathers layer-wise memory tensors for selected beam entries.

        Args:
            memories: dictionary (see top-level class description)
            gather_coordinates: Tensor with shape [batch_size_x, beam_size, 2]

        Returns:
            Dictionary containing gathered memories.
        """
        with tf.compat.v1.name_scope(self._scope):

            shapes = {gather_coordinates: ('batch_size_x', 'beam_size', 2)}
            tf_utils.assert_shapes(shapes)

            coords_shape = tf.shape(input=gather_coordinates)
            batch_size_x, beam_size = coords_shape[0], coords_shape[1]

            def gather_attn(attn):
                # TODO Specify second and third?
                shapes = {attn: ('batch_size', None, None)}
                tf_utils.assert_shapes(shapes)
                attn_dims = tf_utils.get_shape_list(attn)
                new_shape = [beam_size, batch_size_x] + attn_dims[1:]
                tmp = tf.reshape(attn, new_shape)
                flat_tensor = tf.transpose(a=tmp, perm=[1, 0, 2, 3])
                tmp = tf.gather_nd(flat_tensor, gather_coordinates)
                tmp = tf.transpose(a=tmp, perm=[1, 0, 2, 3])
                gathered_values = tf.reshape(tmp, attn_dims)
                return gathered_values

            gathered_memories = dict()

            for layer_key in memories.keys():
                layer_dict = memories[layer_key]
                gathered_memories[layer_key] = dict()

                for attn_key in layer_dict.keys():
                    attn_tensor = layer_dict[attn_key]
                    gathered_memories[layer_key][attn_key] = \
                        gather_attn(attn_tensor)

            return gathered_memories
Ejemplo n.º 6
0
    def gather_memories(self, memories, gather_coordinates):
        """Gathers memories for selected beam entries.

        Args:
            memories: dictionary (see top-level class description)
            gather_coordinates: Tensor with shape [batch_size_x, beam_size, 2]

        Returns:
            Dictionary containing gathered memories.
        """
        with tf.name_scope(self._scope):

            shapes = {gather_coordinates: ('batch_size_x', 'beam_size', 2)}
            tf_utils.assert_shapes(shapes)

            coords_shape = tf.shape(gather_coordinates)
            batch_size_x, beam_size = coords_shape[0], coords_shape[1]

            def gather_states(states):
                shapes = {states: ('batch_size', self._config.state_size)}
                tf_utils.assert_shapes(shapes)
                states_shape = tf.shape(states)
                state_size = states_shape[1]
                tmp = tf.reshape(states, [beam_size, batch_size_x, state_size])
                flat_tensor = tf.transpose(tmp, [1, 0, 2])
                tmp = tf.gather_nd(flat_tensor, gather_coordinates)
                tmp = tf.transpose(tmp, [1, 0, 2])
                gathered_values = tf.reshape(tmp, states_shape)
                return gathered_values

            gathered_memories = {}

            gathered_memories['base_states'] = \
                gather_states(memories['base_states'])

            gathered_memories['high_states'] = [
                gather_states(states) for states in memories['high_states']
            ]

            return gathered_memories