コード例 #1
0
    def __init__(self,
                 neighbor_finder,
                 node_features,
                 edge_features,
                 device,
                 dropout=0.1,
                 memory_update_at_start=True,
                 message_dimension=100,
                 memory_dimension=200,
                 n_neighbors=None,
                 aggregator_type="last",
                 mean_time_shift_src=0,
                 std_time_shift_src=1,
                 mean_time_shift_dst=0,
                 std_time_shift_dst=1,
                 threshold=2):
        super(DGNN, self).__init__()
        self.neighbor_finder = neighbor_finder
        self.device = device
        self.logger = logging.getLogger(__name__)

        self.node_raw_features = torch.from_numpy(
            node_features.astype(np.float32)).to(device)
        self.edge_raw_features = torch.from_numpy(
            edge_features.astype(np.float32)).to(device)

        self.n_node_features = self.node_raw_features.shape[1]
        self.n_nodes = self.node_raw_features.shape[0]
        self.n_edge_features = self.edge_raw_features.shape[1]
        self.embedding_dimension = self.n_node_features
        self.n_neighbors = n_neighbors
        self.memory_s = None
        self.memory_g = None

        self.threshold = threshold
        self.mean_time_shift_src = mean_time_shift_src
        self.std_time_shift_src = std_time_shift_src
        self.mean_time_shift_dst = mean_time_shift_dst
        self.std_time_shift_dst = std_time_shift_dst
        self.memory_dimension = memory_dimension
        self.memory_update_at_start = memory_update_at_start
        self.message_dimension = message_dimension
        self.memory_merge = MemoryMerge(self.memory_dimension, self.device)
        self.memory_s = Memory(n_nodes=self.n_nodes,
                               memory_dimension=self.memory_dimension,
                               message_dimension=message_dimension,
                               device=device)
        self.memory_g = Memory(n_nodes=self.n_nodes,
                               memory_dimension=self.memory_dimension,
                               message_dimension=message_dimension,
                               device=device)
        self.message_dim = message_dimension
        self.message_aggregator = get_message_aggregator(
            aggregator_type=aggregator_type, device=device)
        self.message_function = MessageFunction(
            memory_dimension=memory_dimension,
            message_dimension=message_dimension,
            edge_dimension=self.n_edge_features,
            device=self.device)
        self.memory_updater_s = MemoryUpdater(
            memory=self.memory_s,
            message_dimension=message_dimension,
            memory_dimension=self.memory_dimension,
            mean_time_shift_src=self.mean_time_shift_src / 2,
            device=self.device)
        self.memory_updater_g = MemoryUpdater(
            memory=self.memory_g,
            message_dimension=message_dimension,
            memory_dimension=self.memory_dimension,
            mean_time_shift_src=self.mean_time_shift_dst / 2,
            device=self.device)
        self.propagater_s = Propagater(
            memory=self.memory_s,
            message_dimension=message_dimension,
            memory_dimension=self.memory_dimension,
            mean_time_shift_src=self.mean_time_shift_src / 2,
            neighbor_finder=self.neighbor_finder,
            n_neighbors=self.n_neighbors,
            tau=self.threshold,
            device=self.device)
        self.propagater_g = Propagater(
            memory=self.memory_g,
            message_dimension=message_dimension,
            memory_dimension=self.memory_dimension,
            mean_time_shift_src=self.mean_time_shift_dst / 2,
            neighbor_finder=self.neighbor_finder,
            n_neighbors=self.n_neighbors,
            tau=self.threshold,
            device=self.device)
        self.W_s = nn.Parameter(
            torch.zeros(
                (memory_dimension, memory_dimension // 2)).to(self.device))
        #nn.xavier_
        self.W_g = nn.Parameter(
            torch.zeros(
                (memory_dimension, memory_dimension // 2)).to(self.device))
コード例 #2
0
ファイル: tgn.py プロジェクト: jacob-heglund/tgn
    def __init__(self,
                 neighbor_finder,
                 node_features,
                 edge_features,
                 device,
                 n_layers=2,
                 n_heads=2,
                 dropout=0.1,
                 use_memory=False,
                 memory_update_at_start=True,
                 message_dimension=100,
                 memory_dimension=500,
                 embedding_module_type="graph_attention",
                 message_function="mlp",
                 mean_time_shift_src=0,
                 std_time_shift_src=1,
                 mean_time_shift_dst=0,
                 std_time_shift_dst=1,
                 n_neighbors=None,
                 aggregator_type="last",
                 memory_updater_type="gru",
                 use_destination_embedding_in_message=False,
                 use_source_embedding_in_message=False,
                 dyrep=False):
        super(TGN, self).__init__()

        self.n_layers = n_layers
        self.neighbor_finder = neighbor_finder
        self.device = device
        self.logger = logging.getLogger(__name__)

        self.node_raw_features = torch.from_numpy(
            node_features.astype(np.float32)).to(device)
        self.edge_raw_features = torch.from_numpy(
            edge_features.astype(np.float32)).to(device)
        self.n_node_features = self.node_raw_features.shape[1]
        self.n_nodes = self.node_raw_features.shape[0]
        self.n_edge_features = self.edge_raw_features.shape[1]
        self.embedding_dimension = self.n_node_features
        self.n_neighbors = n_neighbors
        self.embedding_module_type = embedding_module_type
        self.use_destination_embedding_in_message = use_destination_embedding_in_message
        self.use_source_embedding_in_message = use_source_embedding_in_message
        self.dyrep = dyrep

        self.use_memory = use_memory
        self.time_encoder = TimeEncode(dimension=self.n_node_features)
        self.memory = None

        self.mean_time_shift_src = mean_time_shift_src
        self.std_time_shift_src = std_time_shift_src
        self.mean_time_shift_dst = mean_time_shift_dst
        self.std_time_shift_dst = std_time_shift_dst

        if self.use_memory:
            self.memory_dimension = memory_dimension
            self.memory_update_at_start = memory_update_at_start
            raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \
                                    self.time_encoder.dimension
            message_dimension = message_dimension if message_function != "identity" else raw_message_dimension
            self.memory = Memory(n_nodes=self.n_nodes,
                                 memory_dimension=self.memory_dimension,
                                 input_dimension=message_dimension,
                                 message_dimension=message_dimension,
                                 device=device)
            self.message_aggregator = get_message_aggregator(
                aggregator_type=aggregator_type, device=device)
            self.message_function = get_message_function(
                module_type=message_function,
                raw_message_dimension=raw_message_dimension,
                message_dimension=message_dimension)
            self.memory_updater = get_memory_updater(
                module_type=memory_updater_type,
                memory=self.memory,
                message_dimension=message_dimension,
                memory_dimension=self.memory_dimension,
                device=device)

        self.embedding_module_type = embedding_module_type

        self.embedding_module = get_embedding_module(
            module_type=embedding_module_type,
            node_features=self.node_raw_features,
            edge_features=self.edge_raw_features,
            memory=self.memory,
            neighbor_finder=self.neighbor_finder,
            time_encoder=self.time_encoder,
            n_layers=self.n_layers,
            n_node_features=self.n_node_features,
            n_edge_features=self.n_edge_features,
            n_time_features=self.n_node_features,
            embedding_dimension=self.embedding_dimension,
            device=self.device,
            n_heads=n_heads,
            dropout=dropout,
            use_memory=use_memory,
            n_neighbors=self.n_neighbors)

        # MLP to compute probability on an edge given two node embeddings
        self.affinity_score = MergeLayer(self.n_node_features,
                                         self.n_node_features,
                                         self.n_node_features, 1)