def test_rgcn(): etype = [] g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) # 5 etypes R = 5 for i in range(g.number_of_edges()): etype.append(i % 5) B = 2 I = 10 O = 8 rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) h = tf.random.normal((100, I)) r = tf.constant(etype) h_new = rgc_basis(g, h, r) assert list(h_new.shape) == [100, O] rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) h = tf.random.normal((100, I)) r = tf.constant(etype) h_new = rgc_bdd(g, h, r) assert list(h_new.shape) == [100, O] # with norm norm = tf.zeros((g.number_of_edges(), 1)) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) h = tf.random.normal((100, I)) r = tf.constant(etype) h_new = rgc_basis(g, h, r, norm) assert list(h_new.shape) == [100, O] rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) h = tf.random.normal((100, I)) r = tf.constant(etype) h_new = rgc_bdd(g, h, r, norm) assert list(h_new.shape) == [100, O] # id input rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) h = tf.constant(np.random.randint(0, I, (100,))) r = tf.constant(etype) h_new = rgc_basis(g, h, r) assert list(h_new.shape) == [100, O]
def test_rgcn(): etype = [] g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx()) # 5 etypes R = 5 for i in range(g.number_of_edges()): etype.append(i % 5) B = 2 I = 10 O = 8 rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) rgc_basis_low.weight = rgc_basis.weight rgc_basis_low.w_comp = rgc_basis.w_comp h = tf.random.normal((100, I)) r = tf.constant(etype) h_new = rgc_basis(g, h, r) h_new_low = rgc_basis_low(g, h, r) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O] assert F.allclose(h_new, h_new_low) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True) rgc_bdd_low.weight = rgc_bdd.weight h = tf.random.normal((100, I)) r = tf.constant(etype) h_new = rgc_bdd(g, h, r) h_new_low = rgc_bdd_low(g, h, r) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O] assert F.allclose(h_new, h_new_low) # with norm norm = tf.zeros((g.number_of_edges(), 1)) rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) rgc_basis_low.weight = rgc_basis.weight rgc_basis_low.w_comp = rgc_basis.w_comp h = tf.random.normal((100, I)) r = tf.constant(etype) h_new = rgc_basis(g, h, r, norm) h_new_low = rgc_basis_low(g, h, r, norm) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O] assert F.allclose(h_new, h_new_low) rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B) rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True) rgc_bdd_low.weight = rgc_bdd.weight h = tf.random.normal((100, I)) r = tf.constant(etype) h_new = rgc_bdd(g, h, r, norm) h_new_low = rgc_bdd_low(g, h, r, norm) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O] assert F.allclose(h_new, h_new_low) # id input rgc_basis = nn.RelGraphConv(I, O, R, "basis", B) rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True) rgc_basis_low.weight = rgc_basis.weight rgc_basis_low.w_comp = rgc_basis.w_comp h = tf.constant(np.random.randint(0, I, (100, ))) * 1 r = tf.constant(etype) * 1 h_new = rgc_basis(g, h, r) h_new_low = rgc_basis_low(g, h, r) assert list(h_new.shape) == [100, O] assert list(h_new_low.shape) == [100, O] assert F.allclose(h_new, h_new_low)