示例#1
0
class DArtNet(nn.Module):
    def __init__(self,
                 num_nodes,
                 h_dim,
                 num_rels,
                 dropout=0,
                 model=0,
                 seq_len=10,
                 num_k=10,
                 gamma=1):
        super(DArtNet, self).__init__()
        self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.num_rels = num_rels
        self.model = model
        self.seq_len = seq_len
        self.num_k = num_k
        self.gamma = gamma
        self.rel_embeds = nn.Parameter(torch.Tensor(num_rels, h_dim))
        nn.init.xavier_uniform_(self.rel_embeds,
                                gain=nn.init.calculate_gain('relu'))

        self.ent_embeds = nn.Parameter(torch.Tensor(num_nodes, h_dim))
        nn.init.xavier_uniform_(self.ent_embeds,
                                gain=nn.init.calculate_gain('relu'))
        self.ent_embeds_attribute = nn.Parameter(torch.Tensor(
            num_nodes, h_dim))
        nn.init.xavier_uniform_(self.ent_embeds_attribute,
                                gain=nn.init.calculate_gain('relu'))

        self.dropout = nn.Dropout(dropout)
        self.sub_encoder = nn.GRU(3 * h_dim, h_dim, batch_first=True)
        # self.ob_encoder = self.sub_encoder

        self.att_encoder = nn.GRU(3 * h_dim, h_dim, batch_first=True)

        self.aggregator_s = MeanAggregator(h_dim, dropout, seq_len)
        # self.aggregator_o = self.aggregator_s

        self.f1 = nn.Linear(2 * self.h_dim, 1)
        self.f2 = nn.Linear(3 * h_dim, num_nodes)

        self.W1 = nn.Linear(1, self.h_dim)
        # self.W2 = nn.Linear(2 * self.h_dim, self.h_dim)
        self.W3 = nn.Linear(3 * self.h_dim, self.h_dim)
        self.W4 = nn.Linear(2 * self.h_dim, self.h_dim)

        # For recording history in inference

        self.entity_s_his_test = None
        self.att_s_his_test = None
        self.rel_s_his_test = None
        self.self_att_s_his_test = None

        # self.entity_o_his_test = None
        # self.att_o_his_test = None
        # self.rel_o_his_test = None
        # self.self_att_o_his_test = None

        self.entity_s_his_cache = None
        self.att_s_his_cache = None
        self.rel_s_his_cache = None
        self.self_att_s_his_cache = None

        # self.entity_o_his_cache = None
        # self.att_o_his_cache = None
        # self.rel_o_his_cache = None
        # self.self_att_o_his_cache = None

        self.att_s_dict = {}
        # self.att_o_dict = {}

        self.latest_time = 0

        self.criterion = nn.CrossEntropyLoss()
        self.att_criterion = nn.MSELoss()

    """
    Prediction function in training. 
    This should be different from testing because in testing we don't use ground-truth history.
    """

    def forward(self,
                triplets,
                s_hist,
                rel_s_hist,
                att_s_hist,
                self_att_s_hist,
                o_hist,
                rel_o_hist,
                att_o_hist,
                self_att_o_hist,
                predict_both=True):
        # print('here')
        s = triplets[:, 0].type(torch.cuda.LongTensor)
        r = triplets[:, 1].type(torch.cuda.LongTensor)
        o = triplets[:, 2].type(torch.cuda.LongTensor)
        a_s = triplets[:, 3].type(torch.cuda.FloatTensor)
        a_o = triplets[:, 4].type(torch.cuda.FloatTensor)

        batch_size = len(s)

        s_hist_len = torch.LongTensor(list(map(len, s_hist))).cuda()
        s_len, s_idx = s_hist_len.sort(0, descending=True)
        # o_hist_len = torch.LongTensor(list(map(len, o_hist))).cuda()
        # o_len, o_idx = o_hist_len.sort(0, descending=True)
        # print('here1')
        s_packed_input, att_s_packed_input = self.aggregator_s(
            s_hist, rel_s_hist, att_s_hist, self_att_s_hist, s, r,
            self.ent_embeds, self.ent_embeds_attribute, self.rel_embeds,
            self.W1, self.W3, self.W4)

        # o_packed_input, att_o_packed_input = self.aggregator_o(
        #     o_hist, rel_o_hist, att_o_hist, self_att_o_hist, o, r,
        #     self.ent_embeds, self.rel_embeds[self.num_rels:], self.W1, self.W2,
        #     self.W3, self.W4)

        if predict_both:
            _, s_h = self.sub_encoder(s_packed_input)
            s_h = s_h.squeeze()
            s_h = torch.cat(
                (s_h, torch.zeros(len(s) - len(s_h), self.h_dim).cuda()),
                dim=0)
            ob_pred = self.f2(
                self.dropout(
                    torch.cat((self.ent_embeds[s[s_idx]], s_h,
                               self.rel_embeds[r[s_idx]]),
                              dim=1)))
            loss_sub = self.criterion(ob_pred, o[s_idx])
        else:
            ob_pred = None
            loss_sub = 0

        # _, o_h = self.ob_encoder(o_packed_input)

        _, att_s_h = self.att_encoder(att_s_packed_input)
        # _, att_o_h = self.att_encoder(att_o_packed_input)
        # print('here2')

        # o_h = o_h.squeeze()
        att_s_h = att_s_h.squeeze()
        # att_o_h = att_o_h.squeeze()

        # o_h = torch.cat(
        #     (o_h, torch.zeros(len(o) - len(o_h), self.h_dim).cuda()), dim=0)
        att_s_h = torch.cat(
            (att_s_h, torch.zeros(len(s) - len(att_s_h), self.h_dim).cuda()),
            dim=0)
        # att_o_h = torch.cat(
        #     (att_o_h, torch.zeros(len(o) - len(att_o_h), self.h_dim).cuda()),
        #     dim=0)
        # print('here3')

        sub_att_pred = self.f1(
            self.dropout(
                torch.cat((self.ent_embeds_attribute[s[s_idx]], att_s_h),
                          dim=1))).squeeze()

        # sub_pred = self.f2(
        #     self.dropout(
        #         torch.cat((self.ent_embeds[o[o_idx]], o_h,
        #                    self.rel_embeds[self.num_rels:][r[o_idx]]),
        #                   dim=1)))

        # ob_att_pred = self.f1(
        #     self.dropout(torch.cat((self.ent_embeds[o[o_idx]], att_o_h),
        #                            dim=1))).squeeze()

        # loss_ob = self.criterion(sub_pred, s[o_idx])

        loss_att_sub = self.att_criterion(sub_att_pred, a_s[s_idx])
        # loss_att_ob = self.att_criterion(ob_att_pred, a_o[o_idx])

        loss = loss_sub + self.gamma * loss_att_sub

        return loss, loss_att_sub, ob_pred, sub_att_pred, s_idx

    def init_history(self):
        self.entity_s_his_test = [[] for _ in range(self.num_nodes)]
        self.att_s_his_test = [[] for _ in range(self.num_nodes)]
        self.rel_s_his_test = [[] for _ in range(self.num_nodes)]
        self.self_att_s_his_test = [[] for _ in range(self.num_nodes)]

        # self.entity_o_his_test = [[] for _ in range(self.num_nodes)]
        # self.att_o_his_test = [[] for _ in range(self.num_nodes)]
        # self.rel_o_his_test = [[] for _ in range(self.num_nodes)]
        # self.self_att_o_his_test = [[] for _ in range(self.num_nodes)]

        self.entity_s_his_cache = [[] for _ in range(self.num_nodes)]
        self.att_s_his_cache = [[] for _ in range(self.num_nodes)]
        self.rel_s_his_cache = [[] for _ in range(self.num_nodes)]
        self.self_att_s_his_cache = [[] for _ in range(self.num_nodes)]

        # self.entity_o_his_cache = [[] for _ in range(self.num_nodes)]
        # self.att_o_his_cache = [[] for _ in range(self.num_nodes)]
        # self.rel_o_his_cache = [[] for _ in range(self.num_nodes)]
        # self.self_att_o_his_cache = [[] for _ in range(self.num_nodes)]

    def get_loss(self, triplets, s_hist, rel_s_hist, att_s_hist,
                 self_att_s_hist, o_hist, rel_o_hist, att_o_hist,
                 self_att_o_hist):
        loss, loss_att_sub, _, _, _ = self.forward(triplets, s_hist,
                                                   rel_s_hist, att_s_hist,
                                                   self_att_s_hist, o_hist,
                                                   rel_o_hist, att_o_hist,
                                                   self_att_o_hist)
        return loss, loss_att_sub

    """
    Prediction function in testing
    """

    def predict(self, triplets, s_hist, rel_s_hist, att_s_hist,
                self_att_s_hist, o_hist, rel_o_hist, att_o_hist,
                self_att_o_hist):

        self.att_s_dict = {}
        # self.att_o_dict = {}
        self.att_residual_dict = {}

        _, loss_att_sub, _, sub_att_pred, s_idx = self.forward(
            triplets, s_hist, rel_s_hist, att_s_hist, self_att_s_hist, o_hist,
            rel_o_hist, att_o_hist, self_att_o_hist, False)
        # print(triplets[:, 0])
        # print(s_hist)
        # print(sub_att_pred)
        indices = {}
        for i in range(len(triplets)):
            s = triplets[s_idx[i], 0].type(torch.LongTensor).item()
            o = triplets[s_idx[i], 2].type(torch.LongTensor).item()
            t = triplets[s_idx[i], 5].type(torch.LongTensor).item()
            s_att = sub_att_pred[i].cpu().item()

            if s not in self.att_s_dict:
                self.att_s_dict[s] = s_att
                indices[s] = i
            else:
                assert (self.att_s_dict[s] == s_att)

            # s = triplets[o_idx[i], 0].type(torch.LongTensor).item()
            # o = triplets[o_idx[i], 2].type(torch.LongTensor).item()
            # t = triplets[o_idx[i], 5].type(torch.LongTensor).item()
            # o_att = ob_att_pred[i].cpu().item()

            # if o not in self.att_o_dict:
            #     self.att_o_dict[o] = o_att
            # else:
            #     assert (self.att_o_dict[o] == o_att)

        for i in range(self.num_nodes):
            if i not in self.att_s_dict:  # and i not in self.att_o_dict:
                s_h = torch.zeros(1, self.h_dim).cuda()
                sub_att_pred = self.f1(
                    torch.cat((self.ent_embeds_attribute[[i]], s_h),
                              dim=1)).squeeze()
                self.att_residual_dict[i] = sub_att_pred

        return loss_att_sub

    def predict_single(self, triplet, s_hist, rel_s_hist, att_s_hist,
                       self_att_s_hist, o_hist, rel_o_hist, att_o_hist,
                       self_att_o_hist):
        # print(triplet)
        s = triplet[0].type(torch.cuda.LongTensor)
        r = triplet[1].type(torch.cuda.LongTensor)
        o = triplet[2].type(torch.cuda.LongTensor)
        a_s = triplet[3].type(torch.cuda.FloatTensor)
        a_o = triplet[4].type(torch.cuda.FloatTensor)
        t = triplet[5].cpu()
        # print('here')
        if self.latest_time != t:

            for ee in range(self.num_nodes):
                if len(self.entity_s_his_cache[ee]) != 0:
                    if len(self.entity_s_his_test[ee]) >= self.seq_len:
                        self.entity_s_his_test[ee].pop(0)
                        self.att_s_his_test[ee].pop(0)
                        self.self_att_s_his_test[ee].pop(0)
                        self.rel_s_his_test[ee].pop(0)

                    self.entity_s_his_test[ee].append(
                        self.entity_s_his_cache[ee].clone())
                    self.att_s_his_test[ee].append(
                        self.att_s_his_cache[ee].clone())
                    self.self_att_s_his_test[ee].append(
                        self.self_att_s_his_cache[ee])
                    self.rel_s_his_test[ee].append(
                        self.rel_s_his_cache[ee].clone())

                    self.entity_s_his_cache[ee] = []
                    self.att_s_his_cache[ee] = []
                    self.self_att_s_his_cache[ee] = []
                    self.rel_s_his_cache[ee] = []

                # if len(self.entity_o_his_cache[ee]) != 0:
                #     if len(self.entity_o_his_test[ee]) >= self.seq_len:
                #         self.entity_o_his_test[ee].pop(0)
                #         self.att_o_his_test[ee].pop(0)
                #         self.self_att_o_his_test[ee].pop(0)
                #         self.rel_o_his_test[ee].pop(0)

                #     self.entity_o_his_test[ee].append(
                #         self.entity_o_his_cache[ee].clone())
                #     self.att_o_his_test[ee].append(
                #         self.att_o_his_cache[ee].clone())
                #     self.self_att_o_his_test[ee].append(
                #         self.self_att_o_his_cache[ee])
                #     self.rel_o_his_test[ee].append(
                #         self.rel_o_his_cache[ee].clone())

                #     self.entity_o_his_cache[ee] = []
                #     self.att_o_his_cache[ee] = []
                #     self.self_att_o_his_cache[ee] = []
                #     self.rel_o_his_cache[ee] = []

            self.latest_time = t

        if len(s_hist) == 0:
            s_h = torch.zeros(self.h_dim).cuda()

        else:
            if len(self.entity_s_his_test[s]) == 0:
                self.entity_s_his_test[s] = s_hist.copy()
                self.rel_s_his_test[s] = rel_s_hist.copy()
                self.att_s_his_test[s] = att_s_hist.copy()
                self.self_att_s_his_test[s] = self_att_s_hist

            s_history = self.entity_s_his_test[s]
            rel_s_history = self.rel_s_his_test[s]
            att_s_history = self.att_s_his_test[s]
            self_att_s_history = self.self_att_s_his_test[s]

            inp_s, _ = self.aggregator_s.predict(
                s_history, rel_s_history, att_s_history, self_att_s_history, s,
                r, self.ent_embeds, self.ent_embeds_attribute, self.rel_embeds,
                self.W1, self.W3, self.W4)

            _, s_h = self.sub_encoder(
                inp_s.view(1, len(s_history), 3 * self.h_dim))
            s_h = s_h.squeeze()

        # if len(o_hist) == 0:
        #     o_h = torch.zeros(self.h_dim).cuda()
        # else:
        #     if len(self.entity_o_his_test[o]) == 0:
        #         self.entity_o_his_test[o] = o_hist.copy()
        #         self.rel_o_his_test[o] = rel_o_hist.copy()
        #         self.att_o_his_test[o] = att_o_hist.copy()
        #         self.self_att_o_his_test[o] = self_att_o_hist

        #     o_history = self.entity_o_his_test[o]
        #     rel_o_history = self.rel_o_his_test[o]
        #     att_o_history = self.att_o_his_test[o]
        #     self_att_o_history = self.self_att_o_his_test[o]

        #     inp_o, _ = self.aggregator_o.predict(
        #         o_history, rel_o_history, att_o_history, self_att_o_history, o,
        #         r, self.ent_embeds, self.rel_embeds[self.num_rels:], self.W1,
        #         self.W2, self.W3, self.W4)

        #     _, o_h = self.ob_encoder(
        #         inp_o.view(1, len(o_history), 3 * self.h_dim))
        #     o_h = o_h.squeeze()

        ob_pred = self.f2(
            torch.cat((self.ent_embeds[s], s_h, self.rel_embeds[r]), dim=0))
        # sub_pred = self.f2(
        #     torch.cat(
        #         (self.ent_embeds[o], o_h, self.rel_embeds[self.num_rels:][r]),
        #         dim=0))

        _, o_candidate = torch.topk(ob_pred, self.num_k)
        # _, s_candidate = torch.topk(sub_pred, self.num_k)

        self.entity_s_his_cache[s], self.rel_s_his_cache[
            s], self.att_s_his_cache[s], self.self_att_s_his_cache[
                s] = self.update_cache(self.entity_s_his_cache[s],
                                       self.rel_s_his_cache[s],
                                       self.att_s_his_cache[s],
                                       self.self_att_s_his_cache[s], s.cpu(),
                                       r.cpu(), o_candidate.cpu())
        # self.entity_o_his_cache[o], self.rel_o_his_cache[
        #     o], self.att_o_his_cache[o], self.self_att_o_his_cache[
        #         o] = self.update_cache(self.entity_o_his_cache[o],
        #                                self.rel_o_his_cache[o],
        #                                self.att_o_his_cache[o],
        #                                self.self_att_o_his_cache[o], o.cpu(),
        #                                r.cpu(), s_candidate.cpu())

        # loss_sub = self.criterion(ob_pred.view(1, -1), o.view(-1))
        # loss_ob = self.criterion(sub_pred.view(1, -1), s.view(-1))

        # loss = loss_sub + loss_ob

        return ob_pred

    def evaluate_filter(self, triplet, s_hist, rel_s_hist, att_s_hist,
                        self_att_s_hist, o_hist, rel_o_hist, att_o_hist,
                        self_att_o_hist, all_triplets):
        s = triplet[0].type(torch.cuda.LongTensor)
        r = triplet[1].type(torch.cuda.LongTensor)
        o = triplet[2].type(torch.cuda.LongTensor)
        # print(s_hist)
        # print(rel_s_hist)
        ob_pred = self.predict_single(triplet, s_hist, rel_s_hist, att_s_hist,
                                      self_att_s_hist, o_hist, rel_o_hist,
                                      att_o_hist, self_att_o_hist)
        o_label = o
        s_label = s
        # sub_pred = torch.sigmoid(sub_pred)
        ob_pred = torch.sigmoid(ob_pred)

        ground = ob_pred[o].clone()

        s_id = torch.nonzero(
            all_triplets[:, 0].type(torch.cuda.LongTensor) == s).view(-1)
        idx = torch.nonzero(
            all_triplets[s_id, 1].type(torch.cuda.LongTensor) == r).view(-1)
        idx = s_id[idx]
        idx = all_triplets[idx, 2].type(torch.cuda.LongTensor)
        ob_pred[idx] = 0
        ob_pred[o_label] = ground

        ob_pred_comp1 = (ob_pred > ground).data.cpu().numpy()
        ob_pred_comp2 = (ob_pred == ground).data.cpu().numpy()
        rank_ob = np.sum(ob_pred_comp1) + (
            (np.sum(ob_pred_comp2) - 1.0) / 2) + 1

        # ground = sub_pred[s].clone()

        # o_id = torch.nonzero(
        #     all_triplets[:, 2].type(torch.cuda.LongTensor) == o).view(-1)
        # idx = torch.nonzero(
        #     all_triplets[o_id, 1].type(torch.cuda.LongTensor) == r).view(-1)
        # idx = o_id[idx]
        # idx = all_triplets[idx, 0].type(torch.cuda.LongTensor)
        # sub_pred[idx] = 0
        # sub_pred[s_label] = ground

        # sub_pred_comp1 = (sub_pred > ground).data.cpu().numpy()
        # sub_pred_comp2 = (sub_pred == ground).data.cpu().numpy()
        # rank_sub = np.sum(sub_pred_comp1) + (
        #     (np.sum(sub_pred_comp2) - 1.0) / 2) + 1
        return np.array([rank_ob])

    def update_cache(self, s_his_cache, r_his_cache, att_his_cache,
                     self_att_his_cache, s, r, o_candidate):
        if len(s_his_cache) == 0:
            s_his_cache = o_candidate.view(-1)
            r_his_cache = r.repeat(len(o_candidate), 1).view(-1)
            att_his_cache = []
            for key in s_his_cache:
                k = key.item()
                if k in self.att_s_dict:
                    att_his_cache.append(self.att_s_dict[k])
                # elif k in self.att_o_dict:
                #     att_his_cache.append(self.att_o_dict[k])
                else:
                    att_his_cache.append(self.att_residual_dict[k])

            if s.item() in self.att_s_dict:
                self_att_his_cache = self.att_s_dict[s.item()]
            # elif s.item() in self.att_o_dict:
            #     self_att_his_cache = self.att_o_dict[s.item()]
            else:
                self_att_his_cache = self.att_residual_dict[s.item()]

            if type(att_his_cache) != torch.Tensor:
                att_his_cache = torch.FloatTensor(att_his_cache)
        else:
            ent_list = s_his_cache[torch.nonzero(r_his_cache == r).view(-1)]
            tem = []
            for i in range(len(o_candidate)):
                if o_candidate[i] not in ent_list:
                    tem.append(i)

            if len(tem) != 0:
                forward = o_candidate[torch.LongTensor(tem)].view(-1)
                forward2 = r.repeat(len(tem), 1).view(-1)

                s_his_cache = torch.cat(
                    (torch.LongTensor(s_his_cache), forward), dim=0)
                r_his_cache = torch.cat(
                    (torch.LongTensor(r_his_cache), forward2), dim=0)
                att_his_cache = torch.cat((torch.FloatTensor(att_his_cache),
                                           forward2.type(torch.FloatTensor)),
                                          dim=0)
                # self_att_his_cache = torch.cat((self_att_his_cache, forward2),
                #                                dim=0)
                # print('---------------no')
                for i in range(len(s_his_cache)):
                    if s_his_cache[i] in ent_list:
                        # print('-------------------yes')
                        if s_his_cache[i].item() in self.att_s_dict:
                            att_his_cache[i] = self.att_s_dict[
                                s_his_cache[i].item()]
                        # elif s_his_cache[i].item() in self.att_o_dict:
                        #     att_his_cache[i] = self.att_o_dict[
                        #         s_his_cache[i].item()]
                        else:
                            att_his_cache[i] = self.att_residual_dict[
                                s_his_cache[i].item()]

                if s.item() in self.att_s_dict:
                    self_att_his_cache = self.att_s_dict[s.item()]
                # elif s.item() in self.att_o_dict:
                #     self_att_his_cache = self.att_o_dict[s.item()]
                else:
                    self_att_his_cache = self.att_residual_dict[s.item()]

        return s_his_cache, r_his_cache, att_his_cache, self_att_his_cache
