コード例 #1
0
ファイル: link_pred_tasker.py プロジェクト: NCTU-MLLab/temp
	def get_sample(self,idx,test, **kwargs):
		hist_adj_list = []
		hist_ndFeats_list = []
		hist_mask_list = []
		existing_nodes = []
		for i in range(idx - self.args.num_hist_steps, idx+1):
			cur_adj = tu.get_sp_adj(edges = self.data.edges, 
								   time = i,
								   weighted = True,
								   time_window = self.args.adj_mat_time_window)

			if self.args.smart_neg_sampling:
				existing_nodes.append(cur_adj['idx'].unique())
			else:
				existing_nodes = None

			node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

			node_feats = self.get_node_feats(cur_adj)

			cur_adj = tu.normalize_adj(adj = cur_adj, num_nodes = self.data.num_nodes)

			hist_adj_list.append(cur_adj)
			hist_ndFeats_list.append(node_feats)
			hist_mask_list.append(node_mask)

		# This would be if we were training on all the edges in the time_window
		label_adj = tu.get_sp_adj(edges = self.data.edges, 
								  time = idx+1,
								  weighted = False,
								  time_window =  self.args.adj_mat_time_window)
		if test:
			neg_mult = self.args.negative_mult_test
		else:
			neg_mult = self.args.negative_mult_training
			
		if self.args.smart_neg_sampling:
			existing_nodes = torch.cat(existing_nodes)

		
		if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True:
			non_exisiting_adj = tu.get_all_non_existing_edges(adj = label_adj, tot_nodes = self.data.num_nodes)
		else:
			non_exisiting_adj = tu.get_non_existing_edges(adj = label_adj,
													  number = label_adj['vals'].size(0) * neg_mult,
													  tot_nodes = self.data.num_nodes,
													  smart_sampling = self.args.smart_neg_sampling,
													  existing_nodes = existing_nodes)

		# label_adj = tu.get_sp_adj_only_new(edges = self.data.edges,
		# 								   weighted = False,
		# 								   time = idx)
		
		label_adj['idx'] = torch.cat([label_adj['idx'],non_exisiting_adj['idx']])
		label_adj['vals'] = torch.cat([label_adj['vals'],non_exisiting_adj['vals']])
		return {'idx': idx,
				'hist_adj_list': hist_adj_list,
				'hist_ndFeats_list': hist_ndFeats_list,
				'label_sp': label_adj,
				'node_mask_list': hist_mask_list}
コード例 #2
0
    def get_sample(self, idx, test):
        hist_adj_list = []
        hist_mask_list = []
        if self.data.node_feature:
            node_feature = self.data.node_feature
        else:
            node_feature = 1
        for i in range(idx - self.args.num_hist_steps, idx + 1):
            #all edgess included from the beginning
            cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                    time=i,
                                    weighted=True,
                                    time_window=self.args.adj_mat_time_window
                                    )  #changed this to keep only a time window

            cur_adj = tu.normalize_adj(adj=cur_adj,
                                       num_nodes=self.data.num_nodes)

            hist_adj_list.append(cur_adj)
            hist_mask_list.append(node_mask)

        label_adj = self.get_node_labels(idx)
        concate_adj = torch.sum(hist_adj_list)
        concate_adj[concate_adj > 0] = 1
        edge_feature = torch.cat(hist_adj_list, dim=0).permute(1, 2, 0)
        return {
            'idx': idx,
            'concate_adj': concate_adj,
            'edge_feature': edge_feature,
            'label_sp': label_adj,
            'node_feature': node_feature
        }
コード例 #3
0
ファイル: node_cls_tasker.py プロジェクト: NCTU-MLLab/temp
    def get_sample(self, idx, test):
        hist_adj_list = []
        hist_ndFeats_list = []
        hist_mask_list = []

        for i in range(idx - self.args.num_hist_steps, idx + 1):
            #all edgess included from the beginning
            cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                    time=i,
                                    weighted=True,
                                    time_window=self.args.adj_mat_time_window
                                    )  #changed this to keep only a time window

            node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

            node_feats = self.get_node_feats(i, cur_adj)

            cur_adj = tu.normalize_adj(adj=cur_adj,
                                       num_nodes=self.data.num_nodes)

            hist_adj_list.append(cur_adj)
            hist_ndFeats_list.append(node_feats)
            hist_mask_list.append(node_mask)

        label_adj = self.get_node_labels(idx)

        return {
            'idx': idx,
            'hist_adj_list': hist_adj_list,
            'hist_ndFeats_list': hist_ndFeats_list,
            'label_sp': label_adj,
            'node_mask_list': hist_mask_list
        }
