예제 #1
0
    def process(self):
        add_inverse_edge = self.meta_info[
            self.name]["add_inverse_edge"] == "True"

        if self.meta_info[self.name]["additional node files"] == 'None':
            additional_node_files = []
        else:
            additional_node_files = self.meta_info[
                self.name]["additional node files"].split(',')

        if self.meta_info[self.name]["additional edge files"] == 'None':
            additional_edge_files = []
        else:
            additional_edge_files = self.meta_info[
                self.name]["additional edge files"].split(',')

        if self.is_hetero:
            data = read_csv_heterograph_pyg(
                self.raw_dir,
                add_inverse_edge=add_inverse_edge,
                additional_node_files=additional_node_files,
                additional_edge_files=additional_edge_files)[0]
        else:
            data = read_csv_graph_pyg(
                self.raw_dir,
                add_inverse_edge=add_inverse_edge,
                additional_node_files=additional_node_files,
                additional_edge_files=additional_edge_files)[0]

        data = data if self.pre_transform is None else self.pre_transform(data)

        print('Saving...')
        torch.save(self.collate([data]), self.processed_paths[0])
예제 #2
0
    def process(self):
        add_inverse_edge = self.meta_info[
            self.name]["add_inverse_edge"] == "True"

        if self.meta_info[self.name]["additional node files"] == 'None':
            additional_node_files = []
        else:
            additional_node_files = self.meta_info[
                self.name]["additional node files"].split(',')

        if self.meta_info[self.name]["additional edge files"] == 'None':
            additional_edge_files = []
        else:
            additional_edge_files = self.meta_info[
                self.name]["additional edge files"].split(',')

        if self.is_hetero:
            data = read_csv_heterograph_pyg(
                self.raw_dir,
                add_inverse_edge=add_inverse_edge,
                additional_node_files=additional_node_files,
                additional_edge_files=additional_edge_files)[0]

            node_label_dict = read_node_label_hetero(self.raw_dir)

            data.y_dict = {}
            if "classification" in self.task_type:
                for nodetype, node_label in node_label_dict.items():
                    data.y_dict[nodetype] = torch.from_numpy(node_label).to(
                        torch.long)
            else:
                for nodetype, node_label in node_label_dict.items():
                    data.y_dict[nodetype] = torch.from_numpy(node_label).to(
                        torch.float32)

        else:
            data = read_csv_graph_pyg(
                self.raw_dir,
                add_inverse_edge=add_inverse_edge,
                additional_node_files=additional_node_files,
                additional_edge_files=additional_edge_files)[0]

            ### adding prediction target
            node_label = pd.read_csv(osp.join(self.raw_dir,
                                              'node-label.csv.gz'),
                                     compression="gzip",
                                     header=None).values

            if "classification" in self.task_type:
                data.y = torch.from_numpy(node_label).to(torch.long)
            else:
                data.y = torch.from_numpy(node_label).to(torch.float32)

        data if self.pre_transform is None else self.pre_transform(data)

        print('Saving...')
        torch.save(self.collate([data]), self.processed_paths[0])