def write(self, data): """ Do a filtered write given the data """ if not self.write_grid: raise ValueError('Writing is not supported') filter_x, filter_y, gamma = self.get_filter(data, self.write_grid, scope='write/filter') filter_y_transpose = tf.transpose(filter_y, [0, 2, 1]) window = layers.linear(data, reduce_prod(self.write_grid.size)) window = tf.reshape(window, (-1, self.write_grid.size[1], self.write_grid.size[0])) patch = tf.matmul(filter_y_transpose, tf.matmul(window, filter_x)) return tf.reciprocal(tf.maximum(gamma, self.epsilon)) * layers.flatten(patch)
def create_attention(attention_type, data_shape, read_size=None, write_size=None): """ Create the appropriate attention based on the passed in config """ if attention_type == 'none': return NoAttention(reduce_prod(data_shape)) elif attention_type == 'sigmoid': return SimpleAttention(tf.sigmoid, scope='SigmoidAttention') elif attention_type == 'softmax': return SimpleAttention(tf.nn.softmax, scope='SoftmaxAttention') elif attention_type == 'content': return ContentAttention(tuple(data_shape[:2])) elif attention_type == 'gaussian': grid_attention_cls = GaussianAttention elif attention_type == 'cauchy': grid_attention_cls = CauchyAttention elif attention_type == 'grid': grid_attention_cls = GridAttention else: raise TypeError('Unknown attention type: "{0}"'.format(attention_type)) return grid_attention_cls(tuple(data_shape[:2]), read_size, write_size)
def read_size(self, data): """ Return the read size of the given data """ return (reduce_prod(data.get_shape().as_list()[1:]), )
def read_size(self, data): """ Return the read size of the given data """ return (reduce_prod(self.read_grid.size), )