コード例 #4
0
ファイル: edge_cls_tasker.py プロジェクト: NCTU-MLLab/temp
    def get_sample(self, idx, test):
        hist_adj_list = []
        hist_ndFeats_list = []
        hist_mask_list = []

        for i in range(idx - self.args.num_hist_steps, idx + 1):
            cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                    time=i,
                                    weighted=True,
                                    time_window=self.args.adj_mat_time_window)
            node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)
            node_feats = self.get_node_feats(cur_adj)
            cur_adj = tu.normalize_adj(adj=cur_adj,
                                       num_nodes=self.data.num_nodes)

            hist_adj_list.append(cur_adj)
            hist_ndFeats_list.append(node_feats)
            hist_mask_list.append(node_mask)

        label_adj = tu.get_edge_labels(edges=self.data.edges, time=idx)

        return {
            'idx': idx,
            'hist_adj_list': hist_adj_list,
            'hist_ndFeats_list': hist_ndFeats_list,
            'label_sp': label_adj,
            'node_mask_list': hist_mask_list
        }
コード例 #5
0
    def get_sample(self,idx,test):
        hist_adj_list = []
        hist_mask_list = []
        node_feature = torch.FloatTensor(self.data.nodes_feats)
        for i in range(idx - self.args.num_hist_steps+1, idx+1):
            #all edgess included from the beginning
            cur_adj = tu.get_sp_adj(edges = self.data.edges,
                                    time = i,
                                    weighted = True,
                                    time_window = self.args.adj_mat_time_window) #changed this to keep only a time window

            hist_adj_list.append(cur_adj)

        label_adj = self.get_node_labels(idx)

        return {'idx': idx,
                'edge_feature': hist_adj_list,
                'label_sp': label_adj,
                'node_feature': node_feature}
コード例 #6
0
    def get_sample(self, idx, test, **kwargs):
        hist_adj_list = []
        existing_nodes = []
        if self.args.fft:
            for i in range(idx + 1):
                cur_adj = tu.get_sp_adj(
                    edges=self.data.edges,
                    time=i,
                    weighted=True,
                    time_window=self.args.adj_mat_time_window)
                if self.args.smart_neg_sampling:
                    existing_nodes.append(cur_adj['idx'].unique())
                else:
                    existing_nodes = None

                # node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

                cur_adj = tu.normalize_adj(adj=cur_adj,
                                           num_nodes=self.data.num_nodes)
                cur_adj = torch.sparse.FloatTensor(
                    cur_adj['idx'].T, cur_adj['vals']).to_dense().numpy()
                hist_adj_list.append(cur_adj)
            hist_adj_list = np.concatenate(hist_adj_list).reshape(
                (-1, cur_adj.shape[0], cur_adj.shape[1]))
            #print(1, hist_adj_list.shape)
            f_adj = dctn(hist_adj_list, axes=0, norm='ortho')
            edge_feature = torch.from_numpy(
                f_adj[:self.args.num_hist_steps, :, :])
            #print(2, edge_feature.size())

        else:
            for i in range(idx - self.args.num_hist_steps + 1, idx + 1):
                cur_adj = tu.get_sp_adj(
                    edges=self.data.edges,
                    time=i,
                    weighted=True,
                    time_window=self.args.adj_mat_time_window)
                if self.args.smart_neg_sampling:
                    existing_nodes.append(cur_adj['idx'].unique())
                else:
                    existing_nodes = None

                #node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

                cur_adj = tu.normalize_adj(adj=cur_adj,
                                           num_nodes=self.data.num_nodes)
                cur_adj = torch.sparse.FloatTensor(cur_adj['idx'].T,
                                                   cur_adj['vals']).to_dense()
                hist_adj_list.append(cur_adj)

            edge_feature = torch.cat(hist_adj_list).view(
                -1, cur_adj.size(0), cur_adj.size(1))
        concate_adj = torch.sum(edge_feature, dim=0)
        edge_feature = edge_feature.permute(1, 2, 0)
        concate_adj[concate_adj > 0] = 1

        # This would be if we were training on all the edges in the time_window
        label_adj = tu.get_sp_adj(edges=self.data.edges,
                                  time=idx + 1,
                                  weighted=False,
                                  time_window=self.args.adj_mat_time_window)
        if test:
            neg_mult = self.args.negative_mult_test
        else:
            neg_mult = self.args.negative_mult_training

        if self.args.smart_neg_sampling:
            existing_nodes = torch.cat(existing_nodes)

        if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True:
            non_exisiting_adj = tu.get_all_non_existing_edges(
                adj=label_adj, tot_nodes=self.data.num_nodes)
        else:
            non_exisiting_adj = tu.get_non_existing_edges(
                adj=label_adj,
                number=label_adj['vals'].size(0) * neg_mult,
                tot_nodes=self.data.num_nodes,
                smart_sampling=self.args.smart_neg_sampling,
                existing_nodes=existing_nodes)

        label_adj['idx'] = torch.cat(
            [label_adj['idx'], non_exisiting_adj['idx']])
        label_adj['vals'] = torch.cat(
            [label_adj['vals'], non_exisiting_adj['vals']])
        return {
            'idx': idx,
            'concate_adj': concate_adj,
            'edge_feature': edge_feature,
            'label_sp': label_adj,
            'node_feature': 1
        }
