def _erase_and_write(memory, address, reset_weights, values): """Module to erase and write in the external memory. Erase operation: M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t) Add operation: M_t(i) = M_t'(i) + w_t(i) * a_t where e are the reset_weights, w the write weights and a the values. Args: memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`. address: 3-D tensor `[batch_size, num_writes, memory_size]`. reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`. values: 3-D tensor `[batch_size, num_writes, word_size]`. Returns: 3-D tensor of shape `[batch_size, num_writes, word_size]`. """ with tf.compat.v1.name_scope('erase_memory', values=[memory, address, reset_weights]): expand_address = tf.expand_dims(address, 3) reset_weights = tf.expand_dims(reset_weights, 2) weighted_resets = expand_address * reset_weights reset_gate = util.reduce_prod(1 - weighted_resets, 1) memory *= reset_gate with tf.compat.v1.name_scope('additive_write', values=[memory, address, values]): add_matrix = tf.matmul(address, values, adjoint_a=True) memory += add_matrix return memory
def _usage_after_write(self, prev_usage, write_weights): """Calcualtes the new usage after writing to memory. Args: prev_usage: tensor of shape `[batch_size, memory_size]`. write_weights: tensor of shape `[batch_size, num_writes, memory_size]`. Returns: New usage, a tensor of shape `[batch_size, memory_size]`. """ with tf.name_scope('usage_after_write'): # Calculate the aggregated effect of all write heads write_weights = 1 - util.reduce_prod(1 - write_weights, 1) return prev_usage + (1 - prev_usage) * write_weights
def _usage_after_read(self, prev_usage, free_gate, read_weights): """Calcualtes the new usage after reading and freeing from memory. Args: prev_usage: tensor of shape `[batch_size, memory_size]`. free_gate: tensor of shape `[batch_size, num_reads]` with entries in the range [0, 1] indicating the amount that locations read from can be freed. read_weights: tensor of shape `[batch_size, num_reads, memory_size]`. Returns: New usage, a tensor of shape `[batch_size, memory_size]`. """ with tf.name_scope('usage_after_read'): free_gate = tf.expand_dims(free_gate, -1) free_read_weights = free_gate * read_weights phi = util.reduce_prod(1 - free_read_weights, 1, name='phi') return prev_usage * phi