示例#2
0
    def __init__(self,
                 in_dim,
                 h_dim,
                 num_rels,
                 dropout=0,
                 model=0,
                 seq_len=10,
                 num_k=10):
        super(RENet, self).__init__()
        self.in_dim = in_dim
        self.num_nodes = in_dim
        self.h_dim = h_dim
        self.num_rels = num_rels
        self.model = model
        self.seq_len = seq_len
        self.num_k = num_k
        self.rel_embeds = nn.Parameter(torch.Tensor(2 * num_rels, h_dim))
        nn.init.xavier_uniform_(
            self.rel_embeds, gain=nn.init.calculate_gain('relu'))

        self.ent_embeds = nn.Parameter(torch.Tensor(in_dim + 1, h_dim))
        nn.init.xavier_uniform_(
            self.ent_embeds, gain=nn.init.calculate_gain('relu'))

        self.dropout = nn.Dropout(dropout)
        self.sub_encoder = nn.GRU(3 * h_dim, h_dim, batch_first=True)
        self.ob_encoder = self.sub_encoder

        if model == 0:  # Attentive Aggregator
            self.aggregator_s = AttnAggregator(h_dim, dropout, seq_len)
        elif model == 1:  # Mean Aggregator
            self.aggregator_s = MeanAggregator(
                h_dim, dropout, seq_len, gcn=False)
        elif model == 2:  # Pooling Aggregator
            self.aggregator_s = MeanAggregator(
                h_dim, dropout, seq_len, gcn=True)
        elif model == 3:  # RGCN Aggregator
            self.aggregator_s = RGCNAggregator(h_dim, dropout, in_dim,
                                               num_rels, 100, model, seq_len)
        self.aggregator_o = self.aggregator_s

        self.linear_sub = nn.Linear(3 * h_dim, in_dim)
        self.linear_ob = self.linear_sub

        # For recording history in inference
        if model == 3:
            self.s_hist_test = None
            self.o_hist_test = None
            self.s_hist_test_t = None
            self.o_hist_test_t = None
            self.s_his_cache = None
            self.o_his_cache = None
            self.s_his_cache_t = None
            self.o_his_cache_t = None
            self.graph_dict = None
            self.data = None

        else:
            self.s_hist_test = None
            self.o_hist_test = None
            self.s_his_cache = None
            self.o_his_cache = None
        self.latest_time = 0

        self.criterion = nn.CrossEntropyLoss()
