def _decoding_function_outer(step_target_ids, current_time_step, memories): """Single-step decoding function (outer version). This is a wrapper around _decoding_function_inner() that does some housekeeping before calling that function to do the actual work. Args: step_target_ids: Tensor with shape (batch_size) current_time_step: scalar Tensor. memories: dictionary (see top-level class description) Returns: """ with tf.name_scope(self._scope): shapes = {step_target_ids: ('batch_size', )} tf_utils.assert_shapes(shapes) logits, memories['base_states'], memories['high_states'] = \ _decoding_function_inner( step_target_ids, memories['base_states'], memories['high_states'], current_time_step) return logits, memories
def gather_states(states): shapes = {states: ('batch_size', self._config.state_size)} tf_utils.assert_shapes(shapes) states_shape = tf.shape(states) state_size = states_shape[1] tmp = tf.reshape(states, [beam_size, batch_size_x, state_size]) flat_tensor = tf.transpose(tmp, [1, 0, 2]) tmp = tf.gather_nd(flat_tensor, gather_coordinates) tmp = tf.transpose(tmp, [1, 0, 2]) gathered_values = tf.reshape(tmp, states_shape) return gathered_values
def gather_attn(attn): # TODO Specify second and third? shapes = {attn: ('batch_size', None, None)} tf_utils.assert_shapes(shapes) attn_dims = tf_utils.get_shape_list(attn) new_shape = [beam_size, batch_size_x] + attn_dims[1:] tmp = tf.reshape(attn, new_shape) flat_tensor = tf.transpose(a=tmp, perm=[1, 0, 2, 3]) tmp = tf.gather_nd(flat_tensor, gather_coordinates) tmp = tf.transpose(a=tmp, perm=[1, 0, 2, 3]) gathered_values = tf.reshape(tmp, attn_dims) return gathered_values
def generate_initial_memories(self, batch_size, beam_size): with tf.name_scope(self._scope): d = self._model.decoder shapes = {d.init_state: ('batch_size', self.config.state_size)} tf_utils.assert_shapes(shapes) high_depth = 0 if d.high_gru_stack is None \ else len(d.high_gru_stack.grus) initial_memories = {} initial_memories['base_states'] = d.init_state initial_memories['high_states'] = [d.init_state] * high_depth return initial_memories
def gather_memories(self, memories, gather_coordinates): """ Gathers layer-wise memory tensors for selected beam entries. Args: memories: dictionary (see top-level class description) gather_coordinates: Tensor with shape [batch_size_x, beam_size, 2] Returns: Dictionary containing gathered memories. """ with tf.compat.v1.name_scope(self._scope): shapes = {gather_coordinates: ('batch_size_x', 'beam_size', 2)} tf_utils.assert_shapes(shapes) coords_shape = tf.shape(input=gather_coordinates) batch_size_x, beam_size = coords_shape[0], coords_shape[1] def gather_attn(attn): # TODO Specify second and third? shapes = {attn: ('batch_size', None, None)} tf_utils.assert_shapes(shapes) attn_dims = tf_utils.get_shape_list(attn) new_shape = [beam_size, batch_size_x] + attn_dims[1:] tmp = tf.reshape(attn, new_shape) flat_tensor = tf.transpose(a=tmp, perm=[1, 0, 2, 3]) tmp = tf.gather_nd(flat_tensor, gather_coordinates) tmp = tf.transpose(a=tmp, perm=[1, 0, 2, 3]) gathered_values = tf.reshape(tmp, attn_dims) return gathered_values gathered_memories = dict() for layer_key in memories.keys(): layer_dict = memories[layer_key] gathered_memories[layer_key] = dict() for attn_key in layer_dict.keys(): attn_tensor = layer_dict[attn_key] gathered_memories[layer_key][attn_key] = \ gather_attn(attn_tensor) return gathered_memories
def gather_memories(self, memories, gather_coordinates): """Gathers memories for selected beam entries. Args: memories: dictionary (see top-level class description) gather_coordinates: Tensor with shape [batch_size_x, beam_size, 2] Returns: Dictionary containing gathered memories. """ with tf.name_scope(self._scope): shapes = {gather_coordinates: ('batch_size_x', 'beam_size', 2)} tf_utils.assert_shapes(shapes) coords_shape = tf.shape(gather_coordinates) batch_size_x, beam_size = coords_shape[0], coords_shape[1] def gather_states(states): shapes = {states: ('batch_size', self._config.state_size)} tf_utils.assert_shapes(shapes) states_shape = tf.shape(states) state_size = states_shape[1] tmp = tf.reshape(states, [beam_size, batch_size_x, state_size]) flat_tensor = tf.transpose(tmp, [1, 0, 2]) tmp = tf.gather_nd(flat_tensor, gather_coordinates) tmp = tf.transpose(tmp, [1, 0, 2]) gathered_values = tf.reshape(tmp, states_shape) return gathered_values gathered_memories = {} gathered_memories['base_states'] = \ gather_states(memories['base_states']) gathered_memories['high_states'] = [ gather_states(states) for states in memories['high_states'] ] return gathered_memories