Пример #1
0
    def test(self):
        values = tf.constant([
            [1, -8, 5],
            [0, 2, 7],
            [2, -8, 6],
        ], dtype=tf.float32)

        float_mask = tf.constant([
            [1, 1, 1],
            [0, 0, 1],
            [1, 1, 0],
        ], dtype=tf.float32)

        bool_mask = tf.constant([
            [True, True, True],
            [False, False, True],
            [True, True, False],
        ], dtype=tf.bool)

        ninf = float('-inf')
        correct = np.array([
            [1, -8, 5],
            [ninf, ninf, 7],
            [2, -8, ninf],
        ], dtype=np.float32)

        seq_batch0 = SequenceBatch(values, float_mask)
        seq_batch1 = SequenceBatch(values, bool_mask)

        with tf.Session():
            assert_almost_equal(seq_batch0.with_pad_value(ninf).values.eval(), correct)
            assert_almost_equal(seq_batch1.with_pad_value(ninf).values.eval(), correct)
Пример #2
0
    def __init__(self, memory_cells, query, project_query=False):
        """Define Attention.

        Args:
            memory_cells (SequenceBatch): a SequenceBatch containing a Tensor of shape (batch_size, num_cells, cell_dim)
            query (Tensor): a tensor of shape (batch_size, query_dim).
            project_query (bool): defaults to False. If True, the query goes through an extra projection layer to
                coerce it to cell_dim.
        """
        cell_dim = memory_cells.values.get_shape().as_list()[2]
        if project_query:
            # project the query up/down to cell_dim
            self._projection_layer = Dense(cell_dim, activation='linear')
            query = self._projection_layer(query)  # (batch_size, cand_dim)

        memory_values, memory_mask = memory_cells.values, memory_cells.mask

        # batch matrix multiply to compute logit scores for all choices in all batches
        query = tf.expand_dims(query, 2)  # (batch_size, cell_dim, 1)
        logit_values = tf.batch_matmul(memory_values, query)  # (batch_size, num_cells, 1)
        logit_values = tf.squeeze(logit_values, [2])  # (batch_size, num_cells)

        # set all pad logits to negative infinity
        logits = SequenceBatch(logit_values, memory_mask)
        logits = logits.with_pad_value(-float('inf'))

        # normalize to get probs
        probs = tf.nn.softmax(logits.values)  # (batch_size, num_cells)

        retrieved = tf.batch_matmul(tf.expand_dims(probs, 1), memory_values)  # (batch_size, 1, cell_dim)
        retrieved = tf.squeeze(retrieved, [1])  # (batch_size, cell_dim)

        self._logits = logits.values
        self._probs = probs
        self._retrieved = retrieved