示例#3
0
    def __init__(self,
                 num_nodes,
                 h_dim,
                 num_rels,
                 dropout=0,
                 model=0,
                 seq_len=10,
                 num_k=10,
                 gamma=1):
        super(DArtNet, self).__init__()
        self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.num_rels = num_rels
        self.model = model
        self.seq_len = seq_len
        self.num_k = num_k
        self.gamma = gamma
        self.rel_embeds = nn.Parameter(torch.Tensor(num_rels, h_dim))
        nn.init.xavier_uniform_(self.rel_embeds,
                                gain=nn.init.calculate_gain('relu'))

        self.ent_embeds = nn.Parameter(torch.Tensor(num_nodes, h_dim))
        nn.init.xavier_uniform_(self.ent_embeds,
                                gain=nn.init.calculate_gain('relu'))
        self.ent_embeds_attribute = nn.Parameter(torch.Tensor(
            num_nodes, h_dim))
        nn.init.xavier_uniform_(self.ent_embeds_attribute,
                                gain=nn.init.calculate_gain('relu'))

        self.dropout = nn.Dropout(dropout)
        self.sub_encoder = nn.GRU(3 * h_dim, h_dim, batch_first=True)
        # self.ob_encoder = self.sub_encoder

        self.att_encoder = nn.GRU(3 * h_dim, h_dim, batch_first=True)

        self.aggregator_s = MeanAggregator(h_dim, dropout, seq_len)
        # self.aggregator_o = self.aggregator_s

        self.f1 = nn.Linear(2 * self.h_dim, 1)
        self.f2 = nn.Linear(3 * h_dim, num_nodes)

        self.W1 = nn.Linear(1, self.h_dim)
        # self.W2 = nn.Linear(2 * self.h_dim, self.h_dim)
        self.W3 = nn.Linear(3 * self.h_dim, self.h_dim)
        self.W4 = nn.Linear(2 * self.h_dim, self.h_dim)

        # For recording history in inference

        self.entity_s_his_test = None
        self.att_s_his_test = None
        self.rel_s_his_test = None
        self.self_att_s_his_test = None

        # self.entity_o_his_test = None
        # self.att_o_his_test = None
        # self.rel_o_his_test = None
        # self.self_att_o_his_test = None

        self.entity_s_his_cache = None
        self.att_s_his_cache = None
        self.rel_s_his_cache = None
        self.self_att_s_his_cache = None

        # self.entity_o_his_cache = None
        # self.att_o_his_cache = None
        # self.rel_o_his_cache = None
        # self.self_att_o_his_cache = None

        self.att_s_dict = {}
        # self.att_o_dict = {}

        self.latest_time = 0

        self.criterion = nn.CrossEntropyLoss()
        self.att_criterion = nn.MSELoss()