def __init__(self, memory_size=128, word_size=20, num_reads=1, num_writes=1, name='memory_access'): """Creates a MemoryAccess module. Args: memory_size: The number of memory slots (N in the DNC paper). word_size: The width of each memory slot (W in the DNC paper) num_reads: The number of read heads (R in the DNC paper). num_writes: The number of write heads (fixed at 1 in the paper). name: The name of the module. """ super(MemoryAccess, self).__init__(name=name) self._memory_size = memory_size self._word_size = word_size self._num_reads = num_reads self._num_writes = num_writes self._write_content_weights_mod = addressing.CosineWeights( num_writes, word_size, name='write_content_weights') self._read_content_weights_mod = addressing.CosineWeights( num_reads, word_size, name='read_content_weights') self._linkage = addressing.TemporalLinkage(memory_size, num_writes) self._freeness = addressing.Freeness(memory_size)
def testDivideByZero(self): batch_size = 5 num_heads = 4 memory_size = 10 word_size = 2 module = addressing.CosineWeights(num_heads, word_size) keys = tf.random_normal([batch_size, num_heads, word_size]) strengths = tf.random_normal([batch_size, num_heads]) # First row of memory is non-zero to concentrate attention on this location. # Remaining rows are all zero. first_row_ones = tf.ones([batch_size, 1, word_size], dtype=tf.float32) remaining_zeros = tf.zeros([batch_size, memory_size - 1, word_size], dtype=tf.float32) mem = tf.concat((first_row_ones, remaining_zeros), 1) output = module(mem, keys, strengths) gradients = tf.gradients(output, [mem, keys, strengths]) with self.test_session() as sess: output, gradients = sess.run([output, gradients]) self.assertFalse(np.any(np.isnan(output))) self.assertFalse(np.any(np.isnan(gradients[0]))) self.assertFalse(np.any(np.isnan(gradients[1]))) self.assertFalse(np.any(np.isnan(gradients[2])))
def testValues(self): batch_size = 5 num_heads = 4 memory_size = 10 word_size = 2 mem_data = np.random.randn(batch_size, memory_size, word_size) np.copyto(mem_data[0, 0], [1, 2]) np.copyto(mem_data[0, 1], [3, 4]) np.copyto(mem_data[0, 2], [5, 6]) keys_data = np.random.randn(batch_size, num_heads, word_size) np.copyto(keys_data[0, 0], [5, 6]) np.copyto(keys_data[0, 1], [1, 2]) np.copyto(keys_data[0, 2], [5, 6]) np.copyto(keys_data[0, 3], [3, 4]) strengths_data = np.random.randn(batch_size, num_heads) module = addressing.CosineWeights(num_heads, word_size) mem = tf.placeholder(tf.float32, [batch_size, memory_size, word_size]) keys = tf.placeholder(tf.float32, [batch_size, num_heads, word_size]) strengths = tf.placeholder(tf.float32, [batch_size, num_heads]) weights = module(mem, keys, strengths) with self.test_session() as sess: result = sess.run(weights, feed_dict={ mem: mem_data, keys: keys_data, strengths: strengths_data }) # Manually checks results. strengths_softplus = np.log(1 + np.exp(strengths_data)) similarity = np.zeros((memory_size)) for b in xrange(batch_size): for h in xrange(num_heads): key = keys_data[b, h] key_norm = np.linalg.norm(key) for m in xrange(memory_size): row = mem_data[b, m] similarity[m] = np.dot( key, row) / (key_norm * np.linalg.norm(row)) similarity = np.exp(similarity * strengths_softplus[b, h]) similarity /= similarity.sum() self.assertAllClose(result[b, h], similarity, atol=1e-4, rtol=1e-4)
def testShape(self): batch_size = 5 num_heads = 3 memory_size = 7 word_size = 2 module = addressing.CosineWeights(num_heads, word_size) mem = tf.placeholder(tf.float32, [batch_size, memory_size, word_size]) keys = tf.placeholder(tf.float32, [batch_size, num_heads, word_size]) strengths = tf.placeholder(tf.float32, [batch_size, num_heads]) weights = module(mem, keys, strengths) self.assertTrue(weights.get_shape().is_compatible_with( [batch_size, num_heads, memory_size]))