Ejemplo n.º 1
0
  def testPrecedenceWeights(self):
    batch_size = 7
    memory_size = 3
    num_writes = 5
    module = addressing.TemporalLinkage(
        memory_size=memory_size, num_writes=num_writes)

    prev_precedence_weights = np.random.rand(batch_size, num_writes,
                                             memory_size)
    write_weights = np.random.rand(batch_size, num_writes, memory_size)

    # These should sum to at most 1 for each write head in each batch.
    write_weights /= write_weights.sum(2, keepdims=True) + 1
    prev_precedence_weights /= prev_precedence_weights.sum(2, keepdims=True) + 1

    write_weights[0, 1, :] = 0  # batch 0 head 1: no writing
    write_weights[1, 2, :] /= write_weights[1, 2, :].sum()  # b1 h2: all writing

    precedence_weights = module._precedence_weights(
        prev_precedence_weights=tf.constant(prev_precedence_weights),
        write_weights=tf.constant(write_weights))

    with self.test_session():
      precedence_weights = precedence_weights.eval()

    # precedence weights should be bounded in range [0, 1]
    self.assertGreaterEqual(precedence_weights.min(), 0)
    self.assertLessEqual(precedence_weights.max(), 1)

    # no writing in batch 0, head 1
    self.assertAllClose(precedence_weights[0, 1, :],
                        prev_precedence_weights[0, 1, :])

    # all writing in batch 1, head 2
    self.assertAllClose(precedence_weights[1, 2, :], write_weights[1, 2, :])
Ejemplo n.º 2
0
    def __init__(self,
                 memory_size=128,
                 word_size=20,
                 num_reads=1,
                 num_writes=1,
                 name='memory_access'):
        """Creates a MemoryAccess module.

    Args:
      memory_size: The number of memory slots (N in the DNC paper).
      word_size: The width of each memory slot (W in the DNC paper)
      num_reads: The number of read heads (R in the DNC paper).
      num_writes: The number of write heads (fixed at 1 in the paper).
      name: The name of the module.
    """
        super(MemoryAccess, self).__init__(name=name)
        self._memory_size = memory_size
        self._word_size = word_size
        self._num_reads = num_reads
        self._num_writes = num_writes

        self._write_content_weights_mod = addressing.CosineWeights(
            num_writes, word_size, name='write_content_weights')
        self._read_content_weights_mod = addressing.CosineWeights(
            num_reads, word_size, name='read_content_weights')

        self._linkage = addressing.TemporalLinkage(memory_size, num_writes)
        self._freeness = addressing.Freeness(memory_size)
Ejemplo n.º 3
0
  def testModule(self):
    batch_size = 7
    memory_size = 4
    num_reads = 11
    num_writes = 5
    module = addressing.TemporalLinkage(
        memory_size=memory_size, num_writes=num_writes)

    prev_link_in = tf.placeholder(
        tf.float32, (batch_size, num_writes, memory_size, memory_size))
    prev_precedence_weights_in = tf.placeholder(
        tf.float32, (batch_size, num_writes, memory_size))
    write_weights_in = tf.placeholder(tf.float32,
                                      (batch_size, num_writes, memory_size))

    state = addressing.TemporalLinkageState(
        link=np.zeros([batch_size, num_writes, memory_size, memory_size]),
        precedence_weights=np.zeros([batch_size, num_writes, memory_size]))

    calc_state = module(write_weights_in,
                        addressing.TemporalLinkageState(
                            link=prev_link_in,
                            precedence_weights=prev_precedence_weights_in))

    with self.test_session() as sess:
      num_steps = 5
      for i in xrange(num_steps):
        write_weights = np.random.rand(batch_size, num_writes, memory_size)
        write_weights /= write_weights.sum(2, keepdims=True) + 1

        # Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1
        if i == num_steps - 2:
          write_weights[0, 0, :] = util.one_hot(memory_size, 0)
          write_weights[0, 1, :] = util.one_hot(memory_size, 3)
        elif i == num_steps - 1:
          write_weights[0, 0, :] = util.one_hot(memory_size, 1)
          write_weights[0, 1, :] = util.one_hot(memory_size, 2)

        state = sess.run(
            calc_state,
            feed_dict={
                prev_link_in: state.link,
                prev_precedence_weights_in: state.precedence_weights,
                write_weights_in: write_weights
            })

    # link should be bounded in range [0, 1]
    self.assertGreaterEqual(state.link.min(), 0)
    self.assertLessEqual(state.link.max(), 1)

    # link diagonal should be zero
    self.assertAllEqual(
        state.link[:, :, range(memory_size), range(memory_size)],
        np.zeros([batch_size, num_writes, memory_size]))

    # link rows and columns should sum to at most 1
    self.assertLessEqual(state.link.sum(2).max(), 1)
    self.assertLessEqual(state.link.sum(3).max(), 1)

    # records our transitions in batch 0: head 0: 0->1, and head 1: 3->2
    self.assertAllEqual(state.link[0, 0, :, 0], util.one_hot(memory_size, 1))
    self.assertAllEqual(state.link[0, 1, :, 3], util.one_hot(memory_size, 2))

    # Now test calculation of forward and backward read weights
    prev_read_weights = np.random.rand(batch_size, num_reads, memory_size)
    prev_read_weights[0, 5, :] = util.one_hot(memory_size, 0)  # read 5, posn 0
    prev_read_weights[0, 6, :] = util.one_hot(memory_size, 2)  # read 6, posn 2
    forward_read_weights = module.directional_read_weights(
        tf.constant(state.link),
        tf.constant(prev_read_weights, dtype=tf.float32),
        forward=True)
    backward_read_weights = module.directional_read_weights(
        tf.constant(state.link),
        tf.constant(prev_read_weights, dtype=tf.float32),
        forward=False)

    with self.test_session():
      forward_read_weights = forward_read_weights.eval()
      backward_read_weights = backward_read_weights.eval()

    # Check directional weights calculated correctly.
    self.assertAllEqual(
        forward_read_weights[0, 5, 0, :],  # read=5, write=0
        util.one_hot(memory_size, 1))
    self.assertAllEqual(
        backward_read_weights[0, 6, 1, :],  # read=6, write=1
        util.one_hot(memory_size, 3))