예제 #1
0
    def get_sample(self, idx, test):
        hist_adj_list = []
        hist_ndFeats_list = []
        hist_mask_list = []

        for i in range(idx - self.args.num_hist_steps, idx + 1):
            #all edgess included from the beginning
            cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                    time=i,
                                    weighted=True,
                                    time_window=self.args.adj_mat_time_window
                                    )  #changed this to keep only a time window

            node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

            node_feats = self.get_node_feats(i, cur_adj)

            cur_adj = tu.normalize_adj(adj=cur_adj,
                                       num_nodes=self.data.num_nodes)

            hist_adj_list.append(cur_adj)
            hist_ndFeats_list.append(node_feats)
            hist_mask_list.append(node_mask)

        label_adj = self.get_node_labels(idx)

        return {
            'idx': idx,
            'hist_adj_list': hist_adj_list,
            'hist_ndFeats_list': hist_ndFeats_list,
            'label_sp': label_adj,
            'node_mask_list': hist_mask_list
        }
예제 #2
0
	def get_sample(self,idx,test, **kwargs):
		hist_adj_list = []
		hist_ndFeats_list = []
		hist_mask_list = []
		existing_nodes = []
		for i in range(idx - self.args.num_hist_steps, idx+1):
			cur_adj = tu.get_sp_adj(edges = self.data.edges, 
								   time = i,
								   weighted = True,
								   time_window = self.args.adj_mat_time_window)

			if self.args.smart_neg_sampling:
				existing_nodes.append(cur_adj['idx'].unique())
			else:
				existing_nodes = None

			node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

			node_feats = self.get_node_feats(cur_adj)

			cur_adj = tu.normalize_adj(adj = cur_adj, num_nodes = self.data.num_nodes)

			hist_adj_list.append(cur_adj)
			hist_ndFeats_list.append(node_feats)
			hist_mask_list.append(node_mask)

		# This would be if we were training on all the edges in the time_window
		label_adj = tu.get_sp_adj(edges = self.data.edges, 
								  time = idx+1,
								  weighted = False,
								  time_window =  self.args.adj_mat_time_window)
		if test:
			neg_mult = self.args.negative_mult_test
		else:
			neg_mult = self.args.negative_mult_training
			
		if self.args.smart_neg_sampling:
			existing_nodes = torch.cat(existing_nodes)

		
		if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True:
			non_exisiting_adj = tu.get_all_non_existing_edges(adj = label_adj, tot_nodes = self.data.num_nodes)
		else:
			non_exisiting_adj = tu.get_non_existing_edges(adj = label_adj,
													  number = label_adj['vals'].size(0) * neg_mult,
													  tot_nodes = self.data.num_nodes,
													  smart_sampling = self.args.smart_neg_sampling,
													  existing_nodes = existing_nodes)

		# label_adj = tu.get_sp_adj_only_new(edges = self.data.edges,
		# 								   weighted = False,
		# 								   time = idx)
		
		label_adj['idx'] = torch.cat([label_adj['idx'],non_exisiting_adj['idx']])
		label_adj['vals'] = torch.cat([label_adj['vals'],non_exisiting_adj['vals']])
		return {'idx': idx,
				'hist_adj_list': hist_adj_list,
				'hist_ndFeats_list': hist_ndFeats_list,
				'label_sp': label_adj,
				'node_mask_list': hist_mask_list}
예제 #3
0
    def get_sample(self, idx, test):
        hist_adj_list = []
        hist_ndFeats_list = []
        hist_mask_list = []

        for i in range(idx - self.args.num_hist_steps, idx + 1):
            cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                    time=i,
                                    weighted=True,
                                    time_window=self.args.adj_mat_time_window)
            node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)
            node_feats = self.get_node_feats(cur_adj)
            cur_adj = tu.normalize_adj(adj=cur_adj,
                                       num_nodes=self.data.num_nodes)

            hist_adj_list.append(cur_adj)
            hist_ndFeats_list.append(node_feats)
            hist_mask_list.append(node_mask)

        label_adj = tu.get_edge_labels(edges=self.data.edges, time=idx)

        return {
            'idx': idx,
            'hist_adj_list': hist_adj_list,
            'hist_ndFeats_list': hist_ndFeats_list,
            'label_sp': label_adj,
            'node_mask_list': hist_mask_list
        }
