Exemplo n.º 1
0
    def __init__(self,
                 memory_size=128,
                 word_size=20,
                 num_reads=1,
                 num_writes=1,
                 name='memory_access'):
        """Creates a MemoryAccess module.

    Args:
      memory_size: The number of memory slots (N in the DNC paper).
      word_size: The width of each memory slot (W in the DNC paper)
      num_reads: The number of read heads (R in the DNC paper).
      num_writes: The number of write heads (fixed at 1 in the paper).
      name: The name of the module.
    """
        super(MemoryAccess, self).__init__(name=name)
        self._memory_size = memory_size
        self._word_size = word_size
        self._num_reads = num_reads
        self._num_writes = num_writes

        self._write_content_weights_mod = addressing.CosineWeights(
            num_writes, word_size, name='write_content_weights')
        self._read_content_weights_mod = addressing.CosineWeights(
            num_reads, word_size, name='read_content_weights')

        self._linkage = addressing.TemporalLinkage(memory_size, num_writes)
        self._freeness = addressing.Freeness(memory_size)
Exemplo n.º 2
0
  def testPrecedenceWeights(self):
    batch_size = 7
    memory_size = 3
    num_writes = 5
    module = addressing.TemporalLinkage(
        memory_size=memory_size, num_writes=num_writes)

    prev_precedence_weights = np.random.rand(batch_size, num_writes,
                                             memory_size)
    write_weights = np.random.rand(batch_size, num_writes, memory_size)

    # These should sum to at most 1 for each write head in each batch.
    write_weights /= write_weights.sum(2, keepdims=True) + 1
    prev_precedence_weights /= prev_precedence_weights.sum(2, keepdims=True) + 1

    write_weights[0, 1, :] = 0  # batch 0 head 1: no writing
    write_weights[1, 2, :] /= write_weights[1, 2, :].sum()  # b1 h2: all writing

    precedence_weights = module._precedence_weights(
        prev_precedence_weights=tf.constant(prev_precedence_weights),
        write_weights=tf.constant(write_weights))

    with self.test_session():
      precedence_weights = precedence_weights.eval()

    # precedence weights should be bounded in range [0, 1]
    self.assertGreaterEqual(precedence_weights.min(), 0)
    self.assertLessEqual(precedence_weights.max(), 1)

    # no writing in batch 0, head 1
    self.assertAllClose(precedence_weights[0, 1, :],
                        prev_precedence_weights[0, 1, :])

    # all writing in batch 1, head 2
    self.assertAllClose(precedence_weights[1, 2, :], write_weights[1, 2, :])
Exemplo n.º 3
0
    def testModule(self):
        batch_size = 7
        memory_size = 4
        num_reads = 11
        num_writes = 5
        module = addressing.TemporalLinkage(memory_size=memory_size,
                                            num_writes=num_writes)

        prev_link_in = tf.placeholder(
            tf.float32, (batch_size, num_writes, memory_size, memory_size))
        prev_precedence_weights_in = tf.placeholder(
            tf.float32, (batch_size, num_writes, memory_size))
        write_weights_in = tf.placeholder(
            tf.float32, (batch_size, num_writes, memory_size))

        state = addressing.TemporalLinkageState(
            link=np.zeros([batch_size, num_writes, memory_size, memory_size]),
            precedence_weights=np.zeros([batch_size, num_writes, memory_size]))

        calc_state = module(
            write_weights_in,
            addressing.TemporalLinkageState(
                link=prev_link_in,
                precedence_weights=prev_precedence_weights_in))

        with self.test_session() as sess:
            num_steps = 5
            for i in xrange(num_steps):
                write_weights = np.random.rand(batch_size, num_writes,
                                               memory_size)
                write_weights /= write_weights.sum(2, keepdims=True) + 1

                # Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1
                if i == num_steps - 2:
                    write_weights[0, 0, :] = util.one_hot(memory_size, 0)
                    write_weights[0, 1, :] = util.one_hot(memory_size, 3)
                elif i == num_steps - 1:
                    write_weights[0, 0, :] = util.one_hot(memory_size, 1)
                    write_weights[0, 1, :] = util.one_hot(memory_size, 2)

                state = sess.run(calc_state,
                                 feed_dict={
                                     prev_link_in: state.link,
                                     prev_precedence_weights_in:
                                     state.precedence_weights,
                                     write_weights_in: write_weights
                                 })

        # link should be bounded in range [0, 1]
        self.assertGreaterEqual(state.link.min(), 0)
        self.assertLessEqual(state.link.max(), 1)

        # link diagonal should be zero
        self.assertAllEqual(
            state.link[:, :, range(memory_size),
                       range(memory_size)],
            np.zeros([batch_size, num_writes, memory_size]))

        # link rows and columns should sum to at most 1
        self.assertLessEqual(state.link.sum(2).max(), 1)
        self.assertLessEqual(state.link.sum(3).max(), 1)

        # records our transitions in batch 0: head 0: 0->1, and head 1: 3->2
        self.assertAllEqual(state.link[0, 0, :, 0],
                            util.one_hot(memory_size, 1))
        self.assertAllEqual(state.link[0, 1, :, 3],
                            util.one_hot(memory_size, 2))

        # Now test calculation of forward and backward read weights
        prev_read_weights = np.random.rand(batch_size, num_reads, memory_size)
        prev_read_weights[0, 5, :] = util.one_hot(memory_size,
                                                  0)  # read 5, posn 0
        prev_read_weights[0, 6, :] = util.one_hot(memory_size,
                                                  2)  # read 6, posn 2
        forward_read_weights = module.directional_read_weights(
            tf.constant(state.link),
            tf.constant(prev_read_weights, dtype=tf.float32),
            forward=True)
        backward_read_weights = module.directional_read_weights(
            tf.constant(state.link),
            tf.constant(prev_read_weights, dtype=tf.float32),
            forward=False)

        with self.test_session():
            forward_read_weights = forward_read_weights.eval()
            backward_read_weights = backward_read_weights.eval()

        # Check directional weights calculated correctly.
        self.assertAllEqual(
            forward_read_weights[0, 5, 0, :],  # read=5, write=0
            util.one_hot(memory_size, 1))
        self.assertAllEqual(
            backward_read_weights[0, 6, 1, :],  # read=6, write=1
            util.one_hot(memory_size, 3))