def forward(self, pattern, pattern_len, graph, graph_len): bsz = pattern_len.size(0) gate = self.get_filter_gate(pattern, pattern_len, graph, graph_len) zero_mask = (gate == 0).unsqueeze(-1) if gate is not None else None pattern_emb, graph_emb = self.get_emb(pattern, pattern_len, graph, graph_len) if zero_mask is not None: graph_emb.masked_fill_(zero_mask, 0.0) pattern_output = pattern_emb for p_rnn in self.p_net: o = p_rnn(pattern_output) pattern_output = o + pattern_output pattern_mask = (batch_convert_len_to_mask(pattern_len)==0).unsqueeze(-1) pattern_output.masked_fill_(pattern_mask, 0.0) graph_output = graph_emb for g_rnn in self.g_net: o = g_rnn(graph_output) graph_output = o + graph_output if zero_mask is not None: graph_output.masked_fill_(zero_mask, 0.0) graph_mask = (batch_convert_len_to_mask(graph_len)==0).unsqueeze(-1) graph_output.masked_fill_(graph_mask, 0.0) if self.add_enc: pattern_enc, graph_enc = self.get_enc(pattern, pattern_len, graph, graph_len) if zero_mask is not None: graph_enc.masked_fill_(zero_mask, 0.0) pattern_output = torch.cat([pattern_enc, pattern_output], dim=2) graph_output = torch.cat([graph_enc, graph_output], dim=2) pred = self.predict_net(pattern_output, pattern_len, graph_output, graph_len) return pred
def forward(self, pattern, pattern_len, graph, graph_len): bsz = pattern_len.size(0) p_len, g_len = pattern.size(1), graph.size(1) plf, glf = pattern_len.float(), graph_len.float() inv_plf, inv_glf = 1.0 / plf, 1.0 / glf p_mask = batch_convert_len_to_mask(pattern_len) if p_len == pattern_len.max() else None g_mask = batch_convert_len_to_mask(graph_len) if g_len == graph_len.max() else None p, g = pattern, graph if g_mask is not None: mem = list() mem_mask = list() for idx in gather_indices_by_lens(graph_len): if self.mem_init.endswith("attn"): m, mk = init_mem(g[idx, :graph_len[idx[0]]], g_mask[idx, :graph_len[idx[0]]], mem_len=self.mem_len, mem_init=self.mem_init, attn=self.m_layer) elif self.mem_init.endswith("lstm"): m, mk = init_mem(g[idx, :graph_len[idx[0]]], g_mask[idx, :graph_len[idx[0]]], mem_len=self.mem_len, mem_init=self.mem_init, lstm=self.m_layer) else: m, mk = init_mem(g[idx, :graph_len[idx[0]]], g_mask[idx, :graph_len[idx[0]]], mem_len=self.mem_len, mem_init=self.mem_init, post_proj=self.m_layer) mem.append(m) mem_mask.append(mk) mem = torch.cat(mem, dim=0) mem_mask = torch.cat(mem_mask, dim=0) else: if self.mem_init.endswith("attn"): mem, mem_mask = init_mem(keyvalue, None, mem_len=self.mem_len, mem_init=self.mem_init, attn=self.m_layer) elif self.mem_init.endswith("lstm"): mem, mem_mask = init_mem(keyvalue, None, mem_len=self.mem_len, mem_init=self.mem_init, lstm=self.m_layer) else: mem, mem_mask = init_mem(keyvalue, None, mem_len=self.mem_len, mem_init=self.mem_init, post_proj=self.m_layer) for i in range(self.recurrent_steps): mem = self.p_attn(mem, p, p, p_mask) mem = self.g_attn(mem, g, g, g_mask) mem = mem.view(bsz, -1) y = self.pred_layer1(torch.cat([mem, plf, glf, inv_plf, inv_glf], dim=1)) y = self.act(y) y = self.pred_layer2(torch.cat([y, plf, glf, inv_plf, inv_glf], dim=1)) return y
def forward(self, pattern, pattern_len, graph, graph_len): bsz = pattern_len.size(0) p_len, g_len = pattern.size(1), graph.size(1) plf, glf = pattern_len.float(), graph_len.float() inv_plf, inv_glf = 1.0 / plf, 1.0 / glf p_mask = batch_convert_len_to_mask(pattern_len) if p_len == pattern_len.max() else None g_mask = batch_convert_len_to_mask(graph_len) if g_len == graph_len.max() else None p, g = pattern, graph for i in range(self.recurrent_steps): g = self.p_attn(g, p, p_mask) g = self.g_attn(g, g, g_mask) p = self.drop(self.p_layer(torch.max(p, dim=1, keepdim=True)[0])) g = self.drop(self.g_layer(g)) p = p.squeeze(1) g = torch.max(g, dim=1)[0] y = self.pred_layer1(torch.cat([p, g, g-p, g*p, plf, glf, inv_plf, inv_glf], dim=1)) y = self.act(y) y = self.pred_layer2(torch.cat([y, plf, glf, inv_plf, inv_glf], dim=1)) return y
def encoder_forward(self, enc_inp, enc_len, enc_txl, enc_params, mems=None): qlen = enc_inp.size(1) mlen = mems[0].size(1) if mems is not None else 0 enc_attn_mask = batch_convert_len_to_mask(enc_len + mlen, max_seq_len=qlen + mlen) return self._forward(enc_inp, enc_len, enc_txl, enc_params, attn_mask=enc_attn_mask, mems=mems)
def get_filter_gate(self, pattern, pattern_len, graph, graph_len): gate = None if self.vl_flt is not None: gate = self.vl_flt( split_and_batchify_graph_feats( pattern.ndata["label"].unsqueeze(-1), pattern_len)[0], split_and_batchify_graph_feats( graph.ndata["label"].unsqueeze(-1), graph_len)[0]) if gate is not None: bsz = graph_len.size(0) max_g_len = graph_len.max() if bsz * max_g_len != graph.number_of_nodes(): graph_mask = batch_convert_len_to_mask( graph_len) # bsz x max_len gate = gate.masked_select(graph_mask.unsqueeze(-1)).view(-1, 1) else: gate = gate.view(-1, 1) return gate
def forward(self, pattern, pattern_len, graph, graph_len): # data, target, *mems # nn.DataParallel does not allow size(0) tensors to be broadcasted. # So, have to initialize size(0) mems inside the model forward. # Moreover, have to return new_mems to allow nn.DataParallel to piece # them together. bsz = pattern_len.size(0) gate = self.get_filter_gate(pattern, pattern_len, graph, graph_len) zero_mask = (gate == 0).unsqueeze(-1) if gate is not None else None pattern_emb, graph_emb = self.get_emb(pattern, pattern_len, graph, graph_len) if zero_mask is not None: graph_emb.masked_fill_(zero_mask, 0.0) pattern_emb = self.p_emb_proj(pattern_emb).mul_(self.emb_scale) graph_emb = self.g_emb_proj(graph_emb).mul_(self.emb_scale) pattern_segments = segment_data(pattern_emb, self.tgt_len) pattern_seg_lens = segment_length(pattern_len, self.tgt_len) graph_segments = segment_data(graph_emb, self.tgt_len) graph_seg_lens = segment_length(graph_len, self.tgt_len) pattern_outputs = list() for i, (pattern_seg, pattern_seg_len) in enumerate( zip(pattern_segments, pattern_seg_lens)): if i == 0: pattern_mems = self.init_mems(len(self.p_net), pattern_seg) pattern_output, pattern_mems = self.encoder_forward( pattern_seg, pattern_seg_len, self.p_net, self.p_params, mems=pattern_mems) pattern_outputs.append(pattern_output) pattern_output = torch.cat(pattern_outputs, dim=1)[:, :pattern_emb.size(1)] # some segments may only have padded elements, we need to set them as 0 manually pattern_mask = (batch_convert_len_to_mask( pattern_len, max_seq_len=pattern_output.size(1)) == 0).unsqueeze(-1) pattern_output.masked_fill_(pattern_mask, 0.0) graph_outputs = list() for i, (graph_seg, graph_seg_len) in enumerate(zip(graph_segments, graph_seg_lens)): if i == 0: graph_mems = self.init_mems(len(self.g_net), graph_seg) graph_output, graph_mems = self.encoder_forward(graph_seg, graph_seg_len, self.g_net, self.g_params, mems=graph_mems) graph_outputs.append(graph_output) graph_output = torch.cat(graph_outputs, dim=1)[:, :graph_emb.size(1)] # some segments may only have padded elements, we need to set them as 0 manually graph_mask = (batch_convert_len_to_mask( graph_len, max_seq_len=graph_output.size(1)) == 0).unsqueeze(-1) graph_output.masked_fill_(graph_mask, 0.0) if self.add_enc: pattern_enc, graph_enc = self.get_enc(pattern, pattern_len, graph, graph_len) if zero_mask is not None: graph_enc.masked_fill_(zero_mask, 0.0) pattern_output = torch.cat([pattern_enc, pattern_output], dim=2) graph_output = torch.cat([graph_enc, graph_output], dim=2) pred = self.predict_net(pattern_output, pattern_len, graph_output, graph_len) return pred