def __init__(self, neighbor_finder, node_features, edge_features, device, dropout=0.1, memory_update_at_start=True, message_dimension=100, memory_dimension=200, n_neighbors=None, aggregator_type="last", mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0, std_time_shift_dst=1, threshold=2): super(DGNN, self).__init__() self.neighbor_finder = neighbor_finder self.device = device self.logger = logging.getLogger(__name__) self.node_raw_features = torch.from_numpy( node_features.astype(np.float32)).to(device) self.edge_raw_features = torch.from_numpy( edge_features.astype(np.float32)).to(device) self.n_node_features = self.node_raw_features.shape[1] self.n_nodes = self.node_raw_features.shape[0] self.n_edge_features = self.edge_raw_features.shape[1] self.embedding_dimension = self.n_node_features self.n_neighbors = n_neighbors self.memory_s = None self.memory_g = None self.threshold = threshold self.mean_time_shift_src = mean_time_shift_src self.std_time_shift_src = std_time_shift_src self.mean_time_shift_dst = mean_time_shift_dst self.std_time_shift_dst = std_time_shift_dst self.memory_dimension = memory_dimension self.memory_update_at_start = memory_update_at_start self.message_dimension = message_dimension self.memory_merge = MemoryMerge(self.memory_dimension, self.device) self.memory_s = Memory(n_nodes=self.n_nodes, memory_dimension=self.memory_dimension, message_dimension=message_dimension, device=device) self.memory_g = Memory(n_nodes=self.n_nodes, memory_dimension=self.memory_dimension, message_dimension=message_dimension, device=device) self.message_dim = message_dimension self.message_aggregator = get_message_aggregator( aggregator_type=aggregator_type, device=device) self.message_function = MessageFunction( memory_dimension=memory_dimension, message_dimension=message_dimension, edge_dimension=self.n_edge_features, device=self.device) self.memory_updater_s = MemoryUpdater( memory=self.memory_s, message_dimension=message_dimension, memory_dimension=self.memory_dimension, mean_time_shift_src=self.mean_time_shift_src / 2, device=self.device) self.memory_updater_g = MemoryUpdater( memory=self.memory_g, message_dimension=message_dimension, memory_dimension=self.memory_dimension, mean_time_shift_src=self.mean_time_shift_dst / 2, device=self.device) self.propagater_s = Propagater( memory=self.memory_s, message_dimension=message_dimension, memory_dimension=self.memory_dimension, mean_time_shift_src=self.mean_time_shift_src / 2, neighbor_finder=self.neighbor_finder, n_neighbors=self.n_neighbors, tau=self.threshold, device=self.device) self.propagater_g = Propagater( memory=self.memory_g, message_dimension=message_dimension, memory_dimension=self.memory_dimension, mean_time_shift_src=self.mean_time_shift_dst / 2, neighbor_finder=self.neighbor_finder, n_neighbors=self.n_neighbors, tau=self.threshold, device=self.device) self.W_s = nn.Parameter( torch.zeros( (memory_dimension, memory_dimension // 2)).to(self.device)) #nn.xavier_ self.W_g = nn.Parameter( torch.zeros( (memory_dimension, memory_dimension // 2)).to(self.device))
def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2, n_heads=2, dropout=0.1, use_memory=False, memory_update_at_start=True, message_dimension=100, memory_dimension=500, embedding_module_type="graph_attention", message_function="mlp", mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0, std_time_shift_dst=1, n_neighbors=None, aggregator_type="last", memory_updater_type="gru", use_destination_embedding_in_message=False, use_source_embedding_in_message=False, dyrep=False): super(TGN, self).__init__() self.n_layers = n_layers self.neighbor_finder = neighbor_finder self.device = device self.logger = logging.getLogger(__name__) self.node_raw_features = torch.from_numpy( node_features.astype(np.float32)).to(device) self.edge_raw_features = torch.from_numpy( edge_features.astype(np.float32)).to(device) self.n_node_features = self.node_raw_features.shape[1] self.n_nodes = self.node_raw_features.shape[0] self.n_edge_features = self.edge_raw_features.shape[1] self.embedding_dimension = self.n_node_features self.n_neighbors = n_neighbors self.embedding_module_type = embedding_module_type self.use_destination_embedding_in_message = use_destination_embedding_in_message self.use_source_embedding_in_message = use_source_embedding_in_message self.dyrep = dyrep self.use_memory = use_memory self.time_encoder = TimeEncode(dimension=self.n_node_features) self.memory = None self.mean_time_shift_src = mean_time_shift_src self.std_time_shift_src = std_time_shift_src self.mean_time_shift_dst = mean_time_shift_dst self.std_time_shift_dst = std_time_shift_dst if self.use_memory: self.memory_dimension = memory_dimension self.memory_update_at_start = memory_update_at_start raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \ self.time_encoder.dimension message_dimension = message_dimension if message_function != "identity" else raw_message_dimension self.memory = Memory(n_nodes=self.n_nodes, memory_dimension=self.memory_dimension, input_dimension=message_dimension, message_dimension=message_dimension, device=device) self.message_aggregator = get_message_aggregator( aggregator_type=aggregator_type, device=device) self.message_function = get_message_function( module_type=message_function, raw_message_dimension=raw_message_dimension, message_dimension=message_dimension) self.memory_updater = get_memory_updater( module_type=memory_updater_type, memory=self.memory, message_dimension=message_dimension, memory_dimension=self.memory_dimension, device=device) self.embedding_module_type = embedding_module_type self.embedding_module = get_embedding_module( module_type=embedding_module_type, node_features=self.node_raw_features, edge_features=self.edge_raw_features, memory=self.memory, neighbor_finder=self.neighbor_finder, time_encoder=self.time_encoder, n_layers=self.n_layers, n_node_features=self.n_node_features, n_edge_features=self.n_edge_features, n_time_features=self.n_node_features, embedding_dimension=self.embedding_dimension, device=self.device, n_heads=n_heads, dropout=dropout, use_memory=use_memory, n_neighbors=self.n_neighbors) # MLP to compute probability on an edge given two node embeddings self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features, self.n_node_features, 1)