コード例 #7
0
    def get_sample(self, idx, test, **kwargs):
        hist_adj_list = []
        existing_nodes = []
        if self.args.model == 'egnnc':
            if self.args.fft:
                for i in range(idx - self.args.num_hist_steps + 1, idx + 1):
                    cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                            time=i,
                                            weighted=True,
                                            time_window=self.args.adj_mat_time_window)
                    if self.args.smart_neg_sampling:
                        existing_nodes.append(cur_adj['idx'].unique())
                    else:
                        existing_nodes = None

                    # node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)
                    #cur_adj = tu.normalize_adj(adj=cur_adj, num_nodes=self.data.num_nodes)

                    cur_adj = torch.sparse.FloatTensor(cur_adj['idx'].t(),cur_adj['vals'].type(torch.float),torch.Size([self.data.num_nodes,self.data.num_nodes])).to_dense()
                    hist_adj_list.append(cur_adj)
                    
                hist_adj_list = torch.cat(hist_adj_list).view(-1,self.data.num_nodes,self.data.num_nodes,-1).numpy()
                #print(1, hist_adj_list.shape)
                edge_feature = u.DTFT(hist_adj_list, self.args.fft_num_steps)
                #print(2, edge_feature.size())
            elif self.args.dft:
                for i in range(idx - self.args.num_hist_steps + 1, idx + 1):
                    cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                            time=i,
                                            weighted=True,
                                            time_window=self.args.adj_mat_time_window)
                    if self.args.smart_neg_sampling:
                        existing_nodes.append(cur_adj['idx'].unique())
                    else:
                        existing_nodes = None

                    # node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

                    cur_adj = tu.normalize_adj(adj=cur_adj, num_nodes=self.data.num_nodes)
                    cur_adj = torch.sparse.FloatTensor(cur_adj['idx'].T, cur_adj['vals']).to_dense()
                    hist_adj_list.append(cur_adj)
                    
                hist_adj_list = torch.cat(hist_adj_list).view(-1,self.data.num_nodes,self.data.num_nodes).numpy()
                #print(1, hist_adj_list.shape)
                edge_feature = torch.from_numpy(dct(hist_adj_list, n=self.args.fft_num_steps, axis=0, norm='ortho'))
                #print(2, edge_feature.size())
            else:
                for i in range(idx - self.args.num_hist_steps + 1, idx + 1):
                    cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                            time=i,
                                            weighted=True,
                                            time_window=self.args.adj_mat_time_window)
                    if self.args.smart_neg_sampling:
                        existing_nodes.append(cur_adj['idx'].unique())
                    else:
                        existing_nodes = None

                    #node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

                    cur_adj = tu.normalize_adj(adj=cur_adj, num_nodes=self.data.num_nodes)
                    cur_adj = torch.sparse.FloatTensor(cur_adj['idx'].T, cur_adj['vals']).to_dense()
                    hist_adj_list.append(cur_adj)

                edge_feature = torch.cat(hist_adj_list).view(-1,self.data.num_nodes,self.data.num_nodes)
            edge_feature = edge_feature.permute(1, 2, 0)

            # This would be if we were training on all the edges in the time_window
            label_adj = tu.get_sp_adj(edges=self.data.edges,
                                    time=idx + 1,
                                    weighted=False,
                                    time_window=self.args.adj_mat_time_window)
            if test:
                neg_mult = self.args.negative_mult_test
            else:
                neg_mult = self.args.negative_mult_training

            if self.args.smart_neg_sampling:
                existing_nodes = torch.cat(existing_nodes)

            if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True:
                non_exisiting_adj = tu.get_all_non_existing_edges(adj=label_adj, tot_nodes=self.data.num_nodes)
            else:
                non_exisiting_adj = tu.get_non_existing_edges(adj=label_adj,
                                                            number=label_adj['vals'].size(0) * neg_mult,
                                                            tot_nodes=self.data.num_nodes,
                                                            smart_sampling=self.args.smart_neg_sampling,
                                                            existing_nodes=existing_nodes)


            label_adj['idx'] = torch.cat([label_adj['idx'], non_exisiting_adj['idx']])
            label_adj['vals'] = torch.cat([label_adj['vals'], non_exisiting_adj['vals']])
            return {'idx': idx,
                    'edge_feature': edge_feature,
                    'label_sp': label_adj,
                    'node_feature': 1}

        else:
            if self.args.fft:
                for i in range(idx - self.args.num_hist_steps + 1, idx + 1):
                    cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                            time=i,
                                            weighted=True,
                                            time_window=self.args.adj_mat_time_window)
                    if self.args.smart_neg_sampling:
                        existing_nodes.append(cur_adj['idx'].unique())
                    else:
                        existing_nodes = None

                    # node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

                    cur_adj = torch.sparse.FloatTensor(cur_adj['idx'].t(),cur_adj['vals'].type(torch.float),torch.Size([self.data.num_nodes,self.data.num_nodes])).to_dense()
                    hist_adj_list.append(cur_adj)
                hist_adj_list = torch.cat(hist_adj_list).view(self.data.num_nodes,self.data.num_nodes,-1).numpy()
                #print(1, hist_adj_list.shape)
                fft_hist_adj_list = u.DTFTSp(hist_adj_list, self.args.fft_num_steps)

                # This would be if we were training on all the edges in the time_window
                label_adj = tu.get_sp_adj(edges=self.data.edges,
                                        time=idx + 1,
                                        weighted=False,
                                        time_window=self.args.adj_mat_time_window)
                if test:
                    neg_mult = self.args.negative_mult_test
                else:
                    neg_mult = self.args.negative_mult_training

                if self.args.smart_neg_sampling:
                    existing_nodes = torch.cat(existing_nodes)

                if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True:
                    non_exisiting_adj = tu.get_all_non_existing_edges(adj=label_adj, tot_nodes=self.data.num_nodes)
                else:
                    non_exisiting_adj = tu.get_non_existing_edges(adj=label_adj,
                                                                number=label_adj['vals'].size(0) * neg_mult,
                                                                tot_nodes=self.data.num_nodes,
                                                                smart_sampling=self.args.smart_neg_sampling,
                                                                existing_nodes=existing_nodes)


                label_adj['idx'] = torch.cat([label_adj['idx'], non_exisiting_adj['idx']])
                label_adj['vals'] = torch.cat([label_adj['vals'], non_exisiting_adj['vals']])
                return {'idx': idx,
                        'edge_feature': fft_hist_adj_list,
                        'label_sp': label_adj,
                        'node_feature': 1}
            else:
                for i in range(idx - self.args.num_hist_steps + 1, idx + 1):
                        cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                                time=i,
                                                weighted=True,
                                                time_window=self.args.adj_mat_time_window)
                        if self.args.smart_neg_sampling:
                            existing_nodes.append(cur_adj['idx'].unique())
                        else:
                            existing_nodes = None

                        #node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

                        cur_adj = torch.sparse.LongTensor(cur_adj['idx'].T, cur_adj['vals'])
                        hist_adj_list.append(cur_adj)

                # This would be if we were training on all the edges in the time_window
                label_adj = tu.get_sp_adj(edges=self.data.edges,
                                        time=idx + 1,
                                        weighted=False,
                                        time_window=self.args.adj_mat_time_window)
                if test:
                    neg_mult = self.args.negative_mult_test
                else:
                    neg_mult = self.args.negative_mult_training

                if self.args.smart_neg_sampling:
                    existing_nodes = torch.cat(existing_nodes)

                if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True:
                    non_exisiting_adj = tu.get_all_non_existing_edges(adj=label_adj, tot_nodes=self.data.num_nodes)
                else:
                    non_exisiting_adj = tu.get_non_existing_edges(adj=label_adj,
                                                                number=label_adj['vals'].size(0) * neg_mult,
                                                                tot_nodes=self.data.num_nodes,
                                                                smart_sampling=self.args.smart_neg_sampling,
                                                                existing_nodes=existing_nodes)


                label_adj['idx'] = torch.cat([label_adj['idx'], non_exisiting_adj['idx']])
                label_adj['vals'] = torch.cat([label_adj['vals'], non_exisiting_adj['vals']])
                return {'idx': idx,
                        'edge_feature': hist_adj_list,
                        'label_sp': label_adj,
                        'node_feature': 1}
