Ejemplo n.º 1
0
    def forward(self, pattern, pattern_len, graph, graph_len):
        bsz = pattern_len.size(0)

        gate = self.get_filter_gate(pattern, pattern_len, graph, graph_len)
        zero_mask = (gate == 0).unsqueeze(-1) if gate is not None else None
        pattern_emb, graph_emb = self.get_emb(pattern, pattern_len, graph, graph_len)
        if zero_mask is not None:
            graph_emb.masked_fill_(zero_mask, 0.0)

        pattern_output = pattern_emb
        for p_rnn in self.p_net:
            o = p_rnn(pattern_output)
            pattern_output = o + pattern_output
        pattern_mask = (batch_convert_len_to_mask(pattern_len)==0).unsqueeze(-1)
        pattern_output.masked_fill_(pattern_mask, 0.0)

        graph_output = graph_emb
        for g_rnn in self.g_net:
            o = g_rnn(graph_output)
            graph_output = o + graph_output
            if zero_mask is not None:
                graph_output.masked_fill_(zero_mask, 0.0)
        graph_mask = (batch_convert_len_to_mask(graph_len)==0).unsqueeze(-1)
        graph_output.masked_fill_(graph_mask, 0.0)
        
        if self.add_enc:
            pattern_enc, graph_enc = self.get_enc(pattern, pattern_len, graph, graph_len)
            if zero_mask is not None:
                graph_enc.masked_fill_(zero_mask, 0.0)
            pattern_output = torch.cat([pattern_enc, pattern_output], dim=2)
            graph_output = torch.cat([graph_enc, graph_output], dim=2)
        
        pred = self.predict_net(pattern_output, pattern_len, graph_output, graph_len)
        
        return pred
Ejemplo n.º 2
0
    def forward(self, pattern, pattern_len, graph, graph_len):
        bsz = pattern_len.size(0)
        p_len, g_len = pattern.size(1), graph.size(1)
        plf, glf = pattern_len.float(), graph_len.float()
        inv_plf, inv_glf = 1.0 / plf, 1.0 / glf
        p_mask = batch_convert_len_to_mask(pattern_len) if p_len == pattern_len.max() else None
        g_mask = batch_convert_len_to_mask(graph_len) if g_len == graph_len.max() else None

        p, g = pattern, graph
        if g_mask is not None:
            mem = list()
            mem_mask = list()
            for idx in gather_indices_by_lens(graph_len):
                if self.mem_init.endswith("attn"):
                    m, mk = init_mem(g[idx, :graph_len[idx[0]]], g_mask[idx, :graph_len[idx[0]]], 
                        mem_len=self.mem_len, mem_init=self.mem_init, attn=self.m_layer)
                elif self.mem_init.endswith("lstm"):
                    m, mk = init_mem(g[idx, :graph_len[idx[0]]], g_mask[idx, :graph_len[idx[0]]], 
                        mem_len=self.mem_len, mem_init=self.mem_init, lstm=self.m_layer)
                else:
                    m, mk = init_mem(g[idx, :graph_len[idx[0]]], g_mask[idx, :graph_len[idx[0]]], 
                        mem_len=self.mem_len, mem_init=self.mem_init, post_proj=self.m_layer)
                mem.append(m)
                mem_mask.append(mk)
            mem = torch.cat(mem, dim=0)
            mem_mask = torch.cat(mem_mask, dim=0)
        else:
            if self.mem_init.endswith("attn"):
                mem, mem_mask = init_mem(keyvalue, None, 
                    mem_len=self.mem_len, mem_init=self.mem_init, attn=self.m_layer)
            elif self.mem_init.endswith("lstm"):
                mem, mem_mask = init_mem(keyvalue, None, 
                    mem_len=self.mem_len, mem_init=self.mem_init, lstm=self.m_layer)
            else:
                mem, mem_mask = init_mem(keyvalue, None, 
                    mem_len=self.mem_len, mem_init=self.mem_init, post_proj=self.m_layer)
        for i in range(self.recurrent_steps):
            mem = self.p_attn(mem, p, p, p_mask)
            mem = self.g_attn(mem, g, g, g_mask)

        mem = mem.view(bsz, -1)
        y = self.pred_layer1(torch.cat([mem, plf, glf, inv_plf, inv_glf], dim=1))
        y = self.act(y)
        y = self.pred_layer2(torch.cat([y, plf, glf, inv_plf, inv_glf], dim=1))

        return y
Ejemplo n.º 3
0
    def forward(self, pattern, pattern_len, graph, graph_len):
        bsz = pattern_len.size(0)
        p_len, g_len = pattern.size(1), graph.size(1)
        plf, glf = pattern_len.float(), graph_len.float()
        inv_plf, inv_glf = 1.0 / plf, 1.0 / glf
        p_mask = batch_convert_len_to_mask(pattern_len) if p_len == pattern_len.max() else None
        g_mask = batch_convert_len_to_mask(graph_len) if g_len == graph_len.max() else None

        p, g = pattern, graph
        for i in range(self.recurrent_steps):
            g = self.p_attn(g, p, p_mask)
            g = self.g_attn(g, g, g_mask)

        p = self.drop(self.p_layer(torch.max(p, dim=1, keepdim=True)[0]))
        g = self.drop(self.g_layer(g))

        p = p.squeeze(1)
        g = torch.max(g, dim=1)[0]
        y = self.pred_layer1(torch.cat([p, g, g-p, g*p, plf, glf, inv_plf, inv_glf], dim=1))
        y = self.act(y)
        y = self.pred_layer2(torch.cat([y, plf, glf, inv_plf, inv_glf], dim=1))

        return y
