def reduce_prod(cls, seq_batch): """Compute the product of each sequence in a SequenceBatch. If a sequence is empty, we return a product of 1. Args: seq_batch (SequenceBatch): of shape (batch_size, seq_length, X1, X2, ...) Returns: Tensor: of shape (batch_size, X1, X2, ...) """ mask = seq_batch.mask values = seq_batch.values # We set all pad values = 1, so that taking the log will not produce -inf mask_bcast = expand_dims_for_broadcast(mask, values).expand( values.size()) # (batch_size, seq_length, X1, X2, ...) values = conditional(mask_bcast, values, 1 - mask_bcast) logged = SequenceBatch( torch.log(values), seq_batch.mask) # (batch_size, seq_length, X1, X2, ...) log_sum = SequenceBatch.reduce_sum(logged) # (batch_size, X1, X2, ...) prod = torch.exp(log_sum) return prod
def gated_update(h, h_new, update): """If update == 1.0, return h_new; if update == 0.0, return h. Applies this logic to each element in a batch. Args: h (Variable): of shape (batch_size, hidden_dim) h_new (Variable): of shape (batch_size, hidden_dim) update (Variable): of shape (batch_size, 1). Returns: Variable: of shape (batch_size, hidden_dim) """ batch_size, hidden_dim = h.size() gate = update.expand(batch_size, hidden_dim) return conditional(gate, h_new, h)
def _mask_weights(cls, weights, mask): # if a given row has no memory cells, weights should be all zeros no_cells = cls._no_cells(mask) all_zeros = GPUVariable(torch.zeros(*mask.size())) weights = conditional(no_cells, all_zeros, weights) return weights
def forward(self, memory_cells, query): """Generates a density over a set of elements w.r.t. the query vector. Et(i) = tanh(Hi * Wh + St * Ws) * v At = softmax(Et) Dimensions: Hi: (batch_size x memory_dim) St: (batch_size x query_dim) Wh: (memory_dim x attn_dim) Ws: (query_dim x attn_dim) v: (attn_dim x 1) -- tanh( Hi * Wh + St * Ws ): (batch_size x attn_dim) tanh( Hi * Wh + St * Ws ) * v: (batch_size x 1) At = softmax(Et): (batch_size x num_cells) Args: memory_cells (SequenceBatch): (batch_size x num_cells x memory_dim) query (torch.Variable): St (batch_size x query_dim) Returns: Variable: (batch_size x num_cells) array """ transformed_query = torch.mm(query, self.query_transform) # (batch_size, attn_dim) batch_size, num_cells = memory_cells.mask.size() memory_cells_ = torch.transpose(memory_cells.values, 0, 1) # (num_cells, batch_size, memory_dim) expanded_transformed_query = transformed_query.expand(num_cells, batch_size, self.attn_dim) expanded_memory_transform = self.memory_transform.expand(num_cells, self.memory_dim, self.attn_dim) expanded_v_transform = self.v_transform.expand(num_cells, self.attn_dim, 1) # (num_cells, batch_size, attn_dim) attn_embeds = torch.bmm(memory_cells_, expanded_memory_transform) + expanded_transformed_query attn_embeds = self.tanh(attn_embeds) attn_embeds = torch.bmm(attn_embeds, expanded_v_transform) # (num_cells, batch_size, 1) logits = torch.transpose(attn_embeds.squeeze(2), 0, 1) mask = memory_cells.mask # no_cells is a FloatTensor with shape (batch_size, num_cells) # no_cells[i, j] = 1 if example i has NO memory cells, 0 otherwise no_cells = (1 - mask).prod(1).expand_as(mask) # TODO(kelvin): check for numerical stability. Product of 1's does not necessarily equal 1 exactly, which we need suppress = GPUVariable(torch.zeros(*mask.size())) suppress[mask == 0] = float('-inf') # send the logit of non-cells to -infinity suppress[no_cells == 1] = 0.0 # but if an entire row has no cells, just leave the cells alone logits = logits + suppress # -inf + anything = -inf # compute normalized weights weights = self.softmax(logits) # (batch_size, num_cells) # if a given row has no memory cells, weights should be all zeros all_zeros = GPUVariable(torch.zeros(*mask.size())) weights = conditional(no_cells, all_zeros, weights) context = torch.bmm(weights.unsqueeze(1), memory_cells.values) # (batch_size, 1, memory_dim) context = context.squeeze(1) # (batch_size, memory_dim) return AttentionOutput(weights=weights, context=context)