コード例 #8
0
    def build_get_node_feats(self, args, dataset):
        def get_node_feats(adj):  # input is cur_adj

            edgelist = adj['idx'].cpu().data.numpy()
            source = edgelist[:, 0]
            target = edgelist[:, 1]
            weight = np.ones(len(source))

            G = pd.DataFrame({
                'source': source,
                'target': target,
                'weight': weight
            })
            G = StellarGraph(edges=G)
            rw = BiasedRandomWalk(G)

            weighted_walks = rw.run(
                nodes=list(G.nodes()),  # root nodes
                length=2,  # maximum length of a random walk
                n=5,  # number of random walks per root node
                p=1,  # Defines (unormalised) probability, 1/p, of returning to source node
                q=0.5,  # Defines (unormalised) probability, 1/q, for moving away from source node
                weighted=True,  # for weighted random walks
                seed=42,  # random seed fixed for reproducibility
            )

            str_walks = [[str(n) for n in walk] for walk in weighted_walks]
            weighted_model = Word2Vec(str_walks,
                                      size=self.feats_per_node,
                                      window=5,
                                      min_count=0,
                                      sg=1,
                                      workers=1,
                                      iter=1)

            # Retrieve node embeddings and corresponding subjects
            node_ids = weighted_model.wv.index2word  # list of node IDs
            # change to integer
            for i in range(0, len(node_ids)):
                node_ids[i] = int(node_ids[i])

            weighted_node_embeddings = (
                weighted_model.wv.vectors
            )  # numpy.ndarray of size number of nodes times embeddings dimensionality

            # create dic
            dic = dict(zip(node_ids, weighted_node_embeddings.tolist()))
            # ascending order
            dic = dict(sorted(dic.items()))
            # create matrix
            adj_mat = sp.lil_matrix((self.data.num_nodes, self.feats_per_node))

            for row_idx in node_ids:
                adj_mat[row_idx, :] = dic[row_idx]

            adj_mat = adj_mat.tocsr()
            adj_mat = adj_mat.tocoo()
            coords = np.vstack((adj_mat.row, adj_mat.col)).transpose()
            values = adj_mat.data
            row = list(coords[:, 0])
            col = list(coords[:, 1])
            indexx = torch.LongTensor([row, col])
            tensor_size = torch.Size(
                [self.data.num_nodes, self.feats_per_node])
            degs_out = torch.sparse.FloatTensor(indexx,
                                                torch.FloatTensor(values),
                                                tensor_size)

            hot_1 = {
                'idx': degs_out._indices().t(),
                'vals': degs_out._values()
            }

            return hot_1

        # create dic
        feats_dic = {}

        for i in range(self.data.max_time):
            if i % 30 == 0:
                print('current i to make embeddings:', i)
            cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                    time=i,
                                    weighted=True,
                                    time_window=self.args.adj_mat_time_window)

            feats_dic[i] = get_node_feats(cur_adj)

        return feats_dic