Ejemplo n.º 4
0
    def encoder_forward(self,
                        enc_inp,
                        enc_len,
                        enc_txl,
                        enc_params,
                        mems=None):
        qlen = enc_inp.size(1)
        mlen = mems[0].size(1) if mems is not None else 0
        enc_attn_mask = batch_convert_len_to_mask(enc_len + mlen,
                                                  max_seq_len=qlen + mlen)

        return self._forward(enc_inp,
                             enc_len,
                             enc_txl,
                             enc_params,
                             attn_mask=enc_attn_mask,
                             mems=mems)
Ejemplo n.º 5
0
 def get_filter_gate(self, pattern, pattern_len, graph, graph_len):
     gate = None
     if self.vl_flt is not None:
         gate = self.vl_flt(
             split_and_batchify_graph_feats(
                 pattern.ndata["label"].unsqueeze(-1), pattern_len)[0],
             split_and_batchify_graph_feats(
                 graph.ndata["label"].unsqueeze(-1), graph_len)[0])
     if gate is not None:
         bsz = graph_len.size(0)
         max_g_len = graph_len.max()
         if bsz * max_g_len != graph.number_of_nodes():
             graph_mask = batch_convert_len_to_mask(
                 graph_len)  # bsz x max_len
             gate = gate.masked_select(graph_mask.unsqueeze(-1)).view(-1, 1)
         else:
             gate = gate.view(-1, 1)
     return gate
Ejemplo n.º 6
0
    def forward(self, pattern, pattern_len, graph, graph_len):
        # data, target, *mems
        # nn.DataParallel does not allow size(0) tensors to be broadcasted.
        # So, have to initialize size(0) mems inside the model forward.
        # Moreover, have to return new_mems to allow nn.DataParallel to piece
        # them together.
        bsz = pattern_len.size(0)

        gate = self.get_filter_gate(pattern, pattern_len, graph, graph_len)
        zero_mask = (gate == 0).unsqueeze(-1) if gate is not None else None
        pattern_emb, graph_emb = self.get_emb(pattern, pattern_len, graph,
                                              graph_len)
        if zero_mask is not None:
            graph_emb.masked_fill_(zero_mask, 0.0)

        pattern_emb = self.p_emb_proj(pattern_emb).mul_(self.emb_scale)
        graph_emb = self.g_emb_proj(graph_emb).mul_(self.emb_scale)

        pattern_segments = segment_data(pattern_emb, self.tgt_len)
        pattern_seg_lens = segment_length(pattern_len, self.tgt_len)
        graph_segments = segment_data(graph_emb, self.tgt_len)
        graph_seg_lens = segment_length(graph_len, self.tgt_len)

        pattern_outputs = list()
        for i, (pattern_seg, pattern_seg_len) in enumerate(
                zip(pattern_segments, pattern_seg_lens)):
            if i == 0:
                pattern_mems = self.init_mems(len(self.p_net), pattern_seg)
            pattern_output, pattern_mems = self.encoder_forward(
                pattern_seg,
                pattern_seg_len,
                self.p_net,
                self.p_params,
                mems=pattern_mems)
            pattern_outputs.append(pattern_output)
        pattern_output = torch.cat(pattern_outputs,
                                   dim=1)[:, :pattern_emb.size(1)]
        # some segments may only have padded elements, we need to set them as 0 manually
        pattern_mask = (batch_convert_len_to_mask(
            pattern_len,
            max_seq_len=pattern_output.size(1)) == 0).unsqueeze(-1)
        pattern_output.masked_fill_(pattern_mask, 0.0)

        graph_outputs = list()
        for i, (graph_seg,
                graph_seg_len) in enumerate(zip(graph_segments,
                                                graph_seg_lens)):
            if i == 0:
                graph_mems = self.init_mems(len(self.g_net), graph_seg)
            graph_output, graph_mems = self.encoder_forward(graph_seg,
                                                            graph_seg_len,
                                                            self.g_net,
                                                            self.g_params,
                                                            mems=graph_mems)
            graph_outputs.append(graph_output)
        graph_output = torch.cat(graph_outputs, dim=1)[:, :graph_emb.size(1)]
        # some segments may only have padded elements, we need to set them as 0 manually
        graph_mask = (batch_convert_len_to_mask(
            graph_len, max_seq_len=graph_output.size(1)) == 0).unsqueeze(-1)
        graph_output.masked_fill_(graph_mask, 0.0)

        if self.add_enc:
            pattern_enc, graph_enc = self.get_enc(pattern, pattern_len, graph,
                                                  graph_len)
            if zero_mask is not None:
                graph_enc.masked_fill_(zero_mask, 0.0)
            pattern_output = torch.cat([pattern_enc, pattern_output], dim=2)
            graph_output = torch.cat([graph_enc, graph_output], dim=2)

        pred = self.predict_net(pattern_output, pattern_len, graph_output,
                                graph_len)

        return pred