Exemplo n.º 1
0
  def test_permutation_invariance(self):

    num_nodes = 4
    num_features = 2
    rng = random.PRNGKey(0)

    # Generate random graph.
    adjacency = random.randint(rng, (num_nodes, num_nodes), 0, 2)
    node_feats = random.normal(rng, (num_nodes, num_features))
    sources, targets = jnp.where(adjacency)

    # Get permuted graph.
    perm = random.permutation(rng, jnp.arange(num_nodes))
    node_feats_perm = node_feats[perm]
    adjacency_perm = adjacency[perm]
    for j in range(len(adjacency)):
      adjacency_perm = jax.ops.index_update(
          adjacency_perm, j, adjacency_perm[j][perm])
    sources_perm, targets_perm = jnp.where(adjacency_perm)

    # Create GNN.
    _, initial_params = GNN.init(
      rng, node_x=node_feats, edge_x=None, sources=sources, targets=targets)
    model = nn.Model(GNN, initial_params)

    # Feedforward both original and permuted graph.
    logits = model(node_feats, None, sources, targets)
    logits_perm = model(node_feats_perm, None, sources_perm, targets_perm)

    self.assertAllClose(logits[perm], logits_perm, check_dtypes=False)
Exemplo n.º 2
0
 def test_single_train_step(self):
   prng = random.PRNGKey(0)
   #       0 (0)
   #      / \
   # (0) 1 - 2 (1)
   edge_list = [(0, 0), (1, 2), (2, 0)]
   node_labels = [0, 0, 1]
   node_feats, node_labels, sources, targets = train.create_graph_data(edge_list=edge_list,
                                                                       node_labels=node_labels)
   _, initial_params = GNN.init(prng, node_x=node_feats, edge_x=None,
                                sources=sources, targets=targets)
   model = nn.Model(GNN, initial_params)
   optimizer = optim.Adam(learning_rate=0.01).create(model)
   _, loss = train.train_step(optimizer=optimizer, node_feats=node_feats,
                              sources=sources, targets=targets)
   
   self.assertGreater(loss, 0.0)