Example #1
0
def _stack_hypotheses(tensor, batch_size, beam_width, output_dim):
    '''
    stack the hypotheses

    args:
        tensor: a [batch_size x beam_width x output_dim ...] tensor
        batch_size: the batch size
        beam_width: the beam_width
        output_dim: the output_dim

    returns: a [batch_size x beam_width*output_dim x ...] tensor
    '''
    if isinstance(tensor, tf.TensorArray):
        return map_ta(
            partial(_stack_hypotheses,
                    batch_size=batch_size,
                    beam_width=beam_width,
                    output_dim=output_dim), tensor)

    if tensor.shape.ndims < 3:
        return tensor
    if tensor.shape.ndims == 3:
        return tf.reshape(tensor, [batch_size, beam_width * output_dim])
    if tensor.shape[-1].value is not None:

        return tf.reshape(
            tensor,
            [batch_size, beam_width * output_dim,
             tensor.get_shape()[-1]])

    return tf.reshape(tensor, [batch_size, beam_width * output_dim, -1])
Example #2
0
def _unstack(tensor, batch_size, beam_width):
    '''
    stack the beam elements

    args:
        tensor: a [batch_size * beam_width x ...] tensor
        batch_size: the batch size
        beam_width: the beam_width

    returns: a [batch_size x beam_width x ...] tensor
    '''
    if isinstance(tensor, tf.TensorArray):
        return map_ta(
            partial(_unstack, batch_size=batch_size, beam_width=beam_width),
            tensor)

    if tensor.shape.ndims != 2:
        return tensor
    else:
        if tensor.shape[-1].value is not None:
            return tf.reshape(tensor,
                              [batch_size, beam_width,
                               tensor.get_shape()[-1]])
        else:
            return tf.reshape(tensor, [batch_size, beam_width, -1])
Example #3
0
def _gather_state(state, indices):
    '''do the gather on the strates'''

    if isinstance(state, tf.TensorArray):
        return map_ta(partial(_gather_state, indices=indices), state)

    if state.shape.ndims == 0:
        return state

    return tf.gather_nd(state, indices, name='prune_state')
Example #4
0
def _tile_state(state, output_dim):
    '''
    tile the states

    args:
        state: the state to tile
        output_dim: the output dimension

    returns: the tiled state
    '''
    if isinstance(state, tf.TensorArray):
        return map_ta(partial(_tile_state, output_dim=output_dim), state)

    if state.shape.ndims != 3:
        return state

    return tf.tile(tf.expand_dims(state, 2), [1, 1, output_dim, 1])