Exemplo n.º 1
0
    def test_hact_viz(self):
        """Test hierarchical visualization."""

        # 1. load the corresponding image
        image = np.array(
            Image.open(os.path.join(self.image_path, self.image_name)))

        # 2. load tissue graph
        tissue_graph, _ = load_graphs(
            os.path.join(self.tissue_graph_path, self.graph_name))
        tissue_graph = tissue_graph[0]

        # 3. load cell graph
        cell_graph, _ = load_graphs(
            os.path.join(self.cell_graph_path, self.graph_name))
        cell_graph = cell_graph[0]

        # 6. run the visualization
        visualizer = HACTVisualization()
        out = visualizer.process(
            image,
            cell_graph=cell_graph,
            tissue_graph=tissue_graph,
        )

        # 5. save output image
        out.save(os.path.join(
            self.out_path,
            self.image_name.replace(".png", "") + "_hierarchical_overlay.png",
        ),
                 quality=95)
Exemplo n.º 2
0
    def pre_process(self):
        processed_dir = os.path.join(self.root, 'processed')
        pre_processed_file_path = os.path.join(processed_dir, 'dgl_data_processed')

        if os.path.exists(pre_processed_file_path):
            self.graphs, label_dict = load_graphs(pre_processed_file_path)
            self.labels = label_dict['labels']
            self.ids = label_dict['ids']

        else:
            url = self.meta_info[self.name]["dgl url"]
            if decide_download(url):
                path = download_url(url, self.original_root)
                extract_zip(path, self.original_root)
                os.unlink(path)
                # delete folder if there exists
                try:
                    shutil.rmtree(self.root)
                except:
                    pass
                shutil.move(osp.join(self.original_root, self.download_name), self.root)
            else:
                print("Stop download.")
                exit(-1)

            self.graphs, label_dict = load_graphs(pre_processed_file_path)
            self.labels = label_dict['labels']
            self.ids = label_dict['ids']
Exemplo n.º 3
0
 def load(self):
     """ step 5 """
     graph_path = f'{self.save_path}/dgl_graph.bin'
     line_graph_path = f'{self.save_path}/dgl_line_graph.bin'
     self.graphs, label_dict = load_graphs(graph_path)
     self.line_graphs, _ = load_graphs(line_graph_path)
     self.label = torch.stack([label_dict[key] for key in self.label_keys],
                              dim=1)
Exemplo n.º 4
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'dgl_data_processed')

        if osp.exists(pre_processed_file_path):
            self.graph, label_dict = load_graphs(pre_processed_file_path)
            self.labels = label_dict['labels']

        else:
            ### check download
            if not osp.exists(osp.join(self.root, "raw", "edge.csv.gz")):
                url = self.meta_info[self.name]["url"]
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(osp.join(self.original_root, self.download_name), self.root)
                else:
                    print("Stop download.")
                    exit(-1)

            raw_dir = osp.join(self.root, "raw")

            ### pre-process and save
            add_inverse_edge = self.meta_info[self.name]["add_inverse_edge"] == "True"

            if self.meta_info[self.name]["additional node files"] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[self.name]["additional node files"].split(',')

            if self.meta_info[self.name]["additional edge files"] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[self.name]["additional edge files"].split(',')

            graph = read_csv_graph_dgl(raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)[0]

            ### adding prediction target
            node_label = pd.read_csv(osp.join(raw_dir, 'node-label.csv.gz'), compression="gzip", header = None).values
            if "classification" in self.task_type:
                node_label = torch.tensor(node_label, dtype = torch.long)
            else:
                node_label = torch.tensor(node_label, dtype = torch.float32)

            label_dict = {"labels": node_label}

            print('Saving...')
            save_graphs(pre_processed_file_path, graph, label_dict)

            self.graph, label_dict = load_graphs(pre_processed_file_path)
            self.labels = label_dict['labels']