コード例 #9
0
    def get_sample(self, idx, test, **kwargs):
        hist_adj_list = []
        hist_ndFeats_list = []
        hist_mask_list = []
        existing_nodes = []
        for i in range(idx - self.args.num_hist_steps, idx + 1):
            cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                    time=i,
                                    weighted=True,
                                    time_window=self.args.adj_mat_time_window)

            if self.args.smart_neg_sampling:
                existing_nodes.append(cur_adj['idx'].unique())
            else:
                existing_nodes = None

            node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

            # get node features from the dictionary (already created)
            node_feats = self.all_node_feats_dic[i]

            cur_adj = tu.normalize_adj(adj=cur_adj,
                                       num_nodes=self.data.num_nodes)

            hist_adj_list.append(cur_adj)
            hist_ndFeats_list.append(node_feats)
            hist_mask_list.append(node_mask)

        # This would be if we were training on all the edges in the time_window
        label_adj = tu.get_sp_adj(edges=self.data.edges,
                                  time=idx + 1,
                                  weighted=False,
                                  time_window=self.args.adj_mat_time_window)
        if test:
            neg_mult = self.args.negative_mult_test
        else:
            neg_mult = self.args.negative_mult_training

        if self.args.smart_neg_sampling:
            existing_nodes = torch.cat(existing_nodes)

        if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True:
            non_exisiting_adj = tu.get_all_non_existing_edges(
                adj=label_adj, tot_nodes=self.data.num_nodes)
        else:
            non_exisiting_adj = tu.get_non_existing_edges(
                adj=label_adj,
                number=label_adj['vals'].size(0) * neg_mult,
                tot_nodes=self.data.num_nodes,
                smart_sampling=self.args.smart_neg_sampling,
                existing_nodes=existing_nodes)

        # For football data, we need to sample due to memory constraints
        if self.args.sport == 'football':
            # Sampling label_adj
            num_sample = int(np.floor(len(label_adj['vals']) * 0.02))
            indice = random.sample(range(len(label_adj['vals'])), num_sample)
            indice = torch.LongTensor(indice)
            label_adj['idx'] = label_adj['idx'][indice, :]
            label_adj['vals'] = label_adj['vals'][indice]

            # Sampling non_exisiting_adj
            num_sample = int(np.floor(len(non_exisiting_adj['vals']) * 0.02))
            indice = random.sample(range(len(non_exisiting_adj['vals'])),
                                   num_sample)
            indice = torch.LongTensor(indice)
            non_exisiting_adj['idx'] = non_exisiting_adj['idx'][indice, :]
            non_exisiting_adj['vals'] = non_exisiting_adj['vals'][indice]

        all_len = len(label_adj['vals']) + len(non_exisiting_adj['vals'])
        pos = len(label_adj['vals']) / all_len
        neg = len(non_exisiting_adj['vals']) / all_len

        # if adapt, we use EXACT adaptive weights when contributing to the loss
        if self.args.adapt:
            weight = [pos, neg]
        else:
            weight = self.args.class_weights

        label_adj['idx'] = torch.cat(
            [label_adj['idx'], non_exisiting_adj['idx']])
        label_adj['vals'] = torch.cat(
            [label_adj['vals'], non_exisiting_adj['vals']])
        return {
            'idx': idx,
            'hist_adj_list': hist_adj_list,
            'hist_ndFeats_list': hist_ndFeats_list,
            'label_sp': label_adj,
            'node_mask_list': hist_mask_list,
            'weight': weight
        }