コード例 #1
0
def _erase_and_write(memory, address, reset_weights, values):
    """Module to erase and write in the external memory.

  Erase operation:
    M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t)

  Add operation:
    M_t(i) = M_t'(i) + w_t(i) * a_t

  where e are the reset_weights, w the write weights and a the values.

  Args:
    memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`.
    address: 3-D tensor `[batch_size, num_writes, memory_size]`.
    reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`.
    values: 3-D tensor `[batch_size, num_writes, word_size]`.

  Returns:
    3-D tensor of shape `[batch_size, num_writes, word_size]`.
  """
    with tf.compat.v1.name_scope('erase_memory',
                                 values=[memory, address, reset_weights]):
        expand_address = tf.expand_dims(address, 3)
        reset_weights = tf.expand_dims(reset_weights, 2)
        weighted_resets = expand_address * reset_weights
        reset_gate = util.reduce_prod(1 - weighted_resets, 1)
        memory *= reset_gate

    with tf.compat.v1.name_scope('additive_write',
                                 values=[memory, address, values]):
        add_matrix = tf.matmul(address, values, adjoint_a=True)
        memory += add_matrix

    return memory
コード例 #2
0
    def _usage_after_write(self, prev_usage, write_weights):
        """Calcualtes the new usage after writing to memory.

    Args:
      prev_usage: tensor of shape `[batch_size, memory_size]`.
      write_weights: tensor of shape `[batch_size, num_writes, memory_size]`.

    Returns:
      New usage, a tensor of shape `[batch_size, memory_size]`.
    """
        with tf.name_scope('usage_after_write'):
            # Calculate the aggregated effect of all write heads
            write_weights = 1 - util.reduce_prod(1 - write_weights, 1)
            return prev_usage + (1 - prev_usage) * write_weights
コード例 #3
0
    def _usage_after_read(self, prev_usage, free_gate, read_weights):
        """Calcualtes the new usage after reading and freeing from memory.

    Args:
      prev_usage: tensor of shape `[batch_size, memory_size]`.
      free_gate: tensor of shape `[batch_size, num_reads]` with entries in the
          range [0, 1] indicating the amount that locations read from can be
          freed.
      read_weights: tensor of shape `[batch_size, num_reads, memory_size]`.

    Returns:
      New usage, a tensor of shape `[batch_size, memory_size]`.
    """
        with tf.name_scope('usage_after_read'):
            free_gate = tf.expand_dims(free_gate, -1)
            free_read_weights = free_gate * read_weights
            phi = util.reduce_prod(1 - free_read_weights, 1, name='phi')
            return prev_usage * phi