Exemplo n.º 1
0
    def test_graphlrp(self):
        """
        Test Graph LRP.
        """

        # 1. load a graph
        graph, _ = load_graphs(os.path.join(self.graph_path, self.graph_name))
        graph = graph[0]
        graph = set_graph_on_cuda(graph) if IS_CUDA else graph
        node_dim = graph.ndata['feat'].shape[1]

        # 2. create model
        config_fname = os.path.join(self.current_path, 'config',
                                    'cg_bracs_cggnn_3_classes_gin.yml')
        with open(config_fname, 'r') as file:
            config = yaml.safe_load(file)

        model = CellGraphModel(
            gnn_params=config['gnn_params'],
            classification_params=config['classification_params'],
            node_dim=node_dim,
            num_classes=3)

        # 2. run the explainer
        explainer = GraphLRPExplainer(model=model)
        importance_scores, logits = explainer.process(graph)

        # 3. tests
        self.assertIsInstance(importance_scores, np.ndarray)
        self.assertIsInstance(logits, np.ndarray)
        self.assertEqual(graph.number_of_nodes(), importance_scores.shape[0])
Exemplo n.º 2
0
    def test_pretrained_bracs_tggnn_3_classes_gin(self):
        """Test bracs_tggnn_3_classes_gin model."""

        # 1. Load a tissue graph
        graph, _ = load_graphs(os.path.join(self.graph_path, self.graph_name))
        graph = graph[0]
        graph = set_graph_on_cuda(graph) if IS_CUDA else graph
        node_dim = graph.ndata['feat'].shape[1]

        # 2. Load model with pre-trained weights
        config_fname = os.path.join(self.current_path, 'config',
                                    'tg_bracs_tggnn_3_classes_gin.yml')
        with open(config_fname, 'r') as file:
            config = yaml.safe_load(file)

        model = TissueGraphModel(
            gnn_params=config['gnn_params'],
            classification_params=config['classification_params'],
            node_dim=node_dim,
            num_classes=3,
            pretrained=True).to(DEVICE)

        # 4. forward pass
        logits = model(graph)

        self.assertIsInstance(logits, torch.Tensor)
        self.assertEqual(logits.shape[0], 1)
        self.assertEqual(logits.shape[1], 3)
Exemplo n.º 3
0
    def test_tissue_graph_model_with_batch(self):
        """Test tissue graph model with batch."""

        # 1. Load a cell graph
        graph, _ = load_graphs(os.path.join(self.graph_path, self.graph_name))
        graph = graph[0]
        graph = set_graph_on_cuda(graph) if IS_CUDA else graph
        node_dim = graph.ndata['feat'].shape[1]

        # 2. load config
        config_fname = os.path.join(self.current_path, 'config',
                                    'tg_model.yml')
        with open(config_fname, 'r') as file:
            config = yaml.safe_load(file)

        model = TissueGraphModel(
            gnn_params=config['gnn_params'],
            classification_params=config['classification_params'],
            node_dim=node_dim,
            num_classes=3).to(DEVICE)

        # 4. forward pass
        logits = model(dgl.batch([graph, graph]))

        self.assertIsInstance(logits, torch.Tensor)
        self.assertEqual(logits.shape[0], 2)
        self.assertEqual(logits.shape[1], 3)
    def test_hact_model_bracs_hact_5_classes_pna(self):
        """Test HACT bracs_hact_5_classes_pna model."""

        # 1. Load a cell graph
        cell_graph, _ = load_graphs(
            os.path.join(self.cg_graph_path, self.cg_graph_name))
        cell_graph = cell_graph[0]
        cell_graph = set_graph_on_cuda(cell_graph) if IS_CUDA else cell_graph
        cg_node_dim = cell_graph.ndata['feat'].shape[1]

        tissue_graph, _ = load_graphs(
            os.path.join(self.tg_graph_path, self.tg_graph_name))
        tissue_graph = tissue_graph[0]
        tissue_graph = set_graph_on_cuda(
            tissue_graph) if IS_CUDA else tissue_graph
        tg_node_dim = tissue_graph.ndata['feat'].shape[1]

        assignment_matrix = torch.randint(
            2, (tissue_graph.number_of_nodes(),
                cell_graph.number_of_nodes())).float()
        assignment_matrix = assignment_matrix.cuda(
        ) if IS_CUDA else assignment_matrix
        assignment_matrix = [assignment_matrix]  # ie. batch size is 1.

        # 2. load config and build model with pretrained weights
        config_fname = os.path.join(self.current_path, 'config',
                                    'bracs_hact_5_classes_pna.yml')
        with open(config_fname, 'r') as file:
            config = yaml.safe_load(file)

        model = HACTModel(
            cg_gnn_params=config['cg_gnn_params'],
            tg_gnn_params=config['tg_gnn_params'],
            classification_params=config['classification_params'],
            cg_node_dim=cg_node_dim,
            tg_node_dim=tg_node_dim,
            num_classes=5,
            pretrained=True).to(DEVICE)

        # 3. forward pass
        logits = model(cell_graph, tissue_graph, assignment_matrix)

        self.assertIsInstance(logits, torch.Tensor)
        self.assertEqual(logits.shape[0], 1)
        self.assertEqual(logits.shape[1], 5)
Exemplo n.º 5
0
def explain_cell_graphs(cell_graph_path, image_path):
    """
    Generate an explanation for all the cell graphs in cell path dir.
    """

    # 1. get cell graph & image paths
    cg_fnames = glob(os.path.join(cell_graph_path, '*.bin'))
    image_fnames = glob(os.path.join(image_path, '*.png'))

    # 2. create model
    config_fname = os.path.join(os.path.dirname(__file__), 'config',
                                'cg_bracs_cggnn_3_classes_gin.yml')
    with open(config_fname, 'r') as file:
        config = yaml.load(file)

    model = CellGraphModel(
        gnn_params=config['gnn_params'],
        classification_params=config['classification_params'],
        node_dim=NODE_DIM,
        num_classes=3,
        pretrained=True).to(DEVICE)

    # 3. define the explainer
    explainer = GraphGradCAMExplainer(model=model)

    # 4. define graph visualizer
    visualizer = OverlayGraphVisualization(
        instance_visualizer=InstanceImageVisualization(),
        colormap='jet',
        node_style="fill")

    # 5. process all the images
    for cg_path in tqdm(cg_fnames):

        # a. load the graph
        _, graph_name = os.path.split(cg_path)
        graph, _ = load_graphs(cg_path)
        graph = graph[0]
        graph = set_graph_on_cuda(graph) if IS_CUDA else graph

        # b. load corresponding image
        image_path = [
            x for x in image_fnames if graph_name in x.replace('.png', '.bin')
        ][0]
        _, image_name = os.path.split(image_path)
        image = np.array(Image.open(image_path))

        # c. run explainer
        importance_scores, _ = explainer.process(graph)

        # d. visualize and save the output
        node_attrs = {"color": importance_scores}
        canvas = visualizer.process(image, graph, node_attributes=node_attrs)
        canvas.save(os.path.join('output', 'explainer', image_name))