def __init__(self) -> None: super().__init__() self.pre_process_cnn = PreProcessCNN() def first_node_update_fn(node: Dict[str, any]) -> Attribute: desktop_img: torch.Tensor = node['desktop_img'] hidden_layer_out = self.pre_process_cnn(desktop_img.unsqueeze(0)) return hidden_layer_out self.pre_process_block = GNBlock( phi_v=IndependentNodeUpdate(first_node_update_fn)) self.main_block = GNBlock(phi_e=SenderGlobalAveragePoolingEdgeUpdate(), phi_v=MainNodeUpdate(), rho_ev=AvgAggregation()) self.extraction_block = GNBlock( phi_e=SenderGlobalAveragePoolingEdgeUpdate(), phi_v=ExtractionNodeUpdate(), phi_u=NodeAggregationGlobalStateUpdate(), rho_ev=AvgAggregation(), rho_vu=AvgAggregation())
def __init__(self, num_core_blocks: int, drop_p: float): super().__init__() self.num_core_blocks = num_core_blocks self.drop_p = drop_p self.screenshot_feature_extractor = ScreenshotsFeatureExtractor(drop_p) def node_extractor_fn(node: Dict[str, any]) -> Tensor: desktop_img: Tensor = node['desktop_img'] mobile_img: Tensor = node['mobile_img'] # desktop and mobile feature vector x1, x2 = self.screenshot_feature_extractor( desktop_img.unsqueeze(0), mobile_img.unsqueeze(0)) x = torch.cat((x1, x2), dim=1).view(-1) return x self.extraction_block = GNBlock( phi_v=IndependentNodeUpdate(node_extractor_fn)) self.encoder = GNBlock(phi_e=EncoderEdgeUpdate(), phi_u=EncoderGlobalStateUpdate(), rho_eu=AvgAggregation()) self.core = GNBlock(phi_e=CoreEdgeUpdate(self.drop_p), phi_v=CoreNodeUpdate(self.drop_p), phi_u=CoreGlobalStateUpdate(self.drop_p), rho_ev=AvgAggregation(), rho_vu=AvgAggregation(), rho_eu=AvgAggregation()) self.decoder = GNBlock(phi_u=DecoderGlobalStateUpdate())
def __init__(self, drop_p: float, num_core_blocks: int, edge_mode: str, shared_weights: bool = False): """ Deep graph network for domain rank estimation. :param drop_p: Dropout probability :param num_core_blocks: Number of stacked core blocks, >= 0 :param edge_mode: Whether to keep the graph edges, remove them altogether, or make them bi-directional. In any case, the existence of reflexive edges is ensured. :param shared_weights: """ super().__init__() self.drop_p = drop_p self.edge_fns = { 'default': GNDeep.default, 'bi_directional': GNDeep.bi_directional, 'no_edges': GNDeep.no_edges, 'all_edges': GNDeep.all_edges } assert edge_mode in self.edge_fns, "Invalid edge mode; not in [default, bi_directional, no_edges, all_edges]" self.edge_mode = edge_mode self.enc = GNBlock(phi_e=EncoderEdgeUpdate(), phi_u=EncoderGlobalStateUpdate(), rho_eu=AvgAggregation()) assert num_core_blocks >= 0 core_blocks = [] for i in range(num_core_blocks): if shared_weights and i > 0: block = core_blocks[0] else: block = GNBlock(phi_e=CoreEdgeUpdate(self.drop_p), phi_v=CoreNodeUpdate(self.drop_p), phi_u=CoreGlobalStateUpdate(self.drop_p), rho_ev=AvgAggregation(), rho_vu=AvgAggregation(), rho_eu=AvgAggregation()) core_blocks.append(block) self.core_blocks = ListModule(*core_blocks) self.dec = GNBlock(phi_u=DecoderGlobalStateUpdate() ) # maps global state from vec to scalar
def test_basic(self): """ Basic test w/o PyTorch, all attributes are scalars, edges do not have attributes. Feeds a graph through a basic graph block twice and compares to the target values after both passes. """ # create data structure v_0, v_1, v_2 = Node(Attribute(1)), Node(Attribute(10)), Node( Attribute(20)) vs = [v_0, v_1, v_2] # nodes es = [Edge(v_0, v_1), Edge(v_0, v_2), Edge(v_1, v_2)] g_0 = Graph(nodes=vs, edges=es, attr=Attribute(0)) # create block w/ functions block = GNBlock(phi_e=SenderIdentityEdgeUpdate(), phi_v=EdgeNodeSumNodeUpdate(), phi_u=MixedGlobalStateUpdate(), rho_ev=ScalarSumAggregation(), rho_vu=ScalarSumAggregation(), rho_eu=ScalarSumAggregation()) g_1 = block(g_0) v_0, v_1, v_2 = Node(Attribute(1)), Node(Attribute(10 + 1)), Node( Attribute(20 + 11)) vs = [v_0, v_1, v_2] # nodes es = [ Edge(v_0, v_1, Attribute(1)), Edge(v_0, v_2, Attribute(1)), Edge(v_1, v_2, Attribute(10)) ] g_1_target = Graph(nodes=vs, edges=es, attr=Attribute(35)) self.assertTrue(g_1 == g_1_target) g_2 = block(g_1) v_0, v_1, v_2 = Node(Attribute(1)), Node(Attribute(10 + 2)), Node( Attribute(20 + 11 + 12)) vs = [v_0, v_1, v_2] # nodes es = [ Edge(v_0, v_1, Attribute(1)), Edge(v_0, v_2, Attribute(1)), Edge(v_1, v_2, Attribute(11)) ] g_2_target = Graph(nodes=vs, edges=es, attr=Attribute(1 + 12 + 43 - 35)) self.assertTrue(g_2 == g_2_target)
def get_extraction_block(model: torch.nn.Module): """ Creates a new extraction block which operates on a raw graph as provided by the dataset. Applies the model to each node. """ def node_extractor_fn(node: Dict[str, any]) -> Tensor: desktop_img: Tensor = node['desktop_img'] mobile_img: Tensor = node['mobile_img'] # desktop and mobile feature vector x1, x2 = model(desktop_img.unsqueeze(0), mobile_img.unsqueeze(0)) x = torch.cat((x1, x2), dim=1).view(-1) return x return GNBlock(phi_v=IndependentNodeUpdate(node_extractor_fn))
def __init__(self): super().__init__() self.desktop_screenshot_extractor = DesktopScreenshotFeatureExtractor() def node_update_fn(node: Dict[str, any]) -> Attribute: desktop_img: torch.Tensor = node['desktop_img'] rank_scalar = self.desktop_screenshot_extractor( desktop_img.unsqueeze(0)) return rank_scalar self.graph_block = GNBlock(phi_e=IdentityEdgeUpdate(), phi_v=IndependentNodeUpdate(node_update_fn), phi_u=NodeAggregationGlobalStateUpdate(), rho_ev=ConstantAggregation(), rho_vu=AvgAggregation(), rho_eu=ConstantAggregation())
def __init__(self): super().__init__() self.dense = nn.Linear(64, 1) self.core = GNBlock(phi_v=IndependentNodeUpdate(self.dense)) self.dec = GNBlock(rho_vu=MaxAggregation(), phi_u=NodeAggregationGlobalStateUpdate())