def graph_conv_cheby(self, x, W, L, lmax, Fout, K): # parameters # B = batch size # V = nb vertices # Fin = nb input features # Fout = nb output features # K = Chebyshev order & support size B, V, Fin = x.get_shape() B, V, Fin = int(B), int(V), int(Fin) print("fin", Fin) # rescale Laplacian lmax = lmaxX(L) L = rescale_L(L, lmax) # scipy sparse matric of L L = L.tocoo() indices = np.column_stack((L.row, L.col)) L = tf.SparseTensor(indices, L.data, L.shape) L = tf.sparse_reorder(L) # Transform to Chebyshev basis x0 = tf.transpose(x, perm=[1, 2, 0]) # V x Fin x B x0 = tf.reshape(x0, [V, Fin * B]) # V x Fin*B x = tf.expand_dims(x0, 0) # 1 x V x Fin*B def concat(x, x_): x_ = tf.expand_dims(x_, 0) # 1 x V x Fin*B return tf.concat([x, x_], 0) # K x V x Fin*B if K > 1: x1 = tf.sparse_tensor_dense_matmul(L, x0) x = concat(x, x1) for k in range(2, K): x2 = 2 * tf.sparse_tensor_dense_matmul(L, x1) - x0 x = concat(x, x2) # M x Fin*B x0, x1 = x1, x2 x = tf.reshape(x, [K, V, Fin, B]) # K x V x Fin x B x = tf.transpose(x, perm=[3, 1, 2, 0]) # B x V x Fin x K print("xxxx", x) x = tf.reshape(x, [B * V, Fin * K]) # B*V x Fin*K # Compose linearly Fin features to get Fout features x = tf.matmul(x, W) # B*V x Fout x = tf.reshape(x, [B, V, Fout]) # B x V x Fout return x
# Construct graph t_start = time.time() grid_side = 28 number_edges = 8 metric = 'euclidean' A = grid_graph(grid_side,number_edges,metric) # create graph of Euclidean grid # Compute coarsened graphs coarsening_levels = 4 num_vertices, L, perm = coarsen(A, coarsening_levels) # Compute max eigenvalue of graph Laplacians lmax = [] for i in range(coarsening_levels): lmax.append(lmaxX(L[i])) print('lmax: ' + str([lmax[i] for i in range(coarsening_levels)])) train_data = tf.Session().run(train_data) #comment this line for non rotated test images test_data = tf.Session().run(test_data) # Reindex nodes to satisfy a binary tree structure train_data = perm_data(train_data, perm) val_data = perm_data(val_data, perm) test_data = perm_data(test_data, perm) print(train_data.shape) print(val_data.shape) print(test_data.shape)