Exemple #1
0
    def testReadWeights(self):
        memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) -
                       0.5)
        prev_read_weights = np.random.rand(BATCH_SIZE, NUM_READS, MEMORY_SIZE)
        prev_read_weights /= prev_read_weights.sum(2, keepdims=True) + 1

        link = np.random.rand(BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE)
        # Row and column sums should be at most 1:
        link /= np.maximum(link.sum(2, keepdims=True), 1)
        link /= np.maximum(link.sum(3, keepdims=True), 1)

        # We query the memory on the third location in memory, and select a large
        # strength on the query. Then we select a content-based read-mode.
        read_content_keys = np.random.rand(BATCH_SIZE, NUM_READS, WORD_SIZE)
        read_content_keys[0, 0] = memory[0, 3]
        read_content_strengths = tf.constant(100.,
                                             shape=[BATCH_SIZE, NUM_READS],
                                             dtype=tf.float64)
        read_mode = np.random.rand(BATCH_SIZE, NUM_READS, 1 + 2 * NUM_WRITES)
        read_mode[0, 0, :] = util.one_hot(1 + 2 * NUM_WRITES, 2 * NUM_WRITES)
        inputs = {
            'read_content_keys': tf.constant(read_content_keys),
            'read_content_strengths': read_content_strengths,
            'read_mode': tf.constant(read_mode),
        }
        read_weights = self.module._read_weights(inputs, memory,
                                                 prev_read_weights, link)
        with self.test_session():
            read_weights = read_weights.eval()

        # read_weights for batch 0, read head 0 should be memory location 3
        self.assertAllClose(read_weights[0, 0, :],
                            util.one_hot(MEMORY_SIZE, 3),
                            atol=1e-3)
Exemple #2
0
    def testWriteAllocationWeights(self):
        batch_size = 7
        memory_size = 23
        num_writes = 5
        module = addressing.Freeness(memory_size)

        usage = np.random.rand(batch_size, memory_size)
        write_gates = np.random.rand(batch_size, num_writes)

        # Turn off gates for heads 1 and 3 in batch 0. This doesn't scaling down the
        # weighting, but it means that the usage doesn't change, so we should get
        # the same allocation weightings for: (1, 2) and (3, 4) (but all others
        # being different).
        write_gates[0, 1] = 0
        write_gates[0, 3] = 0
        # and turn heads 0 and 2 on for full effect.
        write_gates[0, 0] = 1
        write_gates[0, 2] = 1

        # In batch 1, make one of the usages 0 and another almost 0, so that these
        # entries get most of the allocation weights for the first and second heads.
        usage[
            1] = usage[1] * 0.9 + 0.1  # make sure all entries are in [0.1, 1]
        usage[1][4] = 0  # write head 0 should get allocated to position 4
        usage[1][3] = 1e-4  # write head 1 should get allocated to position 3
        write_gates[1, 0] = 1  # write head 0 fully on
        write_gates[1, 1] = 1  # write head 1 fully on

        weights = module.write_allocation_weights(
            usage=tf.constant(usage),
            write_gates=tf.constant(write_gates),
            num_writes=num_writes)

        with self.test_session():
            weights = weights.eval()

        # Check that all weights are between 0 and 1
        self.assertGreaterEqual(weights.min(), 0)
        self.assertLessEqual(weights.max(), 1)

        # Check that weights sum to close to 1
        self.assertAllClose(np.sum(weights, axis=2),
                            np.ones([batch_size, num_writes]),
                            atol=1e-3)

        # Check the same / different allocation weight pairs as described above.
        self.assertGreater(
            np.abs(weights[0, 0, :] - weights[0, 1, :]).max(), 0.1)
        self.assertAllEqual(weights[0, 1, :], weights[0, 2, :])
        self.assertGreater(
            np.abs(weights[0, 2, :] - weights[0, 3, :]).max(), 0.1)
        self.assertAllEqual(weights[0, 3, :], weights[0, 4, :])

        self.assertAllClose(weights[1][0],
                            util.one_hot(memory_size, 4),
                            atol=1e-3)
        self.assertAllClose(weights[1][1],
                            util.one_hot(memory_size, 3),
                            atol=1e-3)
