Ejemplo n.º 1
0
    def __init__(self, data, device):
        super(local_neighbourhood_cluster, self).__init__()

        memory_dim = time_dim = self.embedding_dim = 100
        self.data = data
        self.device = device

        self.memory = TGNMemory(
            self.data.num_nodes,
            self.data.msg.size(-1),
            memory_dim,
            time_dim,
            message_module=IdentityMessage(self.data.msg.size(-1), memory_dim,
                                           time_dim),
            aggregator_module=LastAggregator(),
        ).to(self.device)

        self.gnn = GraphAttentionEmbedding(
            in_channels=memory_dim,
            out_channels=self.embedding_dim,
            msg_dim=self.data.msg.size(-1),
            time_enc=self.memory.time_enc,
        ).to(self.device)

        self.link_pred = LinkPredictor_khop(in_channels=self.embedding_dim).to(
            self.device)

        self.neighbor_loader = LastNeighborLoader(self.data.num_nodes,
                                                  size=10,
                                                  device=self.device)

        self.criterion = torch.nn.BCEWithLogitsLoss()

        self.assoc = torch.empty(self.data.num_nodes,
                                 dtype=torch.long,
                                 device=self.device)

        self.min_dst_idx, self.max_dst_idx = int(self.data.dst.min()), int(
            self.data.dst.max())
        self.adj_list = {}
Ejemplo n.º 2
0
        self.lin_final = Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)


memory_dim = time_dim = embedding_dim = 100

memory = TGNMemory(
    data.num_nodes,
    data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

optimizer = torch.optim.Adam(set(memory.parameters()) | set(gnn.parameters())
                             | set(link_pred.parameters()),
                             lr=0.0001)
Ejemplo n.º 3
0
def main(memory_type: str,
         path: str,
         dataset: str,
         embedding_dim: int,
         learning_rate: float = 0.0001,
         epochs: int = 25,
         run: 'wandb.Run' = None,
         regularize: bool = False,
         **kwargs):
    dataset = JODIEDataset(path, name=dataset)
    data = dataset[0]

    memory_dim = time_dim = embedding_dim
    raw_msg_dim = data.msg.size(-1)

    if memory_type == 'tgn':
        msg_module = IdentityMessage(data.msg.size(-1), memory_dim, time_dim)
        expire_span = (args.expire_span > 0) and ExpireSpan(
            dim=msg_module.out_channels, max_time=args.expire_span)
        memory = TGNMemory(
            data.num_nodes,
            data.msg.size(-1),
            memory_dim,
            time_dim,
            message_module=msg_module,
            aggregator_module=AttentionAggregator(
                msg_module.out_channels),  # LastAggregator(),
            expire_span=expire_span)
    elif memory_type == 'tsam':
        memory = TSAM(
            data.num_nodes,
            raw_msg_dim,
            memory_dim,
            time_dim,
            message_module=TSAMMessage(raw_msg_dim, memory_dim, time_dim),
            aggregator_module=AttentionAggregator(
                memory_dim)  #TSAMAggregator(),
        )
    else:
        raise ValueError(f'Invalid memory_type {memory_type}.')

    gnn = GraphAttentionEmbedding(
        in_channels=memory_dim,
        out_channels=embedding_dim,
        msg_dim=data.msg.size(-1),
        time_enc=memory.time_enc,
    )

    link_pred = LinkPredictor(in_channels=embedding_dim)

    optimizer = torch.optim.Adam(set(memory.parameters())
                                 | set(gnn.parameters())
                                 | set(link_pred.parameters()),
                                 lr=learning_rate)

    trainer = Trainer(dataset,
                      memory,
                      gnn,
                      link_pred,
                      optimizer,
                      run=run,
                      epochs=epochs,
                      regularize=regularize)
    trainer.trial()
Ejemplo n.º 4
0
from modules.denoiserMemory import DenoiserMemory,  LastAggregator


wandb.init(project="denoiser")
G, G_noisy, C = graphGenerators.getRandomStartingNxGraphs(8, 30, 30)
torch_G = torch_geometric.utils.from_networkx(G_noisy)
torch_G_base = deepcopy(torch_G)
memory_dim = time_dim = embedding_dim = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.autograd.set_detect_anomaly(True)
memory = DenoiserMemory(
    torch_G.num_nodes*2,
    0,
    memory_dim,
    time_dim,
    message_module=IdentityMessage(0, memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=0,
    time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=embedding_dim).to(device)
rel_subgraph = RelationModule(100, 5)
rel_node = RelationModule(100, 5)

neighbor_loader = LastNeighborLoader(torch_G.num_nodes, size=25, device=device)