Exemplo n.º 1
0
    def write(self, data):
        """ Do a filtered write given the data """
        if not self.write_grid:
            raise ValueError('Writing is not supported')

        filter_x, filter_y, gamma = self.get_filter(data, self.write_grid, scope='write/filter')

        filter_y_transpose = tf.transpose(filter_y, [0, 2, 1])
        window = layers.linear(data, reduce_prod(self.write_grid.size))
        window = tf.reshape(window, (-1, self.write_grid.size[1], self.write_grid.size[0]))
        patch = tf.matmul(filter_y_transpose, tf.matmul(window, filter_x))

        return tf.reciprocal(tf.maximum(gamma, self.epsilon)) * layers.flatten(patch)
Exemplo n.º 2
0
def create_attention(attention_type, data_shape, read_size=None, write_size=None):
    """ Create the appropriate attention based on the passed in config """
    if attention_type == 'none':
        return NoAttention(reduce_prod(data_shape))
    elif attention_type == 'sigmoid':
        return SimpleAttention(tf.sigmoid, scope='SigmoidAttention')
    elif attention_type == 'softmax':
        return SimpleAttention(tf.nn.softmax, scope='SoftmaxAttention')
    elif attention_type == 'content':
        return ContentAttention(tuple(data_shape[:2]))
    elif attention_type == 'gaussian':
        grid_attention_cls = GaussianAttention
    elif attention_type == 'cauchy':
        grid_attention_cls = CauchyAttention
    elif attention_type == 'grid':
        grid_attention_cls = GridAttention
    else:
        raise TypeError('Unknown attention type: "{0}"'.format(attention_type))

    return grid_attention_cls(tuple(data_shape[:2]), read_size, write_size)
Exemplo n.º 3
0
 def read_size(self, data):
     """ Return the read size of the given data """
     return (reduce_prod(data.get_shape().as_list()[1:]), )
Exemplo n.º 4
0
 def read_size(self, data):
     """ Return the read size of the given data """
     return (reduce_prod(self.read_grid.size), )