def testMemoryMaskWithNonuniformLogits(self): memory = np.random.randn(2, 3, 10) logits = np.array([[-1, 1, 0], [-1, 1, 0]]) mask = np.array([[True, True, True], [True, True, False]]) # Calculate expected output. expected_weights = np.exp(logits) expected_weights[1, 2] = 0 expected_weights /= np.sum(expected_weights, axis=1, keepdims=True) expected_output = np.matmul(expected_weights[:, np.newaxis, :], memory)[:, 0] # Run attention model. attention = snt.AttentiveRead( lambda _: tf.constant(logits.reshape([6, 1]), dtype=tf.float32)) attention_output = attention( memory=tf.constant(memory, dtype=tf.float32), query=tf.constant(np.zeros([2, 5]), dtype=tf.float32), memory_mask=tf.constant(mask)) with self.test_session() as sess: actual = sess.run(attention_output) # Check output. self.assertAllClose(actual.read, expected_output) self.assertAllClose(actual.weights, expected_weights) # The actual logit for the masked value should be tiny. First check without. masked_actual_weight_logits = np.array(actual.weight_logits, copy=True) masked_actual_weight_logits[1, 2] = logits[1, 2] self.assertAllClose(masked_actual_weight_logits, logits) self.assertLess(actual.weight_logits[1, 2], -1e35)
def testAttentionLogitsModuleShape(self, output_rank): # attention_logit_mod must produce a rank 2 Tensor. attention_mod = snt.AttentiveRead( ConstantZero(output_rank=output_rank)) with self.assertRaises(snt.IncompatibleShapeError): attention_mod(self._memory, self._query)
def testWorksWithCommonModules(self, attention_logit_mod): # In the academic literature, attentive reads are most commonly implemented # with Linear or MLP modules. This integration test ensures that # AttentiveRead works safely with these. attention_mod = snt.AttentiveRead(attention_logit_mod) x = attention_mod(self._memory, self._query) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(x)
def setUp(self): super(AttentiveReadTest, self).setUp() self._batch_size = 3 self._memory_size = 4 self._memory_word_size = 1 self._query_word_size = 2 self._memory = tf.reshape( tf.cast(tf.range(0, 3 * 4 * 1), dtype=tf.float32), shape=[3, 4, 1]) self._query = tf.reshape( tf.cast(tf.range(0, 3 * 2), dtype=tf.float32), shape=[3, 2]) self._memory_mask = tf.convert_to_tensor( [ [True, True, True, True], [True, True, True, False], [True, True, False, False], ], dtype=tf.bool) self._attention_logit_mod = ConstantZero() self._attention_mod = snt.AttentiveRead(self._attention_logit_mod)