Exemplo n.º 1
0
    def __init__(self,
                 n_node_features,
                 n_neighbors_features,
                 n_edge_features,
                 time_dim,
                 output_dimension,
                 n_head=2,
                 dropout=0.1):
        super(TemporalAttentionLayer, self).__init__()

        self.n_head = n_head

        self.feat_dim = n_node_features
        self.time_dim = time_dim

        self.query_dim = n_node_features + time_dim
        self.key_dim = n_neighbors_features + time_dim + n_edge_features

        self.merger = MergeLayer(self.query_dim, n_node_features,
                                 n_node_features, output_dimension)

        self.multi_head_target = nn.MultiheadAttention(
            embed_dim=self.query_dim,
            kdim=self.key_dim,
            vdim=self.key_dim,
            num_heads=n_head,
            dropout=dropout)
Exemplo n.º 2
0
    def __init__(self,
                 n_node_features,
                 n_neighbors_features,
                 n_edge_features,
                 time_dim,
                 output_dimension,
                 n_head=2,
                 n_relations=5,
                 dropout=0.1,
                 aggregate='stack'):
        super(TemporalAttentionLayer, self).__init__()

        self.n_head = n_head

        self.n_relations = n_relations

        self.feat_dim = n_node_features
        self.time_dim = time_dim

        self.query_dim = n_node_features + time_dim
        self.key_dim = n_neighbors_features + time_dim + n_edge_features

        self.aggregate = aggregate
        if self.aggregate == 'stack':
            self.merger = MergeLayer(self.query_dim * n_relations,
                                     n_node_features, n_node_features,
                                     output_dimension)
        elif self.aggregate == 'sum':
            self.merger = MergeLayer(self.query_dim, n_node_features,
                                     n_node_features, output_dimension)
        else:
            raise ValueError("Aggregating Method {} not supported".format(
                self.aggregate))

        self.multi_head_target = torch.nn.ModuleList([
            nn.MultiheadAttention(embed_dim=self.query_dim,
                                  kdim=self.key_dim,
                                  vdim=self.key_dim,
                                  num_heads=n_head,
                                  dropout=dropout) for _ in range(n_relations)
        ])
Exemplo n.º 3
0
    def __init__(self,
                 neighbor_finder,
                 node_features,
                 edge_features,
                 device,
                 n_layers=2,
                 n_heads=2,
                 dropout=0.1,
                 use_memory=False,
                 memory_update_at_start=True,
                 message_dimension=100,
                 memory_dimension=500,
                 embedding_module_type="graph_attention",
                 message_function="mlp",
                 mean_time_shift_src=0,
                 std_time_shift_src=1,
                 mean_time_shift_dst=0,
                 std_time_shift_dst=1,
                 n_neighbors=None,
                 aggregator_type="last",
                 memory_updater_type="gru",
                 use_destination_embedding_in_message=False,
                 use_source_embedding_in_message=False,
                 dyrep=False):
        super(TGN, self).__init__()

        self.n_layers = n_layers
        self.neighbor_finder = neighbor_finder
        self.device = device
        self.logger = logging.getLogger(__name__)

        self.node_raw_features = torch.from_numpy(
            node_features.astype(np.float32)).to(device)
        self.edge_raw_features = torch.from_numpy(
            edge_features.astype(np.float32)).to(device)
        self.n_node_features = self.node_raw_features.shape[1]
        self.n_nodes = self.node_raw_features.shape[0]
        self.n_edge_features = self.edge_raw_features.shape[1]
        self.embedding_dimension = self.n_node_features
        self.n_neighbors = n_neighbors
        self.embedding_module_type = embedding_module_type
        self.use_destination_embedding_in_message = use_destination_embedding_in_message
        self.use_source_embedding_in_message = use_source_embedding_in_message
        self.dyrep = dyrep

        self.use_memory = use_memory
        self.time_encoder = TimeEncode(dimension=self.n_node_features)
        self.memory = None

        self.mean_time_shift_src = mean_time_shift_src
        self.std_time_shift_src = std_time_shift_src
        self.mean_time_shift_dst = mean_time_shift_dst
        self.std_time_shift_dst = std_time_shift_dst

        if self.use_memory:
            self.memory_dimension = memory_dimension
            self.memory_update_at_start = memory_update_at_start
            raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \
                                    self.time_encoder.dimension
            message_dimension = message_dimension if message_function != "identity" else raw_message_dimension
            self.memory = Memory(n_nodes=self.n_nodes,
                                 memory_dimension=self.memory_dimension,
                                 input_dimension=message_dimension,
                                 message_dimension=message_dimension,
                                 device=device)
            self.message_aggregator = get_message_aggregator(
                aggregator_type=aggregator_type, device=device)
            self.message_function = get_message_function(
                module_type=message_function,
                raw_message_dimension=raw_message_dimension,
                message_dimension=message_dimension)
            self.memory_updater = get_memory_updater(
                module_type=memory_updater_type,
                memory=self.memory,
                message_dimension=message_dimension,
                memory_dimension=self.memory_dimension,
                device=device)

        self.embedding_module_type = embedding_module_type

        self.embedding_module = get_embedding_module(
            module_type=embedding_module_type,
            node_features=self.node_raw_features,
            edge_features=self.edge_raw_features,
            memory=self.memory,
            neighbor_finder=self.neighbor_finder,
            time_encoder=self.time_encoder,
            n_layers=self.n_layers,
            n_node_features=self.n_node_features,
            n_edge_features=self.n_edge_features,
            n_time_features=self.n_node_features,
            embedding_dimension=self.embedding_dimension,
            device=self.device,
            n_heads=n_heads,
            dropout=dropout,
            use_memory=use_memory,
            n_neighbors=self.n_neighbors)

        # MLP to compute probability on an edge given two node embeddings
        self.affinity_score = MergeLayer(self.n_node_features,
                                         self.n_node_features,
                                         self.n_node_features, 1)