def _init(rng): # Set up a dummy stochastic scope for random perturbations. with flax.nn.stochastic(jax.random.PRNGKey(0)): ex = graph_bundle.zeros_like_padded_example(padding_config) ex = jax.tree_map(jnp.array, ex) _, initial_params = model_def.init(rng, ex) return initial_params
def test_variants_from_edges(self): example = graph_bundle.zeros_like_padded_example( graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=5, num_input_tagged_nodes=0), max_initial_transitions=0, max_in_tagged_transitions=0, max_edges=8)) example = dataclasses.replace( example, graph_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=4, num_input_tagged_nodes=0), edges=sparse_operator.SparseCoordOperator( input_indices=jnp.array([[0], [0], [0], [1], [1], [2], [0], [0]]), output_indices=jnp.array([[1, 2], [2, 3], [3, 0], [2, 0], [0, 2], [0, 3], [0, 0], [0, 0]]), values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0]))) weights = edge_supervision_models.variants_from_edges( example, automaton_builder.EncodedGraphMetadata( num_nodes=5, num_input_tagged_nodes=0), variant_edge_type_indices=[2, 0], num_edge_types=3) expected = np.array([ [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]], [[1, 0, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1]], [[0, 0, 1], [1, 0, 0], [1, 0, 0], [1, 0, 0]], ], np.float32) # Only assert on the non-padded part. np.testing.assert_allclose(weights[:4, :4], expected)
def zeros_like_padded_example(config): """Builds a VarMisuseExample containing only zeros. This can be useful to initialize model parameters, or do tests. Args: config: Configuration specifying the desired padding size. Returns: An "example" filled with zeros of the given size. """ return VarMisuseExample( input_graph=GraphBundleWithTokens( bundle=graph_bundle.zeros_like_padded_example( config.bundle_padding), tokens=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_tokens, 1), dtype=np.int32), output_indices=np.zeros(shape=(config.max_tokens, 1), dtype=np.int32), values=np.zeros(shape=(config.max_tokens, ), dtype=np.int32))), bug_node_index=-1, repair_node_mask=np.zeros( shape=(config.bundle_padding.static_max_metadata.num_nodes, ), dtype=np.float32), candidate_node_mask=np.zeros( shape=(config.bundle_padding.static_max_metadata.num_nodes, ), dtype=np.float32), unique_candidate_operator=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_tokens, 1), dtype=np.int32), output_indices=np.zeros(shape=(config.max_tokens, 1), dtype=np.int32), values=np.zeros(shape=(config.max_tokens, ), dtype=np.float32)), repair_id=0)
def _make_example(self): example = graph_bundle.zeros_like_padded_example( graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=5, num_input_tagged_nodes=0), max_initial_transitions=0, max_in_tagged_transitions=0, max_edges=8)) example = dataclasses.replace( example, graph_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=4, num_input_tagged_nodes=0), edges=sparse_operator.SparseCoordOperator( input_indices=jnp.array([[0], [0], [0], [0], [1], [2], [2], [0]]), output_indices=jnp.array([[1, 2], [2, 3], [2, 2], [3, 0], [0, 2], [0, 3], [0, 0], [0, 0]]), values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0]))) return example
def test_component_shapes(self, component, embed_edges, expected_dims, extra_config=None): gin.clear_config() gin.parse_config(CONFIG) if extra_config: gin.parse_config(extra_config) # Run the computation with placeholder inputs. (node_out, edge_out), _ = end_to_end_stack.ALL_COMPONENTS[component].init( jax.random.PRNGKey(0), graph_context=end_to_end_stack.SharedGraphContext( bundle=graph_bundle.zeros_like_padded_example( graph_bundle.PaddingConfig( static_max_metadata=automaton_builder. EncodedGraphMetadata(num_nodes=16, num_input_tagged_nodes=32), max_initial_transitions=11, max_in_tagged_transitions=12, max_edges=13)), static_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=32), edge_types_to_indices={"foo": 0}, builder=automaton_builder.AutomatonBuilder({ graph_types.NodeType("node"): graph_types.NodeSchema( in_edges=[graph_types.InEdgeType("in")], out_edges=[graph_types.InEdgeType("out")]) }), edges_are_embedded=embed_edges), node_embeddings=jnp.zeros((16, NODE_DIM)), edge_embeddings=jnp.zeros((16, 16, EDGE_DIM))) self.assertEqual(node_out.shape, (16, expected_dims["node"])) self.assertEqual(edge_out.shape, (16, 16, expected_dims["edge"]))
def test_zeros_like_padded_example(self): tree = gast.parse("pass") py_graph, _ = py_ast_graphs.py_ast_to_graph(tree) example = graph_bundle.convert_graph_with_edges( py_graph, [], builder=py_ast_graphs.BUILDER) padding_config = graph_bundle.PaddingConfig( static_max_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=16, num_input_tagged_nodes=34), max_initial_transitions=64, max_in_tagged_transitions=128, max_edges=4) padded_example = graph_bundle.pad_example(example, padding_config) generated = graph_bundle.zeros_like_padded_example(padding_config) def _check(x, y): x = np.asarray(x) y = np.asarray(y) self.assertEqual(x.shape, y.shape) self.assertEqual(x.dtype, y.dtype) jax.tree_multimap(_check, generated, padded_example)
def train( runner, dataset_paths=gin.REQUIRED, prefetch=4, batch_size_per_device=gin.REQUIRED, validation_example_count=gin.REQUIRED, ): """Train the maze automaton. Args: runner: Helper object that runs the experiment. dataset_paths: Dictionary of dataset paths, with keys: - "train_dataset": Path to training dataset files. - "eval_dataset": Path to validation dataset files. prefetch: Maximum number of examples to prefetch in a background thread. batch_size_per_device: Batch size for each device. validation_example_count: How many examples to use when computing validation metrics. Returns: Optimizer at the end of training (for interactive debugging). """ num_devices = jax.local_device_count() logging.info("Found %d devices: %s", num_devices, jax.devices()) with contextlib.ExitStack() as exit_stack: logging.info("Setting up datasets...") raw_train_iterator = runner.build_sampling_iterator( dataset_paths["train_dataset"], example_type=graph_bundle.GraphBundle) raw_valid_iterator_factory = runner.build_one_pass_iterator_factory( dataset_paths["eval_dataset"], example_type=graph_bundle.GraphBundle, truncate_at=validation_example_count) # Add the example id into the example itself, so that we can use it to # randomly choose a goal. def reify_id(it): for item in it: yield dataclasses.replace(item, example=(item.example, item.example_id)) def reify_id_and_batch(it): return data_loading.batch(reify_id(it), (num_devices, batch_size_per_device), remainder_behavior=data_loading. BatchRemainderBehavior.PAD_ZERO) train_iterator = reify_id_and_batch(raw_train_iterator) valid_iterator_factory = ( lambda: reify_id_and_batch(raw_valid_iterator_factory())) if prefetch: train_iterator = exit_stack.enter_context( data_loading.ThreadedPrefetcher(train_iterator, prefetch)) logging.info("Setting up model...") padding_config = maze_task.PADDING_CONFIG model_def = automaton_layer.FiniteStateGraphAutomaton.partial( static_metadata=padding_config.static_max_metadata, builder=maze_task.BUILDER) # Initialize parameters randomly. _, initial_params = model_def.init( jax.random.PRNGKey(int(time.time() * 1000)), graph_bundle.zeros_like_padded_example( padding_config).automaton_graph, dynamic_metadata=padding_config.static_max_metadata) model = flax.nn.Model(model_def, initial_params) optimizer = flax.optim.Adam().create(model) extra_artifacts = { "builder.pickle": maze_task.BUILDER, } return runner.training_loop( optimizer=optimizer, train_iterator=train_iterator, loss_fn=loss_fn, validation_fn=train_util.build_averaging_validator( loss_fn, valid_iterator_factory), extra_artifacts=extra_artifacts)