def learned_position_encoding(inputs, mask, embed_dim): T = tf.shape(inputs)[1] outputs = tf.range(tf.shape(inputs)[1]) # (T_q) outputs = tf.expand_dims(outputs, 0) # (1, T_q) outputs = tf.tile(outputs, [tf.shape(inputs)[0], 1]) # (N, T_q) outputs = embed_seq(outputs, T, embed_dim, zero_pad=False, scale=False) return tf.expand_dims(tf.to_float(mask), -1) * outputs
def _erase_and_write(memory, address, reset_weights, values): """Module to erase and write in the external memory. Erase operation: M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t) Add operation: M_t(i) = M_t'(i) + w_t(i) * a_t where e are the reset_weights, w the write weights and a the values. Args: memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`. address: 3-D tensor `[batch_size, num_writes, memory_size]`. reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`. values: 3-D tensor `[batch_size, num_writes, word_size]`. Returns: 3-D tensor of shape `[batch_size, num_writes, word_size]`. """ with tf.name_scope('erase_memory', values=[memory, address, reset_weights]): expand_address = tf.expand_dims(address, 3) reset_weights = tf.expand_dims(reset_weights, 2) weighted_resets = expand_address * reset_weights reset_gate = tf.reduce_prod(1 - weighted_resets, [1]) memory *= reset_gate with tf.name_scope('additive_write', values=[memory, address, values]): add_matrix = tf.matmul(address, values, adjoint_a=True) memory += add_matrix return memory
def sinusoidal_position_encoding(inputs, mask, repr_dim): T = tf.shape(inputs)[1] pos = tf.reshape(tf.range(0.0, tf.to_float(T), dtype=tf.float32), [-1, 1]) i = np.arange(0, repr_dim, 2, np.float32) denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1]) enc = tf.expand_dims( tf.concat( [tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0) return tf.tile(enc, [tf.shape(inputs)[0], 1, 1]) * tf.expand_dims( tf.to_float(mask), -1)
def write_allocation_weights(self, usage, write_gates, num_writes): """Calculates freeness-based locations for writing to. This finds unused memory by ranking the memory locations by usage, for each write head. (For more than one write head, we use a "simulated new usage" which takes into account the fact that the previous write head will increase the usage in that area of the memory.) Args: usage: A tensor of shape `[batch_size, memory_size]` representing current memory usage. write_gates: A tensor of shape `[batch_size, num_writes]` with values in the range [0, 1] indicating how much each write head does writing based on the address returned here (and hence how much usage increases). num_writes: The number of write heads to calculate write weights for. Returns: tensor of shape `[batch_size, num_writes, memory_size]` containing the freeness-based write locations. Note that this isn't scaled by `write_gate`; this scaling must be applied externally. """ with tf.name_scope('write_allocation_weights'): # expand gatings over memory locations write_gates = tf.expand_dims(write_gates, -1) allocation_weights = [] for i in range(num_writes): allocation_weights.append(self._allocation(usage)) # update usage to take into account writing to this new allocation usage += ((1 - usage) * write_gates[:, i, :] * allocation_weights[i]) # Pack the allocation weights for the write heads into one tensor. return tf.stack(allocation_weights, axis=1)
def _read_weights(self, inputs, memory, prev_read_weights, link): """Calculates read weights for each read head. The read weights are a combination of following the link graphs in the forward or backward directions from the previous read position, and doing content-based lookup. The interpolation between these different modes is done by `inputs['read_mode']`. Args: inputs: Controls for this access module. This contains the content-based keys to lookup, and the weightings for the different read modes. memory: A tensor of shape `[batch_size, memory_size, word_size]` containing the current memory contents to do content-based lookup. prev_read_weights: A tensor of shape `[batch_size, num_reads, memory_size]` containing the previous read locations. link: A tensor of shape `[batch_size, num_writes, memory_size, memory_size]` containing the temporal write transition graphs. Returns: A tensor of shape `[batch_size, num_reads, memory_size]` containing the read weights for each read head. """ with tf.name_scope( 'read_weights', values=[inputs, memory, prev_read_weights, link]): # c_t^{r, i} - The content weightings for each read head. content_weights = self._read_content_weights_mod( memory, inputs['read_content_keys'], inputs['read_content_strengths']) # Calculates f_t^i and b_t^i. forward_weights = self._linkage.directional_read_weights( link, prev_read_weights, forward=True) backward_weights = self._linkage.directional_read_weights( link, prev_read_weights, forward=False) backward_mode = inputs['read_mode'][:, :, :self._num_writes] forward_mode = ( inputs['read_mode'][:, :, self._num_writes:2 * self._num_writes]) content_mode = inputs['read_mode'][:, :, 2 * self._num_writes] read_weights = ( tf.expand_dims(content_mode, 2) * content_weights + tf.reduce_sum( tf.expand_dims(forward_mode, 3) * forward_weights, 2) + tf.reduce_sum(tf.expand_dims(backward_mode, 3) * backward_weights, 2)) return read_weights
def __init__( self, learning_rate, num_layers, size, size_layer, output_size, forget_bias = 0.1, lambda_coeff = 0.5 ): def lstm_cell(size_layer): return tf.nn.rnn_cell.GRUCell(size_layer) rnn_cells = tf.nn.rnn_cell.MultiRNNCell( [lstm_cell(size_layer) for _ in range(num_layers)], state_is_tuple = False, ) self.X = tf.placeholder(tf.float32, (None, None, size)) self.Y = tf.placeholder(tf.float32, (None, output_size)) drop = tf.contrib.rnn.DropoutWrapper( rnn_cells, output_keep_prob = forget_bias ) self.hidden_layer = tf.placeholder( tf.float32, (None, num_layers * size_layer) ) _, last_state = tf.nn.dynamic_rnn( drop, self.X, initial_state = self.hidden_layer, dtype = tf.float32 ) self.z_mean = tf.layers.dense(last_state, size) self.z_log_sigma = tf.layers.dense(last_state, size) epsilon = tf.random_normal(tf.shape(self.z_log_sigma)) self.z_vector = self.z_mean + tf.exp(self.z_log_sigma) with tf.variable_scope('decoder', reuse = False): rnn_cells_dec = tf.nn.rnn_cell.MultiRNNCell( [lstm_cell(size_layer) for _ in range(num_layers)], state_is_tuple = False ) drop_dec = tf.contrib.rnn.DropoutWrapper( rnn_cells_dec, output_keep_prob = forget_bias ) x = tf.concat([tf.expand_dims(self.z_vector, axis=0), self.X], axis = 1) self.outputs, self.last_state = tf.nn.dynamic_rnn( drop_dec, self.X, initial_state = last_state, dtype = tf.float32 ) self.logits = tf.layers.dense(self.outputs[-1], output_size) self.lambda_coeff = lambda_coeff self.kl_loss = -0.5 * tf.reduce_sum(1.0 + 2 * self.z_log_sigma - self.z_mean ** 2 - tf.exp(2 * self.z_log_sigma), 1) self.kl_loss = tf.scalar_mul(self.lambda_coeff, self.kl_loss) self.cost = tf.reduce_mean(tf.square(self.Y - self.logits) + self.kl_loss) self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize( self.cost )
def multihead_attn(queries, keys, q_masks, k_masks, future_binding, num_units, num_heads): T_q = tf.shape(queries)[1] T_k = tf.shape(keys)[1] Q = tf.layers.dense(queries, num_units, name='Q') K_V = tf.layers.dense(keys, 2 * num_units, name='K_V') K, V = tf.split(K_V, 2, -1) Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) align = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) align = align / np.sqrt(K_.get_shape().as_list()[-1]) paddings = tf.fill(tf.shape(align), float('-inf')) key_masks = k_masks key_masks = tf.tile(key_masks, [num_heads, 1]) key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, T_q, 1]) align = tf.where(tf.equal(key_masks, 0), paddings, align) if future_binding: lower_tri = tf.ones([T_q, T_k]) lower_tri = tf.linalg.LinearOperatorLowerTriangular( lower_tri).to_dense() masks = tf.tile(tf.expand_dims(lower_tri, 0), [tf.shape(align)[0], 1, 1]) align = tf.where(tf.equal(masks, 0), paddings, align) align = tf.nn.softmax(align) query_masks = tf.to_float(q_masks) query_masks = tf.tile(query_masks, [num_heads, 1]) query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, T_k]) align *= query_masks outputs = tf.matmul(align, V_) outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) outputs += queries outputs = layer_norm(outputs) return outputs
def decoder_block(inp, n_hidden, filter_size): inp = tf.expand_dims(inp, 2) inp = tf.pad(inp, [[0, 0], [filter_size[0] - 1, 0], [0, 0], [0, 0]]) conv = tf.layers.conv2d(inp, n_hidden, filter_size, padding='VALID', activation=None) conv = tf.squeeze(conv, 2) return conv
def position_encoding(inputs): T = tf.shape(inputs)[1] repr_dim = inputs.get_shape()[-1].value pos = tf.reshape(tf.range(0.0, tf.to_float(T), dtype=tf.float32), [-1, 1]) i = np.arange(0, repr_dim, 2, np.float32) denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1]) enc = tf.expand_dims( tf.concat( [tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0) return tf.tile(enc, [tf.shape(inputs)[0], 1, 1])
def _write_weights(self, inputs, memory, usage): """Calculates the memory locations to write to. This uses a combination of content-based lookup and finding an unused location in memory, for each write head. Args: inputs: Collection of inputs to the access module, including controls for how to chose memory writing, such as the content to look-up and the weighting between content-based and allocation-based addressing. memory: A tensor of shape `[batch_size, memory_size, word_size]` containing the current memory contents. usage: Current memory usage, which is a tensor of shape `[batch_size, memory_size]`, used for allocation-based addressing. Returns: tensor of shape `[batch_size, num_writes, memory_size]` indicating where to write to (if anywhere) for each write head. """ with tf.name_scope('write_weights', values=[inputs, memory, usage]): # c_t^{w, i} - The content-based weights for each write head. write_content_weights = self._write_content_weights_mod( memory, inputs['write_content_keys'], inputs['write_content_strengths']) # a_t^i - The allocation weights for each write head. write_allocation_weights = self._freeness.write_allocation_weights( usage=usage, write_gates=(inputs['allocation_gate'] * inputs['write_gate']), num_writes=self._num_writes) # Expands gates over memory locations. allocation_gate = tf.expand_dims(inputs['allocation_gate'], -1) write_gate = tf.expand_dims(inputs['write_gate'], -1) # w_t^{w, i} - The write weightings for each write head. return write_gate * (allocation_gate * write_allocation_weights + (1 - allocation_gate) * write_content_weights)
def _link(self, prev_link, prev_precedence_weights, write_weights): """Calculates the new link graphs. For each write head, the link is a directed graph (represented by a matrix with entries in range [0, 1]) whose vertices are the memory locations, and an edge indicates temporal ordering of writes. Args: prev_link: A tensor of shape `[batch_size, num_writes, memory_size, memory_size]` representing the previous link graphs for each write head. prev_precedence_weights: A tensor of shape `[batch_size, num_writes, memory_size]` which is the previous "aggregated" write weights for each write head. write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` containing the new locations in memory written to. Returns: A tensor of shape `[batch_size, num_writes, memory_size, memory_size]` containing the new link graphs for each write head. """ with tf.name_scope('link'): batch_size = prev_link.get_shape()[0].value write_weights_i = tf.expand_dims(write_weights, 3) write_weights_j = tf.expand_dims(write_weights, 2) prev_precedence_weights_j = tf.expand_dims(prev_precedence_weights, 2) prev_link_scale = 1 - write_weights_i - write_weights_j new_link = write_weights_i * prev_precedence_weights_j link = prev_link_scale * prev_link + new_link # Return the link with the diagonal set to zero, to remove self-looping # edges. return tf.matrix_set_diag( link, tf.zeros([batch_size, self._num_writes, self._memory_size], dtype=link.dtype))
def weighted_softmax(activations, strengths, strengths_op): """Returns softmax over activations multiplied by positive strengths. Args: activations: A tensor of shape `[batch_size, num_heads, memory_size]`, of activations to be transformed. Softmax is taken over the last dimension. strengths: A tensor of shape `[batch_size, num_heads]` containing strengths to multiply by the activations prior to the softmax. strengths_op: An operation to transform strengths before softmax. Returns: A tensor of same shape as `activations` with weighted softmax applied. """ transformed_strengths = tf.expand_dims(strengths_op(strengths), -1) sharp_activations = activations * transformed_strengths softmax = snt.BatchApply(module_or_op=tf.nn.softmax) return softmax(sharp_activations)
def _usage_after_read(self, prev_usage, free_gate, read_weights): """Calcualtes the new usage after reading and freeing from memory. Args: prev_usage: tensor of shape `[batch_size, memory_size]`. free_gate: tensor of shape `[batch_size, num_reads]` with entries in the range [0, 1] indicating the amount that locations read from can be freed. read_weights: tensor of shape `[batch_size, num_reads, memory_size]`. Returns: New usage, a tensor of shape `[batch_size, memory_size]`. """ with tf.name_scope('usage_after_read'): free_gate = tf.expand_dims(free_gate, -1) free_read_weights = free_gate * read_weights phi = tf.reduce_prod(1 - free_read_weights, [1], name='phi') return prev_usage * phi