Example #1
0
    def test_reduce_max(self, some_seq_batch):

        with pytest.raises(ValueError):
            # should complain about empty sequence
            SequenceBatch.reduce_max(some_seq_batch)

        values = GPUVariable(
            torch.FloatTensor([
                [
                    [1, 2], [4, 5], [4, 4]
                ],  # actual max is in later elements, but shd be suppressed by mask
                [[0, -4], [43, -5],
                 [-1, -20]],  # note that all elements in 2nd dim are negative
            ]))
        mask = GPUVariable(torch.FloatTensor([
            [1, 0, 0],
            [1, 1, 0],
        ]))
        seq_batch = SequenceBatch(values, mask)
        result = SequenceBatch.reduce_max(seq_batch)

        assert_tensor_equal(result, [
            [1, 2],
            [43, -4],
        ])
Example #2
0
 def forward(self, old_embeds, neighbors, rels):
     batch_size = len(old_embeds)
     neighbor_embeds = torch.index_select(old_embeds, 0,
                                          neighbors.values.view(-1))
     neighbor_embeds = neighbor_embeds.view(batch_size,
                                            neighbors.values.shape[1], -1)
     neighbor_embeds = SequenceBatch(neighbor_embeds, neighbors.mask)
     pooled = SequenceBatch.reduce_max(neighbor_embeds)
     combined = torch.cat((old_embeds, pooled), dim=1)
     return F.relu(self._proj(self._dropout(combined)))
Example #3
0
 def forward(self, old_embeds, neighbors, rels):
     batch_size = len(old_embeds)
     projected = F.relu(self._proj(self._dropout(old_embeds)))
     neighbor_embeds = torch.index_select(projected, 0,
                                          neighbors.values.view(-1))
     neighbor_embeds = neighbor_embeds.view(batch_size,
                                            neighbors.values.shape[1], -1)
     combined = torch.cat((projected.unsqueeze(1), neighbor_embeds), dim=1)
     mask = torch.cat((V(torch.ones(batch_size, 1)), neighbors.mask), dim=1)
     combined = SequenceBatch(combined, mask)
     return SequenceBatch.reduce_max(combined)
Example #4
0
    def embed(self, sequences):
        for seq in sequences:
            if len(seq) == 0:
                raise ValueError("Cannot embed empty sequence.")

        token_indices = SequenceBatch.from_sequences(sequences, self.vocab, min_seq_length=1)
        token_embeds = self.token_embedder.embed_seq_batch(token_indices)  # SequenceBatch of size (batch_size, max_seq_length, word_dim)
        if self.pool == 'sum':
            pooled_token_embeds = SequenceBatch.reduce_sum(token_embeds)  # (batch_size, word_dim)
        elif self.pool == 'mean':
            pooled_token_embeds = SequenceBatch.reduce_mean(token_embeds)  # (batch_size, word_dim)
        elif self.pool == 'max':
            pooled_token_embeds = SequenceBatch.reduce_max(token_embeds)  # (batch_size, word_dim)
        else:
            raise ValueError(self.pool)

        seq_embeds = self.transform(pooled_token_embeds)  # (batch_size, embed_dim)
        assert seq_embeds.size()[1] == self.embed_dim

        return seq_embeds
