Ejemplo n.º 1
0
 def forward(
     self, word_nl_query: PaddedSequenceWithMask,
     char_nl_query: PaddedSequenceWithMask
 ) -> Tuple[PaddedSequenceWithMask, PaddedSequenceWithMask]:
     e_word = self.word_embed(word_nl_query.data)
     e_char = self.char_embed(char_nl_query.data)
     return (PaddedSequenceWithMask(e_word, word_nl_query.mask),
             PaddedSequenceWithMask(e_char, char_nl_query.mask))
Ejemplo n.º 2
0
    def forward(
            self, previous_actions: PaddedSequenceWithMask
    ) -> PaddedSequenceWithMask:
        """
        Parameters
        ----------
        previous_actions: PaddedSequenceWithMask
            The shape is (L, N, 3). where L is the sequence length and
            N is the batch size. The padding value should be -1.
            [:, :, 0] represent the rule IDs, [:, :, 1] represent the token
            IDs, [:, :, 2] represent the indexes of the queries.
            The padding value should be -1

        Returns
        -------
        seq_embed: PaddedSequenceWithMask
            The shape is (L, N, embedding_dim). where L is the sequence length
            and N is the batch size.
        """
        L, N = previous_actions.data.shape[:2]

        rule_seq = previous_actions.data[:, :, 0]

        token_seq = previous_actions.data[:, :, 1]
        """
        # TODO this decreases the performance of CSG pbe significantly
        reference_seq = (token_seq == -1) * (previous_actions.data[:, :, 2] != -1)
        # reference_seq => self.token_num
        token_seq = token_seq + reference_seq * (self.n_token + 1)
        """

        embedding = self.rule_embed(rule_seq) + self.token_embed(token_seq)
        return PaddedSequenceWithMask(embedding, previous_actions.mask)
Ejemplo n.º 3
0
    def forward(
        self, previous_actions: PaddedSequenceWithMask,
        previous_action_rules: PaddedSequenceWithMask
    ) -> Tuple[PaddedSequenceWithMask, PaddedSequenceWithMask]:
        """
        Parameters
        ----------
        previous_acitons: rnn.PaddedSequenceWithMask
            The previous action sequence.
            The encoded tensor with the shape of
            (len(action_sequence) + 1, 3). Each action will be encoded by
            the tuple of (ID of the applied rule, ID of the inserted token,
            the index of the word copied from the reference).
            The padding value should be -1.
        previous_action_rules: rnn.PaddedSequenceWithMask
            The rule of previous action sequence.
            The shape of each sequence is
            (action_length, max_arity + 1, 3).

        Returns
        -------
        e_action: PaddedSequenceWithMask
            (L_ast, N, embedding_size) where L_ast is the sequence length,
            N is the batch_size.
        e_rule_action: PaddedSequenceWithMask
            (L_ast, N, rule_embedding_size) where L_ast is the sequence length,
            N is the batch_size.
        """
        e_action = self.action_embed(previous_actions)
        e_rule_action = self.elem_embed(previous_action_rules.data)
        return (e_action,
                PaddedSequenceWithMask(e_rule_action,
                                       previous_action_rules.mask))
Ejemplo n.º 4
0
    def forward(self,
                test_case_tensor: torch.Tensor,
                variables_tensor: PaddedSequenceWithMask,
                test_case_feature: torch.Tensor
                ) -> Tuple[PaddedSequenceWithMask, torch.Tensor]:
        # (B, N, c)
        processed_input = test_case_tensor
        # (L, B, N, c)
        variables = variables_tensor
        # (B, N, C)
        in_feature = test_case_feature

        B, N = in_feature.shape[:2]
        C = in_feature.shape[2:]

        if len(variables.data) != 0:
            # (L, B, N, c)
            processed_input = \
                processed_input.unsqueeze(0).expand(variables.data.shape)
        else:
            # (L, B, N, c)
            processed_input = torch.zeros_like(variables.data)
        # (L, B, N, 2c)
        f = torch.cat([processed_input, variables.data], dim=3)
        L = f.shape[0]
        # (L, B, N, C)
        vfeatures = self.module(f.reshape(L * B * N, *f.shape[3:]))
        vfeatures = vfeatures.reshape(L, B, N, *vfeatures.shape[1:])

        # reduce n_test_cases
        # (L, B, C)
        vfeatures = vfeatures.float().mean(dim=2)
        # (B, C)
        in_feature = in_feature.float().mean(dim=1)

        # Instantiate PaddedSequenceWithMask
        vmask = variables.mask
        for _ in range(len(C)):
            vmask = vmask.reshape(*vmask.shape, 1)
        features = PaddedSequenceWithMask(vfeatures * vmask, variables.mask)
        if features.data.numel() == 0:
            features.data = torch.zeros(0, B, *C, device=in_feature.device,
                                        dtype=in_feature.dtype)

        reduced_feature = features.data.sum(dim=0)  # reduce sequence length
        return features, torch.cat([in_feature, reduced_feature], dim=1)
