def __init__(self, memory_size=128, word_size=20, num_reads=1, num_writes=1, name='memory_access'): """Creates a MemoryAccess module. Args: memory_size: The number of memory slots (N in the DNC paper). word_size: The width of each memory slot (W in the DNC paper) num_reads: The number of read heads (R in the DNC paper). num_writes: The number of write heads (fixed at 1 in the paper). name: The name of the module. """ super(MemoryAccess, self).__init__(name=name) self._memory_size = memory_size self._word_size = word_size self._num_reads = num_reads self._num_writes = num_writes self._write_content_weights_mod = addressing.CosineWeights( num_writes, word_size, name='write_content_weights') self._read_content_weights_mod = addressing.CosineWeights( num_reads, word_size, name='read_content_weights') self._linkage = addressing.TemporalLinkage(memory_size, num_writes) self._freeness = addressing.Freeness(memory_size)
def testPrecedenceWeights(self): batch_size = 7 memory_size = 3 num_writes = 5 module = addressing.TemporalLinkage( memory_size=memory_size, num_writes=num_writes) prev_precedence_weights = np.random.rand(batch_size, num_writes, memory_size) write_weights = np.random.rand(batch_size, num_writes, memory_size) # These should sum to at most 1 for each write head in each batch. write_weights /= write_weights.sum(2, keepdims=True) + 1 prev_precedence_weights /= prev_precedence_weights.sum(2, keepdims=True) + 1 write_weights[0, 1, :] = 0 # batch 0 head 1: no writing write_weights[1, 2, :] /= write_weights[1, 2, :].sum() # b1 h2: all writing precedence_weights = module._precedence_weights( prev_precedence_weights=tf.constant(prev_precedence_weights), write_weights=tf.constant(write_weights)) with self.test_session(): precedence_weights = precedence_weights.eval() # precedence weights should be bounded in range [0, 1] self.assertGreaterEqual(precedence_weights.min(), 0) self.assertLessEqual(precedence_weights.max(), 1) # no writing in batch 0, head 1 self.assertAllClose(precedence_weights[0, 1, :], prev_precedence_weights[0, 1, :]) # all writing in batch 1, head 2 self.assertAllClose(precedence_weights[1, 2, :], write_weights[1, 2, :])
def testModule(self): batch_size = 7 memory_size = 4 num_reads = 11 num_writes = 5 module = addressing.TemporalLinkage(memory_size=memory_size, num_writes=num_writes) prev_link_in = tf.placeholder( tf.float32, (batch_size, num_writes, memory_size, memory_size)) prev_precedence_weights_in = tf.placeholder( tf.float32, (batch_size, num_writes, memory_size)) write_weights_in = tf.placeholder( tf.float32, (batch_size, num_writes, memory_size)) state = addressing.TemporalLinkageState( link=np.zeros([batch_size, num_writes, memory_size, memory_size]), precedence_weights=np.zeros([batch_size, num_writes, memory_size])) calc_state = module( write_weights_in, addressing.TemporalLinkageState( link=prev_link_in, precedence_weights=prev_precedence_weights_in)) with self.test_session() as sess: num_steps = 5 for i in xrange(num_steps): write_weights = np.random.rand(batch_size, num_writes, memory_size) write_weights /= write_weights.sum(2, keepdims=True) + 1 # Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1 if i == num_steps - 2: write_weights[0, 0, :] = util.one_hot(memory_size, 0) write_weights[0, 1, :] = util.one_hot(memory_size, 3) elif i == num_steps - 1: write_weights[0, 0, :] = util.one_hot(memory_size, 1) write_weights[0, 1, :] = util.one_hot(memory_size, 2) state = sess.run(calc_state, feed_dict={ prev_link_in: state.link, prev_precedence_weights_in: state.precedence_weights, write_weights_in: write_weights }) # link should be bounded in range [0, 1] self.assertGreaterEqual(state.link.min(), 0) self.assertLessEqual(state.link.max(), 1) # link diagonal should be zero self.assertAllEqual( state.link[:, :, range(memory_size), range(memory_size)], np.zeros([batch_size, num_writes, memory_size])) # link rows and columns should sum to at most 1 self.assertLessEqual(state.link.sum(2).max(), 1) self.assertLessEqual(state.link.sum(3).max(), 1) # records our transitions in batch 0: head 0: 0->1, and head 1: 3->2 self.assertAllEqual(state.link[0, 0, :, 0], util.one_hot(memory_size, 1)) self.assertAllEqual(state.link[0, 1, :, 3], util.one_hot(memory_size, 2)) # Now test calculation of forward and backward read weights prev_read_weights = np.random.rand(batch_size, num_reads, memory_size) prev_read_weights[0, 5, :] = util.one_hot(memory_size, 0) # read 5, posn 0 prev_read_weights[0, 6, :] = util.one_hot(memory_size, 2) # read 6, posn 2 forward_read_weights = module.directional_read_weights( tf.constant(state.link), tf.constant(prev_read_weights, dtype=tf.float32), forward=True) backward_read_weights = module.directional_read_weights( tf.constant(state.link), tf.constant(prev_read_weights, dtype=tf.float32), forward=False) with self.test_session(): forward_read_weights = forward_read_weights.eval() backward_read_weights = backward_read_weights.eval() # Check directional weights calculated correctly. self.assertAllEqual( forward_read_weights[0, 5, 0, :], # read=5, write=0 util.one_hot(memory_size, 1)) self.assertAllEqual( backward_read_weights[0, 6, 1, :], # read=6, write=1 util.one_hot(memory_size, 3))