Example #5
0
    def _score_actions(self, states, force_dom_attn=None,
                       force_type_values=None):
        """Score actions.

        Args:
            states (list[State])
            force_dom_attn (list[DOMElement]): a batch of DOMElements. If not None,
                forces the second DOM attention to select the specified elements.
            force_type_values (list[unicode | None]): force these values to
                be scored if possible (optional)

        Returns:
            action_scores_batch (list[ActionScores])
        """
        if len(states) == 0:
            return []  # no actions to score

        # ===== Embed entries of the query: (key, value) pairs =====
        # concatenated keys and values of structured query
        query_entries = self._query_entries(states)
        query_embeds = self._query_embeds(states, query_entries)

        # ===== Embed DOM elements =====
        # dom_embeds: SequenceBatch of shape (batch_size, num_elems, dom_dim)
        # dom_elems: list[list[DOMElement]] of shape (batch_size, num_elems)
        #     It is padded with DOMElementPAD objects.
        dom_embeds, dom_elems = self._dom_embeds(states)

        # ===== Embed agent state =====
        # (batch_size, dom_dim)
        dom_embeds_max = SequenceBatch.reduce_max(dom_embeds)

        if self._dom_attention_for_state:
            first_dom_attn_query = dom_embeds_max
            first_dom_attn = self._first_dom_attention(
                    dom_embeds, first_dom_attn_query)
            state_embeds = first_dom_attn.context  # (batch_size, dom_dim)
        else:
            state_embeds = dom_embeds_max

        # ===== Attend over entries of the query =====
        # use both fields attention heads
        fields_attn = [  # List of AttentionOutput objects
            self._fields_attention_heads[0](query_embeds, state_embeds),
            self._fields_attention_heads[1](query_embeds, state_embeds),
        ]

        # (batch_size, attn_dim * 2)
        attended_query_embeds = torch.cat(
            [fields_attn[0].context, fields_attn[1].context], 1)

        # ===== Attend over DOM elements =====
        elem_query = torch.cat([attended_query_embeds, state_embeds], 1)

        # compute state values using elem_query
        state_values = self._value_function_layer(elem_query)  # (batch_size, 1)
        state_values = torch.squeeze(state_values, 1)  # (batch_size,)
        # TODO(kelvin): clean this up

        # two DOM attention heads
        second_dom_attn = [  # contexts have shape (batch_size, dom_dim)
            self._second_dom_attention_heads[0](dom_embeds, elem_query),
            self._second_dom_attention_heads[1](dom_embeds, elem_query)
        ]

        # ===== Compute DOM probs from field weights =====
        dom_head_weights = F.softmax(  # Weight per each head
                self._second_dom_attn_head_weights(attended_query_embeds))

        first_head_weights = torch.index_select(
                dom_head_weights, 1,
                GPUVariable(torch.LongTensor([0]))).expand_as(
                        second_dom_attn[0].weights)
        second_head_weights = torch.index_select(
                dom_head_weights, 1,
                GPUVariable(torch.LongTensor([1]))).expand_as(
                        second_dom_attn[1].weights)

        # DOM probs =
        # dom_head_weights[0] * first_head + dom_head_weights[1] * second_head
        dom_probs = first_head_weights * second_dom_attn[0].weights + \
                    second_head_weights * second_dom_attn[1].weights

        # ===== Decide whether to click or type =====
        HARD_DOM_ATTN = True
        # TODO: Need to fix this for Best First Search
        if HARD_DOM_ATTN:
            # TODO: Bring back the test time flag?
            selector = lambda probs: self._sample(probs)

            if force_dom_attn:
                elem_indices = []
                for batch_idx, force_dom_elem in enumerate(force_dom_attn):
                    refs = [elem.ref for elem in dom_elems[batch_idx]]
                    elem_indices.append(refs.index(force_dom_elem.ref))

                # this selector just ignores probs and returns the indices of
                # the forced elements
                selector = lambda probs: elem_indices

            dom_selection = Selection(selector, dom_probs, candidates=None)
            batch_size, num_dom_elems, dom_dim = dom_embeds.values.size()
            selected_dom_indices = dom_selection.indices
            # (batch_size, 1)
            selected_dom_indices = torch.unsqueeze(selected_dom_indices, 1)
            # (batch_size, 1, 1)
            selected_dom_indices = torch.unsqueeze(selected_dom_indices, 1)
            selected_dom_indices = selected_dom_indices.expand(
                    batch_size, 1, dom_dim)  # (batch_size, 1, dom_dim)

            # (batch_size, 1, dom_dim)
            selected_dom_embeds = torch.gather(
                dom_embeds.values, 1, selected_dom_indices)
            # (batch_size, dom_dim)
            selected_dom_embeds = torch.squeeze(selected_dom_embeds, 1)
        else:
            selected_dom_embeds = torch.cat(
                [second_dom_attn[0].context, second_dom_attn[1].context], 1)

        # (batch_size, context_dim)
        dom_contexts = self._context_embedder(
                selected_dom_embeds, attended_query_embeds)

        # ===== Decide what value to type =====
        type_values, type_value_probs = self._type_values_and_probs(
                states, query_embeds, dom_contexts, force_type_values)

        # (batch_size, 2) (index 0 corresponds to click)
        click_or_type_probs = F.softmax(
                self._click_or_type_linear(dom_contexts))

        action_scores_batch = self._compute_action_scores(
                dom_selection.indices.data.cpu().numpy(), dom_elems, dom_probs,
                click_or_type_probs, type_values, type_value_probs,
                state_values)

        # add justifications
        for batch_idx, action_score in enumerate(action_scores_batch):
            justif = MiniWoBPolicyJustification(
                dom_elements=dom_elems[batch_idx],
                element_probs=dom_probs[batch_idx],
                click_or_type_probs=click_or_type_probs[batch_idx],
                query_entries=query_entries[batch_idx],
                fields_attentions=[fields_attn[0].weights[batch_idx],
                                   fields_attn[1].weights[batch_idx]],
                type_values=type_values[batch_idx],
                type_value_probs=type_value_probs[batch_idx],
                state_value=state_values[batch_idx])
            action_score.justification = justif

        return action_scores_batch