Ejemplo n.º 5
0
    def forward(self, input: PaddedSequenceWithMask,
                char_embed: torch.Tensor) -> \
            Tuple[PaddedSequenceWithMask, torch.Tensor]:
        """
        Parameters
        ----------
        input: PaddedSequenceWithMask
            (L, N, hidden_size) where L is the sequence length,
            N is the batch size.
        char_embed: torch.Tensor
            (L, N, char_embed_size) where L is the sequence length,
            N is the batch size.

        Returns
        -------
        output: PaddedSequenceWithMask
            (L, N, hidden_size) where L is the sequence length,
            N is the batch size.
        attn_weights: torch.Tensor
            (N, L, L) where N is the batch size and L is the sequence length.
        """
        L, N, _ = input.data.shape
        h_in = input.data
        h = h_in + index_embeddings(h_in, self.block_idx)
        h, attn = self.attention(h,
                                 h,
                                 h,
                                 key_padding_mask=input.mask.permute(1,
                                                                     0) == 0)
        h = h + h_in
        h = self.norm1(h)

        h_in = h
        h = self.gating(h, char_embed)
        h = self.dropout(h)
        h = h + h_in
        h = self.norm2(h)

        h_in = h
        h = h * input.mask.to(h.dtype).reshape(L, N, 1)
        h = lne_to_nel(h)
        h = self.conv1(h)
        h = self.dropout(h)
        h = gelu(h)
        h = h * input.mask.to(h.dtype).reshape(L, N, 1).permute(1, 2, 0)
        h = self.conv2(h)
        h = self.dropout(h)
        h = gelu(h)
        h = nel_to_lne(h)
        h = h + h_in
        h = self.norm3(h)

        return PaddedSequenceWithMask(h, input.mask), attn
Ejemplo n.º 6
0
    def test_to(self) -> None:
        class X:
            def to(self, *args, **kwargs):
                self.args = (args, kwargs)
                return self

        e = Environment()
        e["key"] = X()
        e["x"] = torch.tensor(0)
        e["y"] = PaddedSequenceWithMask(torch.tensor(0.0), torch.tensor(True))
        e["z"] = 10
        e.to(device=torch.device("cpu"))
        assert e["key"].args == ((), {"device": torch.device("cpu")})
Ejemplo n.º 7
0
 def forward(
     self, x: Union[torch.LongTensor, PaddedSequenceWithMask]
 ) -> Union[torch.Tensor, PaddedSequenceWithMask]:
     if isinstance(x, PaddedSequenceWithMask):
         data = x.data
     else:
         data = x
     y = torch.where(data == self.ignore_id, torch.zeros_like(data), data)
     embedding = super().forward(y)
     out = torch.where((data == self.ignore_id).unsqueeze(-1),
                       torch.zeros_like(embedding), embedding)
     if isinstance(x, PaddedSequenceWithMask):
         return PaddedSequenceWithMask(out, x.mask)
     else:
         return out
 def forward(self, env):
     prev_mode = self.training
     self.eval()
     with torch.no_grad():
         seqence_ouptut, pooled_output = self.model(
             input_ids=env.states["input_ids"],
             attention_mask=env.states["input_mask"],
             token_type_ids=env.states["segment_ids"])
         padded_sequence_output = PaddedSequenceWithMask(
             seqence_ouptut.permute(1, 0, 2),
             env.states["input_mask"].permute(1, 0))
         # TODO should we use pooled_output?
         env.states["nl_query_features"] = padded_sequence_output
         env.states["reference_features"] = padded_sequence_output
     self.train(mode=prev_mode)
     return env
Ejemplo n.º 9
0
    def forward(
        self,
        actions: PaddedSequenceWithMask,
        previous_actions: PaddedSequenceWithMask,
    ) -> PaddedSequenceWithMask:
        """
        Parameters
        ----------
        actions: rnn.PackedSequenceWithMask
            The input sequence of action. Each action is represented by
            the tuple of (ID of the node types, ID of the parent-action's
            rule, the index of the parent action).
            The padding value should be -1.
        previous_actions: rnn.PackedSequenceWithMask
            The input sequence of previous action. Each action is
            represented by the tuple of (ID of the applied rule, ID of
            the inserted token, the index of the word copied from
            the reference).
            The padding value should be -1.

        Returns
        -------
        action_features: PaddedSequenceWithMask
            Packed sequence containing the output hidden states.
        """
        L_a, B, _ = actions.data.shape

        node_types, parent_rule, parent_index = torch.split(
            actions.data, 1, dim=2)  # (L_a, B, 1)
        node_types = node_types.reshape([L_a, B])
        parent_rule = parent_rule.reshape([L_a, B])

        # Embed previous actions
        prev_action_embed = self.previous_actions_embed(previous_actions).data
        # Embed action
        node_type_embed = self.node_type_embed(node_types)
        parent_rule_embed = self.previous_actions_embed.rule_embed(parent_rule)

        # Decode embeddings
        feature = torch.cat(
            [prev_action_embed, node_type_embed, parent_rule_embed],
            dim=2)  # (L_a, B, input_size)
        return PaddedSequenceWithMask(feature, actions.mask)
