def test_dynamic_shapes(self): # Test shape_as_value with dynamic shapes. All transformations work. def f(x): return jnp.sum(x, axis=0) * jax2tf.shape_as_value(x)[0] x = np.arange(3.) self.assertAllClose(9., jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x)) self.assertAllClose( 9., jax2tf.convert(jax.jit(f), polymorphic_shapes=["(b,)"])(x)) self.assertAllClose( 9., tf.function(jax2tf.convert(f, polymorphic_shapes=["(b,)"]))(x)) res_primal, res_tangent = jax2tf.convert( lambda x, xt: jax.jvp(f, (x,), (xt,)), polymorphic_shapes=["b", "b"])(x, np.array([0.1, 0.2, 0.3])) self.assertAllClose((9., 1.8), (res_primal, res_tangent)) self.assertAllClose( np.array([3., 3., 3.]), jax2tf.convert(jax.grad(f), polymorphic_shapes=["b"])(x)) xv = np.arange(24.).reshape((2, 3, 4)) res_vmap = jax.vmap(f, in_axes=1)(xv) # Implement by iteration res_iter = jnp.stack([f(xv[:, i, :]) for i in range(xv.shape[1])]) self.assertAllClose(res_iter, res_vmap) res_mask2, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=2)) self.assertAllClose(2., res_mask2) res_mask3, _ = jax.mask(f, polymorphic_shapes=["(b,)"])([x], dict(b=3)) self.assertAllClose(9., res_mask3)
def check(self, fun, in_shapes, out_shape, logical_env, padded_in_shapes, dtypes, rng, rtol=None, atol=None): shapecheck(in_shapes, out_shape)(fun) masked_fun = mask(fun, in_shapes, out_shape) padded_args = [rng(shape, dtype) for shape, dtype in zip(padded_in_shapes, dtypes)] padded_outs, outs_tree = tree_flatten(masked_fun(padded_args, logical_env)) out_specs, _ = tree_flatten(out_shape) out_specs = map(parse_spec, out_specs) out_specs = map(finalize_spec, out_specs, map(np.shape, padded_outs)) logical_out_shapes = [eval_poly_shape(s, logical_env) for s in out_specs] logical_out_slices = [tuple(map(slice, s)) for s in logical_out_shapes] logical_outs = [o[s] for o, s in zip(padded_outs, logical_out_slices)] in_specs = map(parse_spec, in_shapes) in_specs = map(finalize_spec, in_specs, padded_in_shapes) logical_in_shapes = [eval_poly_shape(s, logical_env) for s in in_specs] logical_in_slices = [tuple(map(slice, s)) for s in logical_in_shapes] logical_args = [a[s] for a, s in zip(padded_args, logical_in_slices)] logical_outs_expected, logical_outs_tree = tree_flatten(fun(*logical_args)) assert outs_tree == logical_outs_tree self.assertAllClose(logical_outs, logical_outs_expected, check_dtypes=True, atol=atol, rtol=rtol) # Check that abstract evaluation works padded_outs_jit, _ = tree_flatten(jit(masked_fun)(padded_args, logical_env)) self.assertAllClose(padded_outs_jit, padded_outs, check_dtypes=True, atol=atol, rtol=rtol)
def test_add(self): self.check(lax.add, ['n', ''], 'n', {'n': 3}, [(4,), ()], ['float_', 'float_'], jtu.rand_default(self.rng())) addvecs = mask(lax.add, in_shapes=['n', 'n'], out_shape='n') x = jnp.array([3, 1, 4, 1, 5, 9]) y = jnp.array([2, 6, 5, 3, 5, 8]) ans = addvecs([x, y], dict(n=3)) expected = np.array([5, 7, 9]) self.assertAllClose(ans[:3], expected, check_dtypes=False) thunk = lambda: addvecs([jnp.arange(5), jnp.arange(6)], dict(n=3)) self.assertRaisesRegex(ShapeError, "", thunk)
def thunk(): mask(p.bind, ['n'], 'n')([np.arange(3)], {'n': 2})
def run(): """Runs basic example.""" # Creating graph tuples. # Creates a GraphsTuple from scratch containing a single graph. # The graph has 3 nodes and 2 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. single_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes=np.ones((3, 4)), edges=np.ones((2, 5)), globals=np.ones((1, 6)), senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Single graph %r", single_graph) # Creates a GraphsTuple from scatch containing a single graph with nested # feature vectors. # The graph has 3 nodes and 2 edges. # The feature vector can be arbitrary nested types of dict, list and tuple, # or any other type you registered with jax.tree_util.register_pytree_node. nested_graph = jraph.GraphsTuple(n_node=np.asarray([3]), n_edge=np.asarray([2]), nodes={"a": np.ones((3, 4))}, edges={"b": np.ones((2, 5))}, globals={"c": np.ones((1, 6))}, senders=np.array([0, 1]), receivers=np.array([2, 2])) logging.info("Nested graph %r", nested_graph) # Creates a GraphsTuple from scratch containing a 2 graphs using an implicit # batch dimension. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. implicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([3, 1]), n_edge=np.asarray([2, 1]), nodes=np.ones((4, 4)), edges=np.ones((3, 5)), globals=np.ones((2, 6)), senders=np.array([0, 1, 3]), receivers=np.array([2, 2, 3])) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Creates a GraphsTuple from two existing GraphsTuple using an implicit # batch dimension. # The GraphsTuple will contain three graphs. implicitly_batched_graph = jraph.batch( [single_graph, implicitly_batched_graph]) logging.info("Implicitly batched graph %r", implicitly_batched_graph) # Creates multiple GraphsTuples from an existing GraphsTuple with an implicit # batch dimension. graph_1, graph_2, graph_3 = jraph.unbatch(implicitly_batched_graph) logging.info("Unbatched graphs %r %r %r", graph_1, graph_2, graph_3) # Creates a padded GraphsTuple from an existing GraphsTuple. # The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs. # Three graphs are added for the padding. # First an dummy graph which contains the padding nodes and edges and secondly # two empty graphs without nodes or edges to pad out the graphs. padded_graph = jraph.pad_with_graphs(single_graph, n_node=10, n_edge=5, n_graph=4) logging.info("Padded graph %r", padded_graph) # Creates a GraphsTuple from an existing padded GraphsTuple. # The previously added padding is removed. single_graph = jraph.unpad_with_graphs(padded_graph) logging.info("Unpadded graph %r", single_graph) # Creates a GraphsTuple containing a 2 graphs using an explicit batch # dimension. # An explicit batch dimension requires more memory, but can simplify # the definition of functions operating on the graph. # Explicitly batched graphs require the GraphNetwork to be transformed # by jax.mask followed by jax.vmap. # Using an explicit batch requires padding all feature vectors to # the maximum size of nodes and edges. # The first graph has 3 nodes and 2 edges. # The second graph has 1 nodes and 1 edges. # Each node has a 4-dimensional feature vector. # Each edge has a 5-dimensional feature vector. # The graph itself has a 6-dimensional feature vector. explicitly_batched_graph = jraph.GraphsTuple(n_node=np.asarray([[3], [1]]), n_edge=np.asarray([[2], [1]]), nodes=np.ones((2, 3, 4)), edges=np.ones((2, 2, 5)), globals=np.ones((2, 1, 6)), senders=np.array([[0, 1], [0, -1]]), receivers=np.array([[2, 2], [0, -1]])) logging.info("Explicitly batched graph %r", explicitly_batched_graph) # Running a graph propagation steps. # First define the update functions for the edges, nodes and globals. # In this example we use the identity everywhere. # For Graph neural networks, each update function is typically a neural # network. def update_edge_fn(edge_features, sender_node_features, receiver_node_features, globals_): """Returns the update edge features.""" del sender_node_features del receiver_node_features del globals_ return edge_features def update_node_fn(node_features, aggregated_sender_edge_features, aggregated_receiver_edge_features, globals_): """Returns the update node features.""" del aggregated_sender_edge_features del aggregated_receiver_edge_features del globals_ return node_features def update_globals_fn(aggregated_node_features, aggregated_edge_features, globals_): del aggregated_node_features del aggregated_edge_features return globals_ # Optionally define custom aggregation functions. # In this example we use the defaults (so no need to define them explicitly). aggregate_edges_for_nodes_fn = jax.ops.segment_sum aggregate_nodes_for_globals_fn = jax.ops.segment_sum aggregate_edges_for_globals_fn = jax.ops.segment_sum # Optionally define attention logit function and attention reduce function. # This can be used for graph attention. # The attention function calculates attention weights, and the apply # attention function calculates the new edge feature given the weights. # We don't use graph attention here, and just pass the defaults. attention_logit_fn = None attention_reduce_fn = None # Creates a new GraphNetwork in its most general form. # Most of the arguments have defaults and can be omitted if a feature # is not used. # There are also predefined GraphNetworks available (see models.py) network = jraph.GraphNetwork( update_edge_fn=update_edge_fn, update_node_fn=update_node_fn, update_global_fn=update_globals_fn, attention_logit_fn=attention_logit_fn, aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn, aggregate_nodes_for_globals_fn=aggregate_nodes_for_globals_fn, aggregate_edges_for_globals_fn=aggregate_edges_for_globals_fn, attention_reduce_fn=attention_reduce_fn) # Runs graph propagation on (implicitly batched) graphs. updated_graph = network(single_graph) logging.info("Updated graph from single graph %r", updated_graph) updated_graph = network(nested_graph) logging.info("Updated graph from nested graph %r", nested_graph) updated_graph = network(implicitly_batched_graph) logging.info("Updated graph from implicitly batched graph %r", updated_graph) updated_graph = network(padded_graph) logging.info("Updated graph from padded graph %r", updated_graph) # Runs graph propagation on an explicitly batched graph. # WARNING: This code relies on an undocumented JAX feature (jax.mask) which # might stop working at any time! graph_shape = jraph.GraphsTuple( n_node="(g)", n_edge="(g)", nodes="(n, {})".format(explicitly_batched_graph.nodes.shape[-1]), edges="(e, {})".format(explicitly_batched_graph.edges.shape[-1]), globals="(g, {})".format(explicitly_batched_graph.globals.shape[-1]), senders="(e)", receivers="(e)") batch_size = explicitly_batched_graph.globals.shape[0] logical_env = { "g": jnp.ones(batch_size, dtype=jnp.int32), "n": jnp.sum(explicitly_batched_graph.n_node, axis=-1), "e": jnp.sum(explicitly_batched_graph.n_edge, axis=-1) } try: propagation_fn = jax.vmap( jax.mask(network, in_shapes=[graph_shape], out_shape=graph_shape)) updated_graph = propagation_fn([explicitly_batched_graph], logical_env) logging.info("Updated graph from explicitly batched graph %r", updated_graph) except Exception: # pylint: disable=broad-except logging.warning(MASK_BROKEN_MSG) # JIT-compile graph propagation. # Use padded graphs to avoid re-compilation at every step! jitted_network = jax.jit(network) updated_graph = jitted_network(padded_graph) logging.info("(JIT) updated graph from padded graph %r", updated_graph) # Or use an explicit batch dimension. try: jitted_propagation_fn = jax.jit(propagation_fn) updated_graph = jitted_propagation_fn([explicitly_batched_graph], logical_env) logging.info("(JIT) Updated graph from explicitly batched graph %r", updated_graph) except Exception: # pylint: disable=broad-except logging.warning(MASK_BROKEN_MSG) logging.info("basic.py complete!")