コード例 #1
0
    def __init__(self) -> None:
        super().__init__()

        self.pre_process_cnn = PreProcessCNN()

        def first_node_update_fn(node: Dict[str, any]) -> Attribute:
            desktop_img: torch.Tensor = node['desktop_img']

            hidden_layer_out = self.pre_process_cnn(desktop_img.unsqueeze(0))

            return hidden_layer_out

        self.pre_process_block = GNBlock(
            phi_v=IndependentNodeUpdate(first_node_update_fn))

        self.main_block = GNBlock(phi_e=SenderGlobalAveragePoolingEdgeUpdate(),
                                  phi_v=MainNodeUpdate(),
                                  rho_ev=AvgAggregation())

        self.extraction_block = GNBlock(
            phi_e=SenderGlobalAveragePoolingEdgeUpdate(),
            phi_v=ExtractionNodeUpdate(),
            phi_u=NodeAggregationGlobalStateUpdate(),
            rho_ev=AvgAggregation(),
            rho_vu=AvgAggregation())
コード例 #2
0
    def __init__(self, num_core_blocks: int, drop_p: float):
        super().__init__()

        self.num_core_blocks = num_core_blocks
        self.drop_p = drop_p

        self.screenshot_feature_extractor = ScreenshotsFeatureExtractor(drop_p)

        def node_extractor_fn(node: Dict[str, any]) -> Tensor:
            desktop_img: Tensor = node['desktop_img']
            mobile_img: Tensor = node['mobile_img']

            # desktop and mobile feature vector
            x1, x2 = self.screenshot_feature_extractor(
                desktop_img.unsqueeze(0), mobile_img.unsqueeze(0))

            x = torch.cat((x1, x2), dim=1).view(-1)

            return x

        self.extraction_block = GNBlock(
            phi_v=IndependentNodeUpdate(node_extractor_fn))

        self.encoder = GNBlock(phi_e=EncoderEdgeUpdate(),
                               phi_u=EncoderGlobalStateUpdate(),
                               rho_eu=AvgAggregation())
        self.core = GNBlock(phi_e=CoreEdgeUpdate(self.drop_p),
                            phi_v=CoreNodeUpdate(self.drop_p),
                            phi_u=CoreGlobalStateUpdate(self.drop_p),
                            rho_ev=AvgAggregation(),
                            rho_vu=AvgAggregation(),
                            rho_eu=AvgAggregation())
        self.decoder = GNBlock(phi_u=DecoderGlobalStateUpdate())
コード例 #3
0
def get_extraction_block(model: torch.nn.Module):
    """
    Creates a new extraction block which operates on a raw graph as provided by the dataset.
    Applies the model to each node.
    """
    def node_extractor_fn(node: Dict[str, any]) -> Tensor:
        desktop_img: Tensor = node['desktop_img']
        mobile_img: Tensor = node['mobile_img']

        # desktop and mobile feature vector
        x1, x2 = model(desktop_img.unsqueeze(0), mobile_img.unsqueeze(0))

        x = torch.cat((x1, x2), dim=1).view(-1)

        return x

    return GNBlock(phi_v=IndependentNodeUpdate(node_extractor_fn))
コード例 #4
0
    def __init__(self):
        super().__init__()

        self.desktop_screenshot_extractor = DesktopScreenshotFeatureExtractor()

        def node_update_fn(node: Dict[str, any]) -> Attribute:
            desktop_img: torch.Tensor = node['desktop_img']

            rank_scalar = self.desktop_screenshot_extractor(
                desktop_img.unsqueeze(0))

            return rank_scalar

        self.graph_block = GNBlock(phi_e=IdentityEdgeUpdate(),
                                   phi_v=IndependentNodeUpdate(node_update_fn),
                                   phi_u=NodeAggregationGlobalStateUpdate(),
                                   rho_ev=ConstantAggregation(),
                                   rho_vu=AvgAggregation(),
                                   rho_eu=ConstantAggregation())
コード例 #5
0
    def __init__(self, e_config: OptLinearConfig = None, v_config: OptLinearConfig = None,
                 u_config: OptLinearConfig = None) -> None:
        """
        For all three parameters, the identity mapping will be used if None is being passed.
        :param e_config: Configuration of the linear layer that is applied to edges
        :param v_config: Configuration of the linear layer that is applied to nodes
        :param u_config: Configuration of the linear layer that is applied to the global state
        """

        super().__init__()

        phi_e = IdentityEdgeUpdate() if e_config is None else IndependentEdgeUpdate(nn.Linear(*e_config))
        phi_v = IdentityNodeUpdate() if v_config is None else IndependentNodeUpdate(nn.Linear(*v_config))
        phi_u = IdentityGlobalStateUpdate() if u_config is None else IndependentGlobalStateUpdate(nn.Linear(*u_config))

        self.block = GNBlock(
            phi_e, phi_v, phi_u,
            rho_ev=ConstantAggregation(),
            rho_vu=ConstantAggregation(),
            rho_eu=ConstantAggregation())
コード例 #6
0
 def __init__(self):
     super().__init__()
     self.dense = nn.Linear(64, 1)
     self.core = GNBlock(phi_v=IndependentNodeUpdate(self.dense))
     self.dec = GNBlock(rho_vu=MaxAggregation(),
                        phi_u=NodeAggregationGlobalStateUpdate())