Ejemplo n.º 10
0
    def forward(
        self,
        action_queries: PaddedSequenceWithMask,
    ) -> PaddedSequenceWithMask:
        """
        Parameters
        ----------
        action_queries: PaddedSequenceWithMask
            (L_ast, N, max_depth) where L_ast is the sequence length,
            N is the batch size.
            This tensor encodes the path from the root node to the target node.
            The padding value should be -1.

        Returns
        -------
        action_query_features: PaddedSequenceWithMask
            (L_ast, N, hidden_size) where L_ast is the sequence length,
            N is the batch_size.
        """
        embed = self.query_embed(action_queries.data)
        return PaddedSequenceWithMask(embed, action_queries.mask)
Ejemplo n.º 11
0
    def forward(
        self, reference_features: PaddedSequenceWithMask,
        action_features: PaddedSequenceWithMask,
        action_contexts: PaddedSequenceWithMask
    ) -> Tuple[PaddedSequenceWithMask, PaddedSequenceWithMask,
               PaddedSequenceWithMask]:
        """
        Parameters
        ----------
        reference_features: PaddedSequenceWithMask
            (L_nl, N, nl_feature_size) where L_nl is the sequence length,
            N is the batch size.
        action_features: PaddedSequenceWithMask
                Packed sequence containing the output hidden states.
        action_contexts: PaddedSequenceWithMask
                Packed sequence containing the context vectors.

        Returns
        -------
        rule_probs: PaddedSequenceWithMask
            (L_ast, N, rule_size) where L_ast is the sequence length,
            N is the batch_size.
        token_probs: PaddedSequenceWithMask
           (L_ast, N, token_size) where L_ast is the sequence length,
            N is the batch_size.
        reference_probs: PaddedSequenceWithMask
            (L_ast, N, L_nl) where L_ast is the sequence length,
            N is the batch_size.
        """
        L_q, B, _ = reference_features.data.shape

        # Decode embeddings
        # (L_a, B, hidden_size + query_size)
        dc = torch.cat([action_features.data, action_contexts.data], dim=2)

        # Calculate probabilities
        # (L_a, B, embedding_size)
        rule_pred = torch.tanh(self._l_rule(action_features.data))
        rule_pred = self._rule_embed_inv(rule_pred,
                                         self.embedding.previous_actions_embed.
                                         rule_embed)  # (L_a, B, num_rules)
        rule_pred = torch.softmax(rule_pred, dim=2)  # (L_a, B, num_rules)

        token_pred = torch.tanh(self._l_token(dc))  # (L_a, B, embedding_size)
        token_pred = self._token_embed_inv(
            token_pred, self.embedding.previous_actions_embed.token_embed
        )  # (L_a, B, num_tokens)
        # last index represents reference (copy)
        token_pred = torch.softmax(token_pred[:, :, :-1],
                                   dim=2)  # (L_a, B, num_tokens)

        # (L_a, B, query_length)
        reference_pred = self._pointer_net(dc, reference_features)
        reference_pred = torch.exp(reference_pred)
        reference_pred = reference_pred * \
            reference_features.mask.permute(1, 0).view(1, B, L_q)\
            .to(reference_pred.dtype)

        generate_pred = torch.softmax(self._l_generate(action_features.data),
                                      dim=2)  # (L_a, B, 2)
        rule, token, reference = \
            torch.split(generate_pred, 1, dim=2)  # (L_a, B, 1)

        rule_pred = rule * rule_pred
        token_pred = token * token_pred  # (L_a, B, num_tokens)
        reference_pred = reference * reference_pred  # (L_a, B, query_length)

        if self.training:
            rule_probs = PaddedSequenceWithMask(rule_pred,
                                                action_features.mask)
            token_probs = PaddedSequenceWithMask(token_pred,
                                                 action_features.mask)
            reference_probs = PaddedSequenceWithMask(reference_pred,
                                                     action_features.mask)
        else:
            rule_probs = rule_pred[-1, :, :]
            token_probs = token_pred[-1, :, :]
            reference_probs = reference_pred[-1, :, :]
        return rule_probs, token_probs, reference_probs
