def testMemoryMaskWithNonuniformLogits(self):
    memory = np.random.randn(2, 3, 10)
    logits = np.array([[-1, 1, 0], [-1, 1, 0]])
    mask = np.array([[True, True, True], [True, True, False]])

    # Calculate expected output.
    expected_weights = np.exp(logits)
    expected_weights[1, 2] = 0
    expected_weights /= np.sum(expected_weights, axis=1, keepdims=True)
    expected_output = np.matmul(expected_weights[:, np.newaxis, :],
                                memory)[:, 0]

    # Run attention model.
    attention = snt.AttentiveRead(
        lambda _: tf.constant(logits.reshape([6, 1]), dtype=tf.float32))
    attention_output = attention(
        memory=tf.constant(memory, dtype=tf.float32),
        query=tf.constant(np.zeros([2, 5]), dtype=tf.float32),
        memory_mask=tf.constant(mask))
    with self.test_session() as sess:
      actual = sess.run(attention_output)

    # Check output.
    self.assertAllClose(actual.read, expected_output)
    self.assertAllClose(actual.weights, expected_weights)
    # The actual logit for the masked value should be tiny. First check without.
    masked_actual_weight_logits = np.array(actual.weight_logits, copy=True)
    masked_actual_weight_logits[1, 2] = logits[1, 2]
    self.assertAllClose(masked_actual_weight_logits, logits)
    self.assertLess(actual.weight_logits[1, 2], -1e35)
Example #2
0
    def testAttentionLogitsModuleShape(self, output_rank):

        # attention_logit_mod must produce a rank 2 Tensor.
        attention_mod = snt.AttentiveRead(
            ConstantZero(output_rank=output_rank))
        with self.assertRaises(snt.IncompatibleShapeError):
            attention_mod(self._memory, self._query)
Example #3
0
    def testWorksWithCommonModules(self, attention_logit_mod):

        # In the academic literature, attentive reads are most commonly implemented
        # with Linear or MLP modules. This integration test ensures that
        # AttentiveRead works safely with these.
        attention_mod = snt.AttentiveRead(attention_logit_mod)
        x = attention_mod(self._memory, self._query)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(x)
  def setUp(self):
    super(AttentiveReadTest, self).setUp()

    self._batch_size = 3
    self._memory_size = 4
    self._memory_word_size = 1
    self._query_word_size = 2
    self._memory = tf.reshape(
        tf.cast(tf.range(0, 3 * 4 * 1), dtype=tf.float32), shape=[3, 4, 1])
    self._query = tf.reshape(
        tf.cast(tf.range(0, 3 * 2), dtype=tf.float32), shape=[3, 2])
    self._memory_mask = tf.convert_to_tensor(
        [
            [True, True, True, True],
            [True, True, True, False],
            [True, True, False, False],
        ],
        dtype=tf.bool)
    self._attention_logit_mod = ConstantZero()
    self._attention_mod = snt.AttentiveRead(self._attention_logit_mod)