Exemplo n.º 5
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'dgl_data_processed')

        if osp.exists(pre_processed_file_path):
            self.graph, _ = load_graphs(pre_processed_file_path)

        else:
            ### check download
            if not osp.exists(osp.join(self.root, "raw", "edge.csv.gz")):
                url = self.meta_info[self.name]["url"]
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print("Stop download.")
                    exit(-1)

            raw_dir = osp.join(self.root, "raw")

            add_inverse_edge = self.meta_info[
                self.name]["add_inverse_edge"] == "True"

            ### pre-process and save
            if self.meta_info[self.name]["additional node files"] == 'None':
                additional_node_files = []
            else:
                additional_node_files = self.meta_info[
                    self.name]["additional node files"].split(',')

            if self.meta_info[self.name]["additional edge files"] == 'None':
                additional_edge_files = []
            else:
                additional_edge_files = self.meta_info[
                    self.name]["additional edge files"].split(',')

            graph = read_csv_graph_dgl(
                raw_dir,
                add_inverse_edge=add_inverse_edge,
                additional_node_files=additional_node_files,
                additional_edge_files=additional_edge_files)[0]

            print('Saving...')
            save_graphs(pre_processed_file_path, graph, {})

            self.graph, _ = load_graphs(pre_processed_file_path)
Exemplo n.º 6
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        pre_processed_file_path = osp.join(processed_dir, 'dgl_data_processed')

        if osp.exists(pre_processed_file_path):
            self.graph, _ = load_graphs(pre_processed_file_path)

        else:
            ### check download
            if not osp.exists(osp.join(self.root, "raw", "edge.csv.gz")):
                url = self.meta_info[self.name]["url"]
                if decide_download(url):
                    path = download_url(url, self.original_root)
                    extract_zip(path, self.original_root)
                    os.unlink(path)
                    # delete folder if there exists
                    try:
                        shutil.rmtree(self.root)
                    except:
                        pass
                    shutil.move(
                        osp.join(self.original_root, self.download_name),
                        self.root)
                else:
                    print("Stop download.")
                    exit(-1)

            raw_dir = osp.join(self.root, "raw")

            file_names = ["edge"]
            if self.meta_info[self.name]["has_node_attr"] == "True":
                file_names.append("node-feat")
            if self.meta_info[self.name]["has_edge_attr"] == "True":
                file_names.append("edge-feat")
            raw_file_names = [
                file_name + ".csv.gz" for file_name in file_names
            ]

            ### pre-process and save
            add_inverse_edge = self.meta_info[
                self.name]["add_inverse_edge"] == "True"
            graph = read_csv_graph_dgl(raw_dir,
                                       raw_file_names,
                                       add_inverse_edge=add_inverse_edge)

            save_graphs(pre_processed_file_path, graph, {})

            self.graph, _ = load_graphs(pre_processed_file_path)
Exemplo n.º 7
0
 def __init__(self,
              root=None,
              structures=[None],
              targets=[None],
              cutoff=5.,
              transform=None):
     self.transform = transform
     self.cutoff = cutoff
     assert len(structures) == len(
         targets), "Number of structures unequal to y"
     self.structures = structures
     self.targets = targets
     self.root = root
     if self.root == None:
         self._load()
     else:
         if os.path.isdir(root) is False:
             os.mkdir(root)
         cache_path = os.path.join(root, "processed.bin")
         if os.path.isfile(cache_path):
             try:
                 graphs, labels = load_graphs(cache_path)
                 self.graphs = graphs
                 self.labels = labels['labels']
                 print(len(self.graphs), "loded!")
             except:
                 self._load()
                 labels_dict = {"labels": self.labels}
                 save_graphs(cache_path, self.graphs, labels_dict)
                 print(len(self.graphs), "saved!")
         else:
             self._load()
             labels_dict = {"labels": self.labels}
             save_graphs(cache_path, self.graphs, labels_dict)
             print(len(self.graphs), "saved!")