예제 #4
0
    def get_sample(self, idx, test, **kwargs):
        hist_adj_list = []
        hist_ndFeats_list = []
        hist_mask_list = []
        existing_nodes = []
        for i in range(idx - self.args.num_hist_steps, idx + 1):
            cur_adj = tu.get_sp_adj(edges=self.data.edges,
                                    time=i,
                                    weighted=True,
                                    time_window=self.args.adj_mat_time_window)

            if self.args.smart_neg_sampling:
                existing_nodes.append(cur_adj['idx'].unique())
            else:
                existing_nodes = None

            node_mask = tu.get_node_mask(cur_adj, self.data.num_nodes)

            # get node features from the dictionary (already created)
            node_feats = self.all_node_feats_dic[i]

            cur_adj = tu.normalize_adj(adj=cur_adj,
                                       num_nodes=self.data.num_nodes)

            hist_adj_list.append(cur_adj)
            hist_ndFeats_list.append(node_feats)
            hist_mask_list.append(node_mask)

        # This would be if we were training on all the edges in the time_window
        label_adj = tu.get_sp_adj(edges=self.data.edges,
                                  time=idx + 1,
                                  weighted=False,
                                  time_window=self.args.adj_mat_time_window)
        if test:
            neg_mult = self.args.negative_mult_test
        else:
            neg_mult = self.args.negative_mult_training

        if self.args.smart_neg_sampling:
            existing_nodes = torch.cat(existing_nodes)

        if 'all_edges' in kwargs.keys() and kwargs['all_edges'] == True:
            non_exisiting_adj = tu.get_all_non_existing_edges(
                adj=label_adj, tot_nodes=self.data.num_nodes)
        else:
            non_exisiting_adj = tu.get_non_existing_edges(
                adj=label_adj,
                number=label_adj['vals'].size(0) * neg_mult,
                tot_nodes=self.data.num_nodes,
                smart_sampling=self.args.smart_neg_sampling,
                existing_nodes=existing_nodes)

        # For football data, we need to sample due to memory constraints
        if self.args.sport == 'football':
            # Sampling label_adj
            num_sample = int(np.floor(len(label_adj['vals']) * 0.02))
            indice = random.sample(range(len(label_adj['vals'])), num_sample)
            indice = torch.LongTensor(indice)
            label_adj['idx'] = label_adj['idx'][indice, :]
            label_adj['vals'] = label_adj['vals'][indice]

            # Sampling non_exisiting_adj
            num_sample = int(np.floor(len(non_exisiting_adj['vals']) * 0.02))
            indice = random.sample(range(len(non_exisiting_adj['vals'])),
                                   num_sample)
            indice = torch.LongTensor(indice)
            non_exisiting_adj['idx'] = non_exisiting_adj['idx'][indice, :]
            non_exisiting_adj['vals'] = non_exisiting_adj['vals'][indice]

        all_len = len(label_adj['vals']) + len(non_exisiting_adj['vals'])
        pos = len(label_adj['vals']) / all_len
        neg = len(non_exisiting_adj['vals']) / all_len

        # if adapt, we use EXACT adaptive weights when contributing to the loss
        if self.args.adapt:
            weight = [pos, neg]
        else:
            weight = self.args.class_weights

        label_adj['idx'] = torch.cat(
            [label_adj['idx'], non_exisiting_adj['idx']])
        label_adj['vals'] = torch.cat(
            [label_adj['vals'], non_exisiting_adj['vals']])
        return {
            'idx': idx,
            'hist_adj_list': hist_adj_list,
            'hist_ndFeats_list': hist_ndFeats_list,
            'label_sp': label_adj,
            'node_mask_list': hist_mask_list,
            'weight': weight
        }