def testReadWeights(self): memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) - 0.5) prev_read_weights = np.random.rand(BATCH_SIZE, NUM_READS, MEMORY_SIZE) prev_read_weights /= prev_read_weights.sum(2, keepdims=True) + 1 link = np.random.rand(BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE) # Row and column sums should be at most 1: link /= np.maximum(link.sum(2, keepdims=True), 1) link /= np.maximum(link.sum(3, keepdims=True), 1) # We query the memory on the third location in memory, and select a large # strength on the query. Then we select a content-based read-mode. read_content_keys = np.random.rand(BATCH_SIZE, NUM_READS, WORD_SIZE) read_content_keys[0, 0] = memory[0, 3] read_content_strengths = tf.constant(100., shape=[BATCH_SIZE, NUM_READS], dtype=tf.float64) read_mode = np.random.rand(BATCH_SIZE, NUM_READS, 1 + 2 * NUM_WRITES) read_mode[0, 0, :] = util.one_hot(1 + 2 * NUM_WRITES, 2 * NUM_WRITES) inputs = { 'read_content_keys': tf.constant(read_content_keys), 'read_content_strengths': read_content_strengths, 'read_mode': tf.constant(read_mode), } read_weights = self.module._read_weights(inputs, memory, prev_read_weights, link) with self.test_session(): read_weights = read_weights.eval() # read_weights for batch 0, read head 0 should be memory location 3 self.assertAllClose(read_weights[0, 0, :], util.one_hot(MEMORY_SIZE, 3), atol=1e-3)
def testWriteAllocationWeights(self): batch_size = 7 memory_size = 23 num_writes = 5 module = addressing.Freeness(memory_size) usage = np.random.rand(batch_size, memory_size) write_gates = np.random.rand(batch_size, num_writes) # Turn off gates for heads 1 and 3 in batch 0. This doesn't scaling down the # weighting, but it means that the usage doesn't change, so we should get # the same allocation weightings for: (1, 2) and (3, 4) (but all others # being different). write_gates[0, 1] = 0 write_gates[0, 3] = 0 # and turn heads 0 and 2 on for full effect. write_gates[0, 0] = 1 write_gates[0, 2] = 1 # In batch 1, make one of the usages 0 and another almost 0, so that these # entries get most of the allocation weights for the first and second heads. usage[ 1] = usage[1] * 0.9 + 0.1 # make sure all entries are in [0.1, 1] usage[1][4] = 0 # write head 0 should get allocated to position 4 usage[1][3] = 1e-4 # write head 1 should get allocated to position 3 write_gates[1, 0] = 1 # write head 0 fully on write_gates[1, 1] = 1 # write head 1 fully on weights = module.write_allocation_weights( usage=tf.constant(usage), write_gates=tf.constant(write_gates), num_writes=num_writes) with self.test_session(): weights = weights.eval() # Check that all weights are between 0 and 1 self.assertGreaterEqual(weights.min(), 0) self.assertLessEqual(weights.max(), 1) # Check that weights sum to close to 1 self.assertAllClose(np.sum(weights, axis=2), np.ones([batch_size, num_writes]), atol=1e-3) # Check the same / different allocation weight pairs as described above. self.assertGreater( np.abs(weights[0, 0, :] - weights[0, 1, :]).max(), 0.1) self.assertAllEqual(weights[0, 1, :], weights[0, 2, :]) self.assertGreater( np.abs(weights[0, 2, :] - weights[0, 3, :]).max(), 0.1) self.assertAllEqual(weights[0, 3, :], weights[0, 4, :]) self.assertAllClose(weights[1][0], util.one_hot(memory_size, 4), atol=1e-3) self.assertAllClose(weights[1][1], util.one_hot(memory_size, 3), atol=1e-3)
def testWriteWeights(self): memory = 10 * (np.random.rand(BATCH_SIZE, MEMORY_SIZE, WORD_SIZE) - 0.5) usage = np.random.rand(BATCH_SIZE, MEMORY_SIZE) allocation_gate = np.random.rand(BATCH_SIZE, NUM_WRITES) write_gate = np.random.rand(BATCH_SIZE, NUM_WRITES) write_content_keys = np.random.rand(BATCH_SIZE, NUM_WRITES, WORD_SIZE) write_content_strengths = np.random.rand(BATCH_SIZE, NUM_WRITES) # Check that turning on allocation gate fully brings the write gate to # the allocation weighting (which we will control by controlling the usage). usage[:, 3] = 0 allocation_gate[:, 0] = 1 write_gate[:, 0] = 1 inputs = { 'allocation_gate': tf.constant(allocation_gate), 'write_gate': tf.constant(write_gate), 'write_content_keys': tf.constant(write_content_keys), 'write_content_strengths': tf.constant(write_content_strengths) } weights = self.module._write_weights(inputs, tf.constant(memory), tf.constant(usage)) with self.test_session(): weights = weights.eval() # Check the weights sum to their target gating. self.assertAllClose(np.sum(weights, axis=2), write_gate, atol=5e-2) # Check that we fully allocated to the third row. weights_0_0_target = util.one_hot(MEMORY_SIZE, 3) self.assertAllClose(weights[0, 0], weights_0_0_target, atol=1e-3)
def testModule(self): batch_size = 7 memory_size = 4 num_reads = 11 num_writes = 5 module = addressing.TemporalLinkage( memory_size=memory_size, num_writes=num_writes) prev_link_in = tf.placeholder( tf.float32, (batch_size, num_writes, memory_size, memory_size)) prev_precedence_weights_in = tf.placeholder( tf.float32, (batch_size, num_writes, memory_size)) write_weights_in = tf.placeholder(tf.float32, (batch_size, num_writes, memory_size)) state = addressing.TemporalLinkageState( link=np.zeros([batch_size, num_writes, memory_size, memory_size]), precedence_weights=np.zeros([batch_size, num_writes, memory_size])) calc_state = module(write_weights_in, addressing.TemporalLinkageState( link=prev_link_in, precedence_weights=prev_precedence_weights_in)) with self.test_session() as sess: num_steps = 5 for i in xrange(num_steps): write_weights = np.random.rand(batch_size, num_writes, memory_size) write_weights /= write_weights.sum(2, keepdims=True) + 1 # Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1 if i == num_steps - 2: write_weights[0, 0, :] = util.one_hot(memory_size, 0) write_weights[0, 1, :] = util.one_hot(memory_size, 3) elif i == num_steps - 1: write_weights[0, 0, :] = util.one_hot(memory_size, 1) write_weights[0, 1, :] = util.one_hot(memory_size, 2) state = sess.run( calc_state, feed_dict={ prev_link_in: state.link, prev_precedence_weights_in: state.precedence_weights, write_weights_in: write_weights }) # link should be bounded in range [0, 1] self.assertGreaterEqual(state.link.min(), 0) self.assertLessEqual(state.link.max(), 1) # link diagonal should be zero self.assertAllEqual( state.link[:, :, range(memory_size), range(memory_size)], np.zeros([batch_size, num_writes, memory_size])) # link rows and columns should sum to at most 1 self.assertLessEqual(state.link.sum(2).max(), 1) self.assertLessEqual(state.link.sum(3).max(), 1) # records our transitions in batch 0: head 0: 0->1, and head 1: 3->2 self.assertAllEqual(state.link[0, 0, :, 0], util.one_hot(memory_size, 1)) self.assertAllEqual(state.link[0, 1, :, 3], util.one_hot(memory_size, 2)) # Now test calculation of forward and backward read weights prev_read_weights = np.random.rand(batch_size, num_reads, memory_size) prev_read_weights[0, 5, :] = util.one_hot(memory_size, 0) # read 5, posn 0 prev_read_weights[0, 6, :] = util.one_hot(memory_size, 2) # read 6, posn 2 forward_read_weights = module.directional_read_weights( tf.constant(state.link), tf.constant(prev_read_weights, dtype=tf.float32), forward=True) backward_read_weights = module.directional_read_weights( tf.constant(state.link), tf.constant(prev_read_weights, dtype=tf.float32), forward=False) with self.test_session(): forward_read_weights = forward_read_weights.eval() backward_read_weights = backward_read_weights.eval() # Check directional weights calculated correctly. self.assertAllEqual( forward_read_weights[0, 5, 0, :], # read=5, write=0 util.one_hot(memory_size, 1)) self.assertAllEqual( backward_read_weights[0, 6, 1, :], # read=6, write=1 util.one_hot(memory_size, 3))