def test(self): values = tf.constant([ [1, -8, 5], [0, 2, 7], [2, -8, 6], ], dtype=tf.float32) float_mask = tf.constant([ [1, 1, 1], [0, 0, 1], [1, 1, 0], ], dtype=tf.float32) bool_mask = tf.constant([ [True, True, True], [False, False, True], [True, True, False], ], dtype=tf.bool) ninf = float('-inf') correct = np.array([ [1, -8, 5], [ninf, ninf, 7], [2, -8, ninf], ], dtype=np.float32) seq_batch0 = SequenceBatch(values, float_mask) seq_batch1 = SequenceBatch(values, bool_mask) with tf.Session(): assert_almost_equal(seq_batch0.with_pad_value(ninf).values.eval(), correct) assert_almost_equal(seq_batch1.with_pad_value(ninf).values.eval(), correct)
def __init__(self, memory_cells, query, project_query=False): """Define Attention. Args: memory_cells (SequenceBatch): a SequenceBatch containing a Tensor of shape (batch_size, num_cells, cell_dim) query (Tensor): a tensor of shape (batch_size, query_dim). project_query (bool): defaults to False. If True, the query goes through an extra projection layer to coerce it to cell_dim. """ cell_dim = memory_cells.values.get_shape().as_list()[2] if project_query: # project the query up/down to cell_dim self._projection_layer = Dense(cell_dim, activation='linear') query = self._projection_layer(query) # (batch_size, cand_dim) memory_values, memory_mask = memory_cells.values, memory_cells.mask # batch matrix multiply to compute logit scores for all choices in all batches query = tf.expand_dims(query, 2) # (batch_size, cell_dim, 1) logit_values = tf.batch_matmul(memory_values, query) # (batch_size, num_cells, 1) logit_values = tf.squeeze(logit_values, [2]) # (batch_size, num_cells) # set all pad logits to negative infinity logits = SequenceBatch(logit_values, memory_mask) logits = logits.with_pad_value(-float('inf')) # normalize to get probs probs = tf.nn.softmax(logits.values) # (batch_size, num_cells) retrieved = tf.batch_matmul(tf.expand_dims(probs, 1), memory_values) # (batch_size, 1, cell_dim) retrieved = tf.squeeze(retrieved, [1]) # (batch_size, cell_dim) self._logits = logits.values self._probs = probs self._retrieved = retrieved