Exemplo n.º 8
0
def train(args):
    set_random_seed(args.seed)
    g = load_graphs(os.path.join(args.data_path, 'neighbor_graph.bin'))[0][0]
    feats = load_info(os.path.join(args.data_path, 'in_feats.pkl'))

    model = HetGNN(feats['author'].shape[-1], args.num_hidden, g.ntypes)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    neg_sampler = RatioNegativeSampler()
    for epoch in range(args.epochs):
        model.train()
        embeds = model(g, feats)
        score = model.calc_score(g, embeds)
        neg_g = construct_neg_graph(g, neg_sampler)
        neg_score = model.calc_score(neg_g, embeds)
        logits = torch.cat([score, neg_score])  # (2A*E,)
        labels = torch.cat(
            [torch.ones(score.shape[0]),
             torch.zeros(neg_score.shape[0])])
        loss = F.binary_cross_entropy_with_logits(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Epoch {:d} | Loss {:.4f}'.format(epoch, loss.item()))
    with torch.no_grad():
        final_embeds = model(g, feats)
        with open(args.save_node_embed_path, 'wb') as f:
            pickle.dump(final_embeds, f)
        print('Final node embeddings saved to', args.save_node_embed_path)
Exemplo n.º 9
0
    def __getitem__(self, item) -> Tuple[DGLGraph, torch.Tensor]:
        graph_filename, label_filename, start_index, end_index = self.batch_description[
            item]

        if label_filename != self.current_label_file:
            self.current_label_file = label_filename
            with open(label_filename, 'rb') as pkl_file:
                pkl_data = load(pkl_file)
            self.labels = torch.tensor(pkl_data['labels'].T,
                                       device=self.device).detach()

        graphs, _ = load_graphs(graph_filename,
                                list(range(start_index, end_index)))
        graphs, mask = zip(*[(g, i) for i, g in enumerate(graphs)
                             if self._is_tree_suitable(g)])

        if self.invert_edges:
            graphs = [g.reverse(share_ndata=True) for g in graphs]

        graph = batch(graphs)
        graph.ndata['token'] = graph.ndata['token'].to(self.device).detach()
        graph.ndata['type'] = graph.ndata['type'].to(self.device).detach()
        # [sequence len, batch size]
        labels = self.labels[:, start_index:end_index][:, list(mask)]

        return graph, labels
Exemplo n.º 10
0
def read_graph(graph_path):
    """
    Read graph data from path.
    """
    graph_list, _ = load_graphs(graph_path)
    graph = graph_list[0]
    return graph
Exemplo n.º 11
0
 def _process(  # type: ignore[override]
         self, path: Union[str, Path]) -> dgl.DGLGraph:
     graph_path = str(path)  # DGL cannot handle pathlib.Path
     graphs, _ = load_graphs(graph_path)
     if len(graphs) == 1:
         return graphs[0]
     return graphs
Exemplo n.º 12
0
    def read_ast(self, id_):
        try:
            graph, _ = load_graphs(os.path.join(self.functions_path, str(id_)))
        except Exception as ex:
            return None

        return graph[0]
Exemplo n.º 13
0
def test_graph_serialize_without_feature():
    num_graphs = 100
    g_list = [generate_rand_graph(30) for _ in range(num_graphs)]

    # create a temporary file and immediately release it so DGL can open it.
    f = tempfile.NamedTemporaryFile(delete=False)
    path = f.name
    f.close()

    save_graphs(path, g_list)

    idx_list = np.random.permutation(np.arange(num_graphs)).tolist()
    loadg_list, _ = load_graphs(path, idx_list)

    idx = idx_list[0]
    load_g = loadg_list[0]

    assert F.allclose(load_g.nodes(), g_list[idx].nodes())

    load_edges = load_g.all_edges('uv', 'eid')
    g_edges = g_list[idx].all_edges('uv', 'eid')
    assert F.allclose(load_edges[0], g_edges[0])
    assert F.allclose(load_edges[1], g_edges[1])

    os.unlink(path)
Exemplo n.º 14
0
 def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
     if args['preprocess']['dataset_impl'] == "raw":
         in_file = file_name(input_prefix, lang)
         out_dir = args['preprocess']['destdir']
         os.makedirs(out_dir, exist_ok=True)
         LOGGER.info('Copying {} into {}'.format(in_file, out_dir))
         shutil.copy(src=in_file, dst=args['preprocess']['destdir'])
     else:
         in_file = file_name(input_prefix, lang)
         out_file = dest_path(output_prefix, lang)
         os.makedirs(os.path.dirname(out_file), exist_ok=True)
         offsets = find_offsets(in_file, num_workers)
         with Pool(num_workers) as mpool:
             results = [
                 mpool.apply_async(
                     build_dgl_graph,
                     (vocab, in_file, f'{out_file}{worker_id}.mmap',
                      offsets[worker_id], offsets[worker_id + 1]),
                 ) for worker_id in range(num_workers)
             ]
             results = [res.get() for res in results]
         graph_batch = []
         for worker_id in range(num_workers):
             sub_file = f'{out_file}{worker_id}.mmap'
             glist, _ = load_graphs(sub_file)
             graph_batch.extend(glist)
             os.remove(sub_file)
         save_graphs(f'{out_file}.mmap', graph_batch)
Exemplo n.º 15
0
    def test_pretrained_bracs_tggnn_3_classes_gin(self):
        """Test bracs_tggnn_3_classes_gin model."""

        # 1. Load a tissue graph
        graph, _ = load_graphs(os.path.join(self.graph_path, self.graph_name))
        graph = graph[0]
        graph = set_graph_on_cuda(graph) if IS_CUDA else graph
        node_dim = graph.ndata['feat'].shape[1]

        # 2. Load model with pre-trained weights
        config_fname = os.path.join(self.current_path, 'config',
                                    'tg_bracs_tggnn_3_classes_gin.yml')
        with open(config_fname, 'r') as file:
            config = yaml.safe_load(file)

        model = TissueGraphModel(
            gnn_params=config['gnn_params'],
            classification_params=config['classification_params'],
            node_dim=node_dim,
            num_classes=3,
            pretrained=True).to(DEVICE)

        # 4. forward pass
        logits = model(graph)

        self.assertIsInstance(logits, torch.Tensor)
        self.assertEqual(logits.shape[0], 1)
        self.assertEqual(logits.shape[1], 3)
Exemplo n.º 16
0
def read_input(input_folder):
    X = pd.read_csv(f'{input_folder}/X.csv')
    y = pd.read_csv(f'{input_folder}/y.csv')

    categorical_columns = []
    if os.path.exists(f'{input_folder}/cat_features.txt'):
        with open(f'{input_folder}/cat_features.txt') as f:
            for line in f:
                if line.strip():
                    categorical_columns.append(line.strip())

    cat_features = None
    if categorical_columns:
        columns = X.columns
        cat_features = np.where(columns.isin(categorical_columns))[0]

        for col in list(columns[cat_features]):
            X[col] = X[col].astype(str)

    gs, _ = load_graphs(f'{input_folder}/graph.dgl')
    graph = gs[0]

    with open(f'{input_folder}/masks.json') as f:
        masks = json.load(f)

    return graph, X, y, cat_features, masks
Exemplo n.º 17
0
    def test_graphlrp(self):
        """
        Test Graph LRP.
        """

        # 1. load a graph
        graph, _ = load_graphs(os.path.join(self.graph_path, self.graph_name))
        graph = graph[0]
        graph = set_graph_on_cuda(graph) if IS_CUDA else graph
        node_dim = graph.ndata['feat'].shape[1]

        # 2. create model
        config_fname = os.path.join(self.current_path, 'config',
                                    'cg_bracs_cggnn_3_classes_gin.yml')
        with open(config_fname, 'r') as file:
            config = yaml.safe_load(file)

        model = CellGraphModel(
            gnn_params=config['gnn_params'],
            classification_params=config['classification_params'],
            node_dim=node_dim,
            num_classes=3)

        # 2. run the explainer
        explainer = GraphLRPExplainer(model=model)
        importance_scores, logits = explainer.process(graph)

        # 3. tests
        self.assertIsInstance(importance_scores, np.ndarray)
        self.assertIsInstance(logits, np.ndarray)
        self.assertEqual(graph.number_of_nodes(), importance_scores.shape[0])
Exemplo n.º 18
0
    def test_tissue_graph_model_with_batch(self):
        """Test tissue graph model with batch."""

        # 1. Load a cell graph
        graph, _ = load_graphs(os.path.join(self.graph_path, self.graph_name))
        graph = graph[0]
        graph = set_graph_on_cuda(graph) if IS_CUDA else graph
        node_dim = graph.ndata['feat'].shape[1]

        # 2. load config
        config_fname = os.path.join(self.current_path, 'config',
                                    'tg_model.yml')
        with open(config_fname, 'r') as file:
            config = yaml.safe_load(file)

        model = TissueGraphModel(
            gnn_params=config['gnn_params'],
            classification_params=config['classification_params'],
            node_dim=node_dim,
            num_classes=3).to(DEVICE)

        # 4. forward pass
        logits = model(dgl.batch([graph, graph]))

        self.assertIsInstance(logits, torch.Tensor)
        self.assertEqual(logits.shape[0], 2)
        self.assertEqual(logits.shape[1], 3)
Exemplo n.º 19
0
 def load(self):
     graphs, _ = load_graphs(os.path.join(self.save_path, self.name + '_dgl_graph.bin'))
     self.g = graphs[0]
     info = load_info(join(self.save_path, self.name + '_info.pkl'))
     self.author_names = info['author_names']
     self.paper_titles = info['paper_titles']
     self.conf_names = info['conf_names']
Exemplo n.º 20
0
def test_serialize_heterograph():
    f = tempfile.NamedTemporaryFile(delete=False)
    path = f.name
    f.close()
    g_list0 = create_heterographs2("int64") + create_heterographs2("int32")
    save_graphs(path, g_list0)

    g_list, _ = load_graphs(path)
    assert g_list[0].idtype == F.int64
    assert len(g_list[0].canonical_etypes) == 3
    for i in range(len(g_list0)):
        for j, etypes in enumerate(g_list0[i].canonical_etypes):
            assert g_list[i].canonical_etypes[j] == etypes
    assert g_list[1].restrict_format() == 'any'
    assert g_list[2].restrict_format() == 'csr'
    assert g_list[3].idtype == F.int32
    assert np.allclose(F.asnumpy(g_list[2].nodes['user'].data['hh']),
                       np.ones((4, 5)))
    assert np.allclose(F.asnumpy(g_list[5].nodes['user'].data['hh']),
                       np.ones((4, 5)))
    edges = g_list[0]['follows'].edges()
    assert np.allclose(F.asnumpy(edges[0]), np.array([0, 1, 2]))
    assert np.allclose(F.asnumpy(edges[1]), np.array([1, 2, 3]))
    for i in range(len(g_list)):
        assert g_list[i].ntypes == g_list0[i].ntypes
        assert g_list[i].etypes == g_list0[i].etypes

    os.unlink(path)
Exemplo n.º 21
0
def test_graph_serialize_with_labels(is_hetero):
    num_graphs = 100
    g_list = [generate_rand_graph(30, is_hetero) for _ in range(num_graphs)]
    labels = {"label": F.zeros((num_graphs, 1))}

    # create a temporary file and immediately release it so DGL can open it.
    f = tempfile.NamedTemporaryFile(delete=False)
    path = f.name
    f.close()

    save_graphs(path, g_list, labels)

    idx_list = np.random.permutation(np.arange(num_graphs)).tolist()
    loadg_list, l_labels0 = load_graphs(path, idx_list)
    l_labels = load_labels(path)
    assert F.allclose(l_labels['label'], labels['label'])
    assert F.allclose(l_labels0['label'], labels['label'])

    idx = idx_list[0]
    load_g = loadg_list[0]

    assert F.allclose(load_g.nodes(), g_list[idx].nodes())

    load_edges = load_g.all_edges('uv', 'eid')
    g_edges = g_list[idx].all_edges('uv', 'eid')
    assert F.allclose(load_edges[0], g_edges[0])
    assert F.allclose(load_edges[1], g_edges[1])

    os.unlink(path)
Exemplo n.º 22
0
 def load(self):
     graphs, _ = load_graphs(
         os.path.join(self.save_path, self.name + '_dgl_graph.bin'))
     self.g = graphs[0]
     for k in ('train_mask', 'val_mask', 'test_mask'):
         self.g.nodes['movie'].data[k] = self.g.nodes['movie'].data[k].bool(
         )
Exemplo n.º 23
0
 def _process_and_save(  # type: ignore[override]
     self,
     instance_map: np.ndarray,
     features: torch.Tensor,
     annotation: Optional[np.ndarray] = None,
     output_name: str = None,
 ) -> dgl.DGLGraph:
     """Process and save in provided directory
     Args:
         output_name (str): Name of output file
         instance_map (np.ndarray): Instance map depicting tissue components
                                    (eg nuclei, tissue superpixels)
         features (torch.Tensor): Features of each node. Shape (nr_nodes, nr_features)
         annotation (Optional[np.ndarray], optional): Optional node level to include.
                                                      Defaults to None.
     Returns:
         dgl.DGLGraph: [description]
     """
     assert (
         self.save_path is not None
     ), "Can only save intermediate output if base_path was not None during construction"
     output_path = self.output_dir / f"{output_name}.bin"
     if output_path.exists():
         logging.info(
             f"Output of {output_name} already exists, using it instead of recomputing"
         )
         graphs, _ = load_graphs(str(output_path))
         assert len(graphs) == 1
         graph = graphs[0]
     else:
         graph = self._process(instance_map=instance_map,
                               features=features,
                               annotation=annotation)
         save_graphs(str(output_path), [graph])
     return graph
Exemplo n.º 24
0
 def load(self):
     graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
     graphs, _ = load_graphs(graph_path)
     self._graph = graphs[0]
     self._graph.ndata['train_mask'] = generate_mask_tensor(self._graph.ndata['train_mask'].numpy())
     self._graph.ndata['val_mask'] = generate_mask_tensor(self._graph.ndata['val_mask'].numpy())
     self._graph.ndata['test_mask'] = generate_mask_tensor(self._graph.ndata['test_mask'].numpy())
     self._print_info()
Exemplo n.º 25
0
def convert_holdout(holdout_path: str, output_path: str, batch_size: int,
                    token_to_id: Dict, type_to_id: Dict, label_to_id: Dict,
                    tokens_to_leaves: bool = False, is_split: bool = False,
                    max_token_len: int = -1, max_label_len: int = -1, wrap_tokens: bool = False,
                    wrap_labels: bool = False, delimiter: str = '|', shuffle: bool = True, n_jobs: int = -1) -> str:
    log_file = os.path.join('logs', f"convert_{datetime.now().strftime('%Y_%m_%d_%H:%M:%S')}.txt")

    projects_paths = [os.path.join(holdout_path, project, 'java') for project in os.listdir(holdout_path)]

    print("converting projects...")
    for project_path in tqdm(projects_paths):
        print(f"converting {project_path}")
        _convert_project_safe(
            project_path, log_file, token_to_id=token_to_id, type_to_id=type_to_id, label_to_id=label_to_id,
            token_to_leaves=tokens_to_leaves, is_split=is_split, max_token_len=max_token_len,
            max_label_len=max_label_len, wrap_tokens=wrap_tokens, wrap_labels=wrap_labels, delimiter=delimiter,
            n_jobs=n_jobs
        )

    graphs = []
    labels = []
    source_paths = []
    print("load graphs to memory...")
    for project_path in tqdm(projects_paths):
        graph_path = os.path.join(project_path, 'converted.dgl')
        labels_path = os.path.join(project_path, 'converted.pkl')
        if not os.path.exists(graph_path) or not os.path.exists(labels_path):
            with open(log_file, 'a') as file:
                file.write(f"can't load graphs for {project_path} project\n")
            continue
        cur_graphs, _ = load_graphs(graph_path)
        with open(labels_path, 'rb') as pkl_file:
            pkl_data = load(pkl_file)
            cur_labels, cur_source_paths = pkl_data['labels'], pkl_data['source_paths']
        graphs += cur_graphs
        labels += cur_labels.tolist()
        source_paths += cur_source_paths.tolist()
    assert len(graphs) == len(labels), "unequal lengths of graphs and labels"
    assert len(graphs) == len(source_paths), "unequal lengths of graphs and source paths"
    print(f"total number of graphs: {len(graphs)}")

    if shuffle:
        order = np.random.permutation(len(graphs))
        graphs = [graphs[i] for i in order]
        labels = [labels[i] for i in order]
        source_paths = [source_paths[i] for i in order]

    print(f"save batches...")
    n_batches = len(graphs) // batch_size + (1 if len(graphs) % batch_size > 0 else 0)
    for batch_num in tqdm(range(n_batches)):
        current_slice = slice(batch_num * batch_size, min((batch_num + 1) * batch_size, len(graphs)))
        output_graph_path = os.path.join(output_path, f'batch_{batch_num}.dgl')
        output_labels_path = os.path.join(output_path, f'batch_{batch_num}.pkl')
        save_graphs(output_graph_path, graphs[current_slice])
        with open(output_labels_path, 'wb') as pkl_file:
            dump({
                'labels': np.array(labels[current_slice]), 'source_paths': source_paths[current_slice]
            }, pkl_file)
    def test_hact_model_bracs_hact_5_classes_pna(self):
        """Test HACT bracs_hact_5_classes_pna model."""

        # 1. Load a cell graph
        cell_graph, _ = load_graphs(
            os.path.join(self.cg_graph_path, self.cg_graph_name))
        cell_graph = cell_graph[0]
        cell_graph = set_graph_on_cuda(cell_graph) if IS_CUDA else cell_graph
        cg_node_dim = cell_graph.ndata['feat'].shape[1]

        tissue_graph, _ = load_graphs(
            os.path.join(self.tg_graph_path, self.tg_graph_name))
        tissue_graph = tissue_graph[0]
        tissue_graph = set_graph_on_cuda(
            tissue_graph) if IS_CUDA else tissue_graph
        tg_node_dim = tissue_graph.ndata['feat'].shape[1]

        assignment_matrix = torch.randint(
            2, (tissue_graph.number_of_nodes(),
                cell_graph.number_of_nodes())).float()
        assignment_matrix = assignment_matrix.cuda(
        ) if IS_CUDA else assignment_matrix
        assignment_matrix = [assignment_matrix]  # ie. batch size is 1.

        # 2. load config and build model with pretrained weights
        config_fname = os.path.join(self.current_path, 'config',
                                    'bracs_hact_5_classes_pna.yml')
        with open(config_fname, 'r') as file:
            config = yaml.safe_load(file)

        model = HACTModel(
            cg_gnn_params=config['cg_gnn_params'],
            tg_gnn_params=config['tg_gnn_params'],
            classification_params=config['classification_params'],
            cg_node_dim=cg_node_dim,
            tg_node_dim=tg_node_dim,
            num_classes=5,
            pretrained=True).to(DEVICE)

        # 3. forward pass
        logits = model(cell_graph, tissue_graph, assignment_matrix)

        self.assertIsInstance(logits, torch.Tensor)
        self.assertEqual(logits.shape[0], 1)
        self.assertEqual(logits.shape[1], 5)
Exemplo n.º 27
0
    def pre_process(self):
        processed_dir = osp.join(self.root, 'processed')
        raw_dir = osp.join(self.root, 'raw')
        pre_processed_file_path = osp.join(processed_dir, 'dgl_data_processed')

        if os.path.exists(pre_processed_file_path):
            self.graphs, label_dict = load_graphs(pre_processed_file_path)
            self.labels = label_dict['labels']

        else:
            ### download
            url = self.meta_info[self.name]["url"]
            if decide_download(url):
                path = download_url(url, self.original_root)
                extract_zip(path, self.original_root)
                os.unlink(path)
                # delete folder if there exists
                try:
                    shutil.rmtree(self.root)
                except:
                    pass
                shutil.move(osp.join(self.original_root, self.download_name),
                            self.root)
            else:
                print("Stop download.")
                exit(-1)

            ### preprocess
            add_inverse_edge = self.meta_info[
                self.name]["add_inverse_edge"] == "True"
            graphs = read_csv_graph_dgl(raw_dir,
                                        add_inverse_edge=add_inverse_edge)
            labels = torch.tensor(
                pd.read_csv(osp.join(raw_dir, "graph-label.csv.gz"),
                            compression="gzip",
                            header=None).values)

            print('Saving...')
            save_graphs(pre_processed_file_path,
                        graphs,
                        labels={'labels': labels})

            ### load preprocessed files
            self.graphs, label_dict = load_graphs(pre_processed_file_path)
            self.labels = label_dict['labels']
Exemplo n.º 28
0
 def get_hetgnn_graph(self, length, walks, restart_prob):
     fname = './{}_het.bin'.format(self.n_dataset)
     if os.path.exists(fname):
         g, _ = load_graphs(fname)
         return g[0]
     else:
         g = self.build_hetgnn_graph(length, walks, restart_prob)
         save_graphs(fname, g)
         return g
Exemplo n.º 29
0
 def _load(self):
     if self.load_mode == 'read':
         self.graphs, tour_dict = load_graphs(self.file_name)
         self.tours = tour_dict['tsp_tours']
     else:
         print('Start generating dataset...')
         self.graphs, tour_dict = self.generate_data(
             self.num_samples, self.num_nodes, self.node_dim,
             self.file_name, self.seed)
         self.tours = tour_dict['tsp_tours']
Exemplo n.º 30
0
    def load(self):
        graph_path = os.path.join(self.save_path, 'tu_{}.bin'.format(self.name))
        info_path = os.path.join(self.save_path, 'tu_{}.pkl'.format(self.name))
        graphs, label_dict = load_graphs(str(graph_path))
        info_dict = load_info(str(info_path))

        self.graph_lists = graphs
        self.graph_labels = label_dict['labels']
        self.max_num_node = info_dict['max_num_node']
        self.num_labels = info_dict['num_labels']