Exemple #3
0
    def testWriteWeights(self):
        memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) -
                       0.5)
        usage = np.random.rand(BATCH_SIZE, MEMORY_SIZE)

        allocation_gate = np.random.rand(BATCH_SIZE, NUM_WRITES)
        write_gate = np.random.rand(BATCH_SIZE, NUM_WRITES)
        write_content_keys = np.random.rand(BATCH_SIZE, NUM_WRITES, WORD_SIZE)
        write_content_strengths = np.random.rand(BATCH_SIZE, NUM_WRITES)

        # Check that turning on allocation gate fully brings the write gate to
        # the allocation weighting (which we will control by controlling the usage).
        usage[:, 3] = 0
        allocation_gate[:, 0] = 1
        write_gate[:, 0] = 1

        inputs = {
            'allocation_gate': tf.constant(allocation_gate),
            'write_gate': tf.constant(write_gate),
            'write_content_keys': tf.constant(write_content_keys),
            'write_content_strengths': tf.constant(write_content_strengths)
        }

        weights = self.module._write_weights(inputs, tf.constant(memory),
                                             tf.constant(usage))

        with self.test_session():
            weights = weights.eval()

        # Check the weights sum to their target gating.
        self.assertAllClose(np.sum(weights, axis=2), write_gate, atol=5e-2)

        # Check that we fully allocated to the third row.
        weights_0_0_target = util.one_hot(MEMORY_SIZE, 3)
        self.assertAllClose(weights[0, 0], weights_0_0_target, atol=1e-3)
  def testModule(self):
    batch_size = 7
    memory_size = 4
    num_reads = 11
    num_writes = 5
    module = addressing.TemporalLinkage(
        memory_size=memory_size, num_writes=num_writes)

    prev_link_in = tf.placeholder(
        tf.float32, (batch_size, num_writes, memory_size, memory_size))
    prev_precedence_weights_in = tf.placeholder(
        tf.float32, (batch_size, num_writes, memory_size))
    write_weights_in = tf.placeholder(tf.float32,
                                      (batch_size, num_writes, memory_size))

    state = addressing.TemporalLinkageState(
        link=np.zeros([batch_size, num_writes, memory_size, memory_size]),
        precedence_weights=np.zeros([batch_size, num_writes, memory_size]))

    calc_state = module(write_weights_in,
                        addressing.TemporalLinkageState(
                            link=prev_link_in,
                            precedence_weights=prev_precedence_weights_in))

    with self.test_session() as sess:
      num_steps = 5
      for i in xrange(num_steps):
        write_weights = np.random.rand(batch_size, num_writes, memory_size)
        write_weights /= write_weights.sum(2, keepdims=True) + 1

        # Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1
        if i == num_steps - 2:
          write_weights[0, 0, :] = util.one_hot(memory_size, 0)
          write_weights[0, 1, :] = util.one_hot(memory_size, 3)
        elif i == num_steps - 1:
          write_weights[0, 0, :] = util.one_hot(memory_size, 1)
          write_weights[0, 1, :] = util.one_hot(memory_size, 2)

        state = sess.run(
            calc_state,
            feed_dict={
                prev_link_in: state.link,
                prev_precedence_weights_in: state.precedence_weights,
                write_weights_in: write_weights
            })

    # link should be bounded in range [0, 1]
    self.assertGreaterEqual(state.link.min(), 0)
    self.assertLessEqual(state.link.max(), 1)

    # link diagonal should be zero
    self.assertAllEqual(
        state.link[:, :, range(memory_size), range(memory_size)],
        np.zeros([batch_size, num_writes, memory_size]))

    # link rows and columns should sum to at most 1
    self.assertLessEqual(state.link.sum(2).max(), 1)
    self.assertLessEqual(state.link.sum(3).max(), 1)

    # records our transitions in batch 0: head 0: 0->1, and head 1: 3->2
    self.assertAllEqual(state.link[0, 0, :, 0], util.one_hot(memory_size, 1))
    self.assertAllEqual(state.link[0, 1, :, 3], util.one_hot(memory_size, 2))

    # Now test calculation of forward and backward read weights
    prev_read_weights = np.random.rand(batch_size, num_reads, memory_size)
    prev_read_weights[0, 5, :] = util.one_hot(memory_size, 0)  # read 5, posn 0
    prev_read_weights[0, 6, :] = util.one_hot(memory_size, 2)  # read 6, posn 2
    forward_read_weights = module.directional_read_weights(
        tf.constant(state.link),
        tf.constant(prev_read_weights, dtype=tf.float32),
        forward=True)
    backward_read_weights = module.directional_read_weights(
        tf.constant(state.link),
        tf.constant(prev_read_weights, dtype=tf.float32),
        forward=False)

    with self.test_session():
      forward_read_weights = forward_read_weights.eval()
      backward_read_weights = backward_read_weights.eval()

    # Check directional weights calculated correctly.
    self.assertAllEqual(
        forward_read_weights[0, 5, 0, :],  # read=5, write=0
        util.one_hot(memory_size, 1))
    self.assertAllEqual(
        backward_read_weights[0, 6, 1, :],  # read=6, write=1
        util.one_hot(memory_size, 3))