Ejemplo n.º 12
0
    def forward(self, input: PaddedSequenceWithMask,
                depth: torch.Tensor,
                rule_embed: torch.Tensor,
                adjacency_matrix: torch.Tensor) -> \
            Tuple[PaddedSequenceWithMask, torch.Tensor]:
        """
        Parameters
        ----------
        input: PaddedSequenceWithMask
            (L, N, hidden_size) where L is the sequence length,
            N is the batch size.
        depth: torch.Tensor
            (L, N) where L is the sequence length,
            N is the batch size.
        rule_embed: torch.Tensor
            (L, N, rule_embed_size) where L is the sequence length,
            N is the batch size.
        adjacency_matrix: torch.Tensor
            (N, L, L) where N is the batch size, L is the sequence length.

        Returns
        -------
        output: PaddedSequenceWithMask
            (L, N, hidden_size) where L is the sequence length,
            N is the batch size.
        attn_weights: torch.Tensor
            (N, L, L) where N is the batch size and L is the sequence length.
        """
        L, N, hidden_size = input.data.shape
        device = input.data.device
        h_in = input.data
        h = h_in + \
            index_embeddings(h_in, self.block_idx) + \
            position_embeddings(depth, self.block_idx, hidden_size)
        attn_mask = \
            torch.nn.Transformer.generate_square_subsequent_mask(None, L)\
            .to(device=device)
        h, attn = self.attention(h,
                                 h,
                                 h,
                                 key_padding_mask=input.mask.permute(1,
                                                                     0) == 0,
                                 attn_mask=attn_mask)
        h = h + h_in
        h = self.norm1(h)

        h_in = h
        h = self.gating(h, rule_embed)
        h = self.dropout(h)
        h = h + h_in
        h = self.norm2(h)

        h_in = h
        h = h * input.mask.to(h.dtype).reshape(L, N, 1)
        h = lne_to_nel(h)
        h = self.conv1(h, adjacency_matrix)
        h = self.dropout(h)
        h = gelu(h)
        h = h * input.mask.to(h.dtype).reshape(L, N, 1).permute(1, 2, 0)
        h = self.conv2(h, adjacency_matrix)
        h = self.dropout(h)
        h = gelu(h)
        h = nel_to_lne(h)
        h = h + h_in
        h = self.norm3(h)

        return PaddedSequenceWithMask(h, input.mask), attn
Ejemplo n.º 13
0
    def forward(self, query: PaddedSequenceWithMask,
                nl_feature: PaddedSequenceWithMask,
                ast_feature: PaddedSequenceWithMask) -> \
            Tuple[PaddedSequenceWithMask, torch.Tensor, torch.Tensor]:
        """
        Parameters
        ----------
        query: PaddedSequenceWithMask
            (L_ast, N, query_size) where L_ast is the sequence length,
            N is the batch size.
        nl_feature: PaddedSequenceWithMask
            (L_nl, N, nl_feature_size) where L_nl is the sequence length,
            N is the batch size.
        ast_feature: PaddedSequenceWithMask
            (L_ast, N, ast_feature_size) where L_ast is the sequence length,
            N is the batch size.

        Returns
        -------
        output: PaddedSequenceWithMask
            (L_ast, N, out_size) where L_ast is the sequence length,
            N is the batch_size.
        nl_attn_weights: torch.Tensor
            (N, L_nl, L_ast) where N is the batch size,
            L_nl is the sequence length of NL query,
            L_ast is the ast sequence length.
        ast_attn_weights: torch.Tensor
            (N, L_ast, L_ast) where N is the batch size,
            L_ast is the sequence length of ast.
        """
        L_ast, N, _ = query.data.shape
        device = query.data.device
        attn_mask = \
            torch.nn.Transformer.generate_square_subsequent_mask(None, L_ast)\
            .to(device=device)
        h_in = query.data
        h, ast_attn = self.ast_attention(
            key=ast_feature.data,
            query=h_in,
            value=ast_feature.data,
            key_padding_mask=ast_feature.mask.permute(1, 0) == 0,
            attn_mask=attn_mask)
        h = h + h_in
        h = self.norm1(h)

        h_in = h
        h, nl_attn = self.nl_attention(
            key=nl_feature.data,
            query=h,
            value=nl_feature.data,
            key_padding_mask=nl_feature.mask.permute(1, 0) == 0)
        h = h + h_in
        h = self.norm2(h)

        h_in = h
        h = h * query.mask.to(h.dtype).reshape(L_ast, N, 1)
        h = self.fc1(h.view(L_ast * N, -1))
        h = self.dropout(h)
        h = gelu(h)
        h = self.fc2(h)
        h = self.dropout(h)
        h = h.view(L_ast, N, -1)
        h = h * query.mask.to(h.dtype).reshape(L_ast, N, 1)
        h = h + h_in
        h = self.norm3(h)

        return PaddedSequenceWithMask(h, query.mask), nl_attn, ast_attn