def test_reduce_max(self, some_seq_batch): with pytest.raises(ValueError): # should complain about empty sequence SequenceBatch.reduce_max(some_seq_batch) values = GPUVariable( torch.FloatTensor([ [ [1, 2], [4, 5], [4, 4] ], # actual max is in later elements, but shd be suppressed by mask [[0, -4], [43, -5], [-1, -20]], # note that all elements in 2nd dim are negative ])) mask = GPUVariable(torch.FloatTensor([ [1, 0, 0], [1, 1, 0], ])) seq_batch = SequenceBatch(values, mask) result = SequenceBatch.reduce_max(seq_batch) assert_tensor_equal(result, [ [1, 2], [43, -4], ])
def forward(self, old_embeds, neighbors, rels): batch_size = len(old_embeds) neighbor_embeds = torch.index_select(old_embeds, 0, neighbors.values.view(-1)) neighbor_embeds = neighbor_embeds.view(batch_size, neighbors.values.shape[1], -1) neighbor_embeds = SequenceBatch(neighbor_embeds, neighbors.mask) pooled = SequenceBatch.reduce_max(neighbor_embeds) combined = torch.cat((old_embeds, pooled), dim=1) return F.relu(self._proj(self._dropout(combined)))
def forward(self, old_embeds, neighbors, rels): batch_size = len(old_embeds) projected = F.relu(self._proj(self._dropout(old_embeds))) neighbor_embeds = torch.index_select(projected, 0, neighbors.values.view(-1)) neighbor_embeds = neighbor_embeds.view(batch_size, neighbors.values.shape[1], -1) combined = torch.cat((projected.unsqueeze(1), neighbor_embeds), dim=1) mask = torch.cat((V(torch.ones(batch_size, 1)), neighbors.mask), dim=1) combined = SequenceBatch(combined, mask) return SequenceBatch.reduce_max(combined)
def embed(self, sequences): for seq in sequences: if len(seq) == 0: raise ValueError("Cannot embed empty sequence.") token_indices = SequenceBatch.from_sequences(sequences, self.vocab, min_seq_length=1) token_embeds = self.token_embedder.embed_seq_batch(token_indices) # SequenceBatch of size (batch_size, max_seq_length, word_dim) if self.pool == 'sum': pooled_token_embeds = SequenceBatch.reduce_sum(token_embeds) # (batch_size, word_dim) elif self.pool == 'mean': pooled_token_embeds = SequenceBatch.reduce_mean(token_embeds) # (batch_size, word_dim) elif self.pool == 'max': pooled_token_embeds = SequenceBatch.reduce_max(token_embeds) # (batch_size, word_dim) else: raise ValueError(self.pool) seq_embeds = self.transform(pooled_token_embeds) # (batch_size, embed_dim) assert seq_embeds.size()[1] == self.embed_dim return seq_embeds
def _score_actions(self, states, force_dom_attn=None, force_type_values=None): """Score actions. Args: states (list[State]) force_dom_attn (list[DOMElement]): a batch of DOMElements. If not None, forces the second DOM attention to select the specified elements. force_type_values (list[unicode | None]): force these values to be scored if possible (optional) Returns: action_scores_batch (list[ActionScores]) """ if len(states) == 0: return [] # no actions to score # ===== Embed entries of the query: (key, value) pairs ===== # concatenated keys and values of structured query query_entries = self._query_entries(states) query_embeds = self._query_embeds(states, query_entries) # ===== Embed DOM elements ===== # dom_embeds: SequenceBatch of shape (batch_size, num_elems, dom_dim) # dom_elems: list[list[DOMElement]] of shape (batch_size, num_elems) # It is padded with DOMElementPAD objects. dom_embeds, dom_elems = self._dom_embeds(states) # ===== Embed agent state ===== # (batch_size, dom_dim) dom_embeds_max = SequenceBatch.reduce_max(dom_embeds) if self._dom_attention_for_state: first_dom_attn_query = dom_embeds_max first_dom_attn = self._first_dom_attention( dom_embeds, first_dom_attn_query) state_embeds = first_dom_attn.context # (batch_size, dom_dim) else: state_embeds = dom_embeds_max # ===== Attend over entries of the query ===== # use both fields attention heads fields_attn = [ # List of AttentionOutput objects self._fields_attention_heads[0](query_embeds, state_embeds), self._fields_attention_heads[1](query_embeds, state_embeds), ] # (batch_size, attn_dim * 2) attended_query_embeds = torch.cat( [fields_attn[0].context, fields_attn[1].context], 1) # ===== Attend over DOM elements ===== elem_query = torch.cat([attended_query_embeds, state_embeds], 1) # compute state values using elem_query state_values = self._value_function_layer(elem_query) # (batch_size, 1) state_values = torch.squeeze(state_values, 1) # (batch_size,) # TODO(kelvin): clean this up # two DOM attention heads second_dom_attn = [ # contexts have shape (batch_size, dom_dim) self._second_dom_attention_heads[0](dom_embeds, elem_query), self._second_dom_attention_heads[1](dom_embeds, elem_query) ] # ===== Compute DOM probs from field weights ===== dom_head_weights = F.softmax( # Weight per each head self._second_dom_attn_head_weights(attended_query_embeds)) first_head_weights = torch.index_select( dom_head_weights, 1, GPUVariable(torch.LongTensor([0]))).expand_as( second_dom_attn[0].weights) second_head_weights = torch.index_select( dom_head_weights, 1, GPUVariable(torch.LongTensor([1]))).expand_as( second_dom_attn[1].weights) # DOM probs = # dom_head_weights[0] * first_head + dom_head_weights[1] * second_head dom_probs = first_head_weights * second_dom_attn[0].weights + \ second_head_weights * second_dom_attn[1].weights # ===== Decide whether to click or type ===== HARD_DOM_ATTN = True # TODO: Need to fix this for Best First Search if HARD_DOM_ATTN: # TODO: Bring back the test time flag? selector = lambda probs: self._sample(probs) if force_dom_attn: elem_indices = [] for batch_idx, force_dom_elem in enumerate(force_dom_attn): refs = [elem.ref for elem in dom_elems[batch_idx]] elem_indices.append(refs.index(force_dom_elem.ref)) # this selector just ignores probs and returns the indices of # the forced elements selector = lambda probs: elem_indices dom_selection = Selection(selector, dom_probs, candidates=None) batch_size, num_dom_elems, dom_dim = dom_embeds.values.size() selected_dom_indices = dom_selection.indices # (batch_size, 1) selected_dom_indices = torch.unsqueeze(selected_dom_indices, 1) # (batch_size, 1, 1) selected_dom_indices = torch.unsqueeze(selected_dom_indices, 1) selected_dom_indices = selected_dom_indices.expand( batch_size, 1, dom_dim) # (batch_size, 1, dom_dim) # (batch_size, 1, dom_dim) selected_dom_embeds = torch.gather( dom_embeds.values, 1, selected_dom_indices) # (batch_size, dom_dim) selected_dom_embeds = torch.squeeze(selected_dom_embeds, 1) else: selected_dom_embeds = torch.cat( [second_dom_attn[0].context, second_dom_attn[1].context], 1) # (batch_size, context_dim) dom_contexts = self._context_embedder( selected_dom_embeds, attended_query_embeds) # ===== Decide what value to type ===== type_values, type_value_probs = self._type_values_and_probs( states, query_embeds, dom_contexts, force_type_values) # (batch_size, 2) (index 0 corresponds to click) click_or_type_probs = F.softmax( self._click_or_type_linear(dom_contexts)) action_scores_batch = self._compute_action_scores( dom_selection.indices.data.cpu().numpy(), dom_elems, dom_probs, click_or_type_probs, type_values, type_value_probs, state_values) # add justifications for batch_idx, action_score in enumerate(action_scores_batch): justif = MiniWoBPolicyJustification( dom_elements=dom_elems[batch_idx], element_probs=dom_probs[batch_idx], click_or_type_probs=click_or_type_probs[batch_idx], query_entries=query_entries[batch_idx], fields_attentions=[fields_attn[0].weights[batch_idx], fields_attn[1].weights[batch_idx]], type_values=type_values[batch_idx], type_value_probs=type_value_probs[batch_idx], state_value=state_values[batch_idx]) action_score.justification = justif return action_scores_batch