Exemple #1
0
    def __init__(self) -> None:
        super().__init__()

        self.pre_process_cnn = PreProcessCNN()

        def first_node_update_fn(node: Dict[str, any]) -> Attribute:
            desktop_img: torch.Tensor = node['desktop_img']

            hidden_layer_out = self.pre_process_cnn(desktop_img.unsqueeze(0))

            return hidden_layer_out

        self.pre_process_block = GNBlock(
            phi_v=IndependentNodeUpdate(first_node_update_fn))

        self.main_block = GNBlock(phi_e=SenderGlobalAveragePoolingEdgeUpdate(),
                                  phi_v=MainNodeUpdate(),
                                  rho_ev=AvgAggregation())

        self.extraction_block = GNBlock(
            phi_e=SenderGlobalAveragePoolingEdgeUpdate(),
            phi_v=ExtractionNodeUpdate(),
            phi_u=NodeAggregationGlobalStateUpdate(),
            rho_ev=AvgAggregation(),
            rho_vu=AvgAggregation())
Exemple #2
0
    def __init__(self, num_core_blocks: int, drop_p: float):
        super().__init__()

        self.num_core_blocks = num_core_blocks
        self.drop_p = drop_p

        self.screenshot_feature_extractor = ScreenshotsFeatureExtractor(drop_p)

        def node_extractor_fn(node: Dict[str, any]) -> Tensor:
            desktop_img: Tensor = node['desktop_img']
            mobile_img: Tensor = node['mobile_img']

            # desktop and mobile feature vector
            x1, x2 = self.screenshot_feature_extractor(
                desktop_img.unsqueeze(0), mobile_img.unsqueeze(0))

            x = torch.cat((x1, x2), dim=1).view(-1)

            return x

        self.extraction_block = GNBlock(
            phi_v=IndependentNodeUpdate(node_extractor_fn))

        self.encoder = GNBlock(phi_e=EncoderEdgeUpdate(),
                               phi_u=EncoderGlobalStateUpdate(),
                               rho_eu=AvgAggregation())
        self.core = GNBlock(phi_e=CoreEdgeUpdate(self.drop_p),
                            phi_v=CoreNodeUpdate(self.drop_p),
                            phi_u=CoreGlobalStateUpdate(self.drop_p),
                            rho_ev=AvgAggregation(),
                            rho_vu=AvgAggregation(),
                            rho_eu=AvgAggregation())
        self.decoder = GNBlock(phi_u=DecoderGlobalStateUpdate())
Exemple #3
0
    def __init__(self,
                 drop_p: float,
                 num_core_blocks: int,
                 edge_mode: str,
                 shared_weights: bool = False):
        """
        Deep graph network for domain rank estimation.
        :param drop_p: Dropout probability
        :param num_core_blocks: Number of stacked core blocks, >= 0
        :param edge_mode: Whether to keep the graph edges, remove them altogether, or make them bi-directional. In any
                          case, the existence of reflexive edges is ensured.
        :param shared_weights:
        """

        super().__init__()

        self.drop_p = drop_p

        self.edge_fns = {
            'default': GNDeep.default,
            'bi_directional': GNDeep.bi_directional,
            'no_edges': GNDeep.no_edges,
            'all_edges': GNDeep.all_edges
        }

        assert edge_mode in self.edge_fns, "Invalid edge mode; not in [default, bi_directional, no_edges, all_edges]"
        self.edge_mode = edge_mode

        self.enc = GNBlock(phi_e=EncoderEdgeUpdate(),
                           phi_u=EncoderGlobalStateUpdate(),
                           rho_eu=AvgAggregation())

        assert num_core_blocks >= 0

        core_blocks = []
        for i in range(num_core_blocks):
            if shared_weights and i > 0:
                block = core_blocks[0]
            else:
                block = GNBlock(phi_e=CoreEdgeUpdate(self.drop_p),
                                phi_v=CoreNodeUpdate(self.drop_p),
                                phi_u=CoreGlobalStateUpdate(self.drop_p),
                                rho_ev=AvgAggregation(),
                                rho_vu=AvgAggregation(),
                                rho_eu=AvgAggregation())
            core_blocks.append(block)
        self.core_blocks = ListModule(*core_blocks)

        self.dec = GNBlock(phi_u=DecoderGlobalStateUpdate()
                           )  # maps global state from vec to scalar
    def test_basic(self):
        """
        Basic test w/o PyTorch, all attributes are scalars, edges do not have attributes.
        Feeds a graph through a basic graph block twice and compares to the target values after both passes.
        """

        # create data structure
        v_0, v_1, v_2 = Node(Attribute(1)), Node(Attribute(10)), Node(
            Attribute(20))
        vs = [v_0, v_1, v_2]  # nodes
        es = [Edge(v_0, v_1), Edge(v_0, v_2), Edge(v_1, v_2)]
        g_0 = Graph(nodes=vs, edges=es, attr=Attribute(0))

        # create block w/ functions
        block = GNBlock(phi_e=SenderIdentityEdgeUpdate(),
                        phi_v=EdgeNodeSumNodeUpdate(),
                        phi_u=MixedGlobalStateUpdate(),
                        rho_ev=ScalarSumAggregation(),
                        rho_vu=ScalarSumAggregation(),
                        rho_eu=ScalarSumAggregation())

        g_1 = block(g_0)

        v_0, v_1, v_2 = Node(Attribute(1)), Node(Attribute(10 + 1)), Node(
            Attribute(20 + 11))
        vs = [v_0, v_1, v_2]  # nodes
        es = [
            Edge(v_0, v_1, Attribute(1)),
            Edge(v_0, v_2, Attribute(1)),
            Edge(v_1, v_2, Attribute(10))
        ]
        g_1_target = Graph(nodes=vs, edges=es, attr=Attribute(35))

        self.assertTrue(g_1 == g_1_target)

        g_2 = block(g_1)

        v_0, v_1, v_2 = Node(Attribute(1)), Node(Attribute(10 + 2)), Node(
            Attribute(20 + 11 + 12))
        vs = [v_0, v_1, v_2]  # nodes
        es = [
            Edge(v_0, v_1, Attribute(1)),
            Edge(v_0, v_2, Attribute(1)),
            Edge(v_1, v_2, Attribute(11))
        ]
        g_2_target = Graph(nodes=vs,
                           edges=es,
                           attr=Attribute(1 + 12 + 43 - 35))

        self.assertTrue(g_2 == g_2_target)
Exemple #5
0
def get_extraction_block(model: torch.nn.Module):
    """
    Creates a new extraction block which operates on a raw graph as provided by the dataset.
    Applies the model to each node.
    """
    def node_extractor_fn(node: Dict[str, any]) -> Tensor:
        desktop_img: Tensor = node['desktop_img']
        mobile_img: Tensor = node['mobile_img']

        # desktop and mobile feature vector
        x1, x2 = model(desktop_img.unsqueeze(0), mobile_img.unsqueeze(0))

        x = torch.cat((x1, x2), dim=1).view(-1)

        return x

    return GNBlock(phi_v=IndependentNodeUpdate(node_extractor_fn))
    def __init__(self):
        super().__init__()

        self.desktop_screenshot_extractor = DesktopScreenshotFeatureExtractor()

        def node_update_fn(node: Dict[str, any]) -> Attribute:
            desktop_img: torch.Tensor = node['desktop_img']

            rank_scalar = self.desktop_screenshot_extractor(
                desktop_img.unsqueeze(0))

            return rank_scalar

        self.graph_block = GNBlock(phi_e=IdentityEdgeUpdate(),
                                   phi_v=IndependentNodeUpdate(node_update_fn),
                                   phi_u=NodeAggregationGlobalStateUpdate(),
                                   rho_ev=ConstantAggregation(),
                                   rho_vu=AvgAggregation(),
                                   rho_eu=ConstantAggregation())
Exemple #7
0
 def __init__(self):
     super().__init__()
     self.dense = nn.Linear(64, 1)
     self.core = GNBlock(phi_v=IndependentNodeUpdate(self.dense))
     self.dec = GNBlock(rho_vu=MaxAggregation(),
                        phi_u=NodeAggregationGlobalStateUpdate())