def test_update_points(self): with self.test_session() as sess: adj_mt = tf.convert_to_tensor( np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=np.int32)) points_data = np.array([[1], [2], [3]], dtype=np.float32) edges_data = tf.constant( np.array([[5], [6], [7], [8], [9], [10]], dtype=np.float32)) A = DynamicAdjacentMatrix(adj_mt=adj_mt, points_data=points_data, edges_data=edges_data) #A.edges_reducer_for_points = tf.unsorted_segment_mean A.global_attr = tf.convert_to_tensor( np.array([[21]], dtype=np.float32)) def update_points(x): def fn(x): return tf.reshape(tf.reduce_sum(x), [1]) return tf.map_fn(fn, elems=(x)) A.update_points(update_points) sess.run(tf.global_variables_initializer()) points_data = A.points_data.eval() print(points_data) self.assertAllClose(points_data, [[35.5], [38.0], [40.5]], atol=1e-5)
def test_update_edges_independent(self): with self.test_session() as sess: adj_mt = tf.convert_to_tensor( np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=np.int32)) points_data = np.array([[1], [2], [3]], dtype=np.float32) edges_data = tf.constant( np.array([[5], [6], [7], [8], [9], [10]], dtype=np.float32)) A = DynamicAdjacentMatrix(adj_mt=adj_mt, points_data=points_data, edges_data=edges_data) A.global_attr = tf.convert_to_tensor( np.array([[21]], dtype=np.float32)) def update_edges(x): def fn(x): return x * 2 return tf.map_fn(fn, elems=(x)) A.update_edges_independent(update_edges) self.assertAllClose(A.edges_data.eval(), edges_data * 2, atol=1e-5)
def test_update_globals(self): with self.test_session() as sess: adj_mt = tf.convert_to_tensor( np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=np.int32)) points_data = np.array([[1], [2], [3]], dtype=np.float32) edges_data = tf.constant( np.array([[5], [6], [7], [8], [9], [10]], dtype=np.float32)) A = DynamicAdjacentMatrix(adj_mt=adj_mt, points_data=points_data, edges_data=edges_data) A.global_attr = tf.convert_to_tensor( np.array([[21]], dtype=np.float32)) def update_global(x): def fn(x): return tf.reshape(tf.reduce_sum(x), [1]) return tf.map_fn(fn, elems=(x)) A.update_global(update_global) self.assertAllClose(A.global_attr.eval(), [[30.5]], atol=1e-5)