def get_flow_inference_function( checkpoint, height, width): """Restores a raft model from a checkpoint and returns the inference function. Args: checkpoint: Path to the checkpoint that will be used. height: Image height that should be used for inference. width: Image width that will be used for inference. Returns: Inference function of the restored model. """ tf.keras.backend.clear_session() gin.parse_config('raft_model_parameters.max_rec_iters = 32') smurf = SMURFNet( checkpoint, flow_architecture='raft', feature_architecture='raft') smurf.restore() return functools.partial( smurf.infer_no_tf_function, input_height=height, input_width=width, resize_flow_to_img_res=True, infer_occlusion=False, infer_bw=False)
def test_with_mock_training(self): model_dir = self.create_tempdir().full_path mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor, device_type='tpu', use_avg_model_params=True) mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE) export_dir = os.path.join(model_dir, _EXPORT_DIR) hook_builder = async_export_hook_builder.AsyncExportHookBuilder( export_dir=export_dir, create_export_fn=async_export_hook_builder.default_create_export_fn ) gin.parse_config('tf.contrib.tpu.TPUConfig.iterations_per_loop=1') gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1') # We optimize our network. train_eval.train_eval_model(t2r_model=mock_t2r_model, input_generator_train=mock_input_generator, train_hook_builders=[hook_builder], model_dir=model_dir, max_train_steps=_MAX_STEPS) self.assertNotEmpty(tf.io.gfile.listdir(model_dir)) self.assertNotEmpty(tf.io.gfile.listdir(export_dir)) for exported_model_dir in tf.io.gfile.listdir(export_dir): self.assertNotEmpty( tf.io.gfile.listdir( os.path.join(export_dir, exported_model_dir))) predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor( export_dir=export_dir) self.assertTrue(predictor.restore())
def run(agent, game, num_steps, root_dir, restore_ckpt, use_legacy_checkpoint=False): """Main entrypoint for running and generating visualizations. Args: agent: str, agent type to use. game: str, Atari 2600 game to run. num_steps: int, number of steps to play game. root_dir: str, root directory where files will be stored. restore_ckpt: str, path to the checkpoint to reload. use_legacy_checkpoint: bool, whether to restore from a legacy (pre-Keras) checkpoint. """ tf.compat.v1.reset_default_graph() config = """ atari_lib.create_atari_environment.game_name = '{}' WrappedReplayBuffer.replay_capacity = 300 """.format(game) base_dir = os.path.join(root_dir, 'agent_viz', game, agent) gin.parse_config(config) runner = create_runner(base_dir, restore_ckpt, agent, use_legacy_checkpoint) runner.visualize(os.path.join(base_dir, 'images'), num_global_steps=num_steps)
def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ gin.parse_config([ 'batch_fn.batch_size_per_device = 256', 'batch_fn.eval_batch_size = 256', ]) mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.AccuracyScalar()], names=['CrossEntropyLoss', 'AccuracyScalar'], eval_at=lambda step_n: step_n % 50 == 0, eval_N=10) training_session = training.Loop(mnist_model, task, eval_task=eval_task) training_session.run(n_steps=1000) self.assertEqual(training_session.current_step(), 1000)
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None, split="validation"): """Evaluate the model on the given Mixture or Task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` checkpoint_steps: int, list of ints, or None. If an int or list of ints, evaluation will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run eval continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the mixture/task split to evaluate on. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) vocabulary = t5.models.mesh_transformer.get_vocabulary( mixture_or_task_name) dataset_fn = functools.partial( t5.models.mesh_transformer.mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name, ) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.parse_config(self._gin_bindings) utils.eval_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, split, self._model_dir, dataset_fn, summary_dir, checkpoint_steps)
def _worker(self, root_dir, parameters, device_queue): # sleep for random seconds to avoid crowded launching try: time.sleep(random.uniform(0, 3)) device = device_queue.get() if self._conf.use_gpu: os.environ["CUDA_VISIBLE_DEVICES"] = str(device) else: os.environ["CUDA_VISIBLE_DEVICES"] = "" # run on cpu from alf.utils.common import set_per_process_memory_growth set_per_process_memory_growth() logging.set_verbosity(logging.INFO) logging.info("parameters %s" % parameters) with gin.unlock_config(): gin.parse_config( ['%s=%s' % (k, v) for k, v in parameters.items()]) train_eval(root_dir) device_queue.put(device) except Exception as e: logging.info(e) raise e
def parse_gin_config(self, ckpt): """Parse the model operative config with special streaming parameters.""" with gin.unlock_config(): ckpt_dir = os.path.dirname(ckpt) operative_config = train_util.get_latest_operative_config(ckpt_dir) print(f'Parsing from operative_config {operative_config}') gin.parse_config_file(operative_config, skip_unknown=True) # Set streaming specific params. # Remove reverb processor. pg_string = """ProcessorGroup.dag = [ (@synths.Harmonic(), ['amps', 'harmonic_distribution', 'f0_hz']), (@synths.FilteredNoise(), ['noise_magnitudes']), (@processors.Add(), ['filtered_noise/signal', 'harmonic/signal']), ]""" time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps') n_samples = gin.query_parameter('Harmonic.n_samples') samples_per_frame = int(n_samples / time_steps) gin.parse_config([ 'F0PowerPreprocessor.time_steps=1', f'Harmonic.n_samples={samples_per_frame}', f'FilteredNoise.n_samples={samples_per_frame}', pg_string, ])
def load_model(instrument_model, audio_length): # Build checkpoint path # Assumes only one checkpoint in the folder, 'model.ckpt-[iter]`. model_dir = os.path.join(CKPT_DIR, "solo_%s_ckpt" % instrument_model.lower()) ckpt_files = [ f for f in tf.gfile.ListDirectory(model_dir) if "model.ckpt" in f ] ckpt_name = ".".join(ckpt_files[0].split(".")[:2]) ckpt = os.path.join(model_dir, ckpt_name) # Parse gin config with gin.unlock_config(): gin_file = os.path.join(model_dir, "operative_config-0.gin") gin.parse_config_file(gin_file, skip_unknown=True) # Ensure dimensions sampling rates are equal time_steps_train = gin.query_parameter("DefaultPreprocessor.time_steps") n_samples_train = gin.query_parameter("Additive.n_samples") hop_size = int(n_samples_train / time_steps_train) time_steps = int(audio_length / hop_size) n_samples = time_steps * hop_size gin_params = [ "Additive.n_samples = {}".format(n_samples), "FilteredNoise.n_samples = {}".format(n_samples), "DefaultPreprocessor.time_steps = {}".format(time_steps), ] with gin.unlock_config(): gin.parse_config(gin_params) return ckpt, time_steps, n_samples
def configure_gin(self, ckpt): """Parse the model operative config with special streaming parameters.""" parse_operative_config(ckpt) # Set streaming specific params. time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps') n_samples = gin.query_parameter('Harmonic.n_samples') samples_per_frame = int(n_samples / time_steps) config = [ 'F0PowerPreprocessor.time_steps = 1', f'Harmonic.n_samples = {samples_per_frame}', f'FilteredNoise.n_samples = {samples_per_frame}', ] # Remove reverb processor. processor_group_string = """ProcessorGroup.dag = [ (@synths.Harmonic(), ['amps', 'harmonic_distribution', 'f0_hz']), (@synths.FilteredNoise(), ['noise_magnitudes']), (@processors.Add(), ['filtered_noise/signal', 'harmonic/signal']), ]""" config.append(processor_group_string) with gin.unlock_config(): gin.parse_config(config)
def test_build_layer(self, kwarg_modules): """Tests if layer builds properly and produces outputs of correct shape.""" gin_config = (self.gin_config_kwarg_modules if kwarg_modules else self.gin_config_dag_modules) with gin.unlock_config(): gin.clear_config() gin.parse_config(gin_config) dag_layer = ConfigurableDAGLayer() outputs = dag_layer(self.inputs) self.assertIsInstance(outputs, dict) z = outputs['bottleneck']['z_bottleneck'] x_rec = outputs['decoder']['reconstruction'] x_rec2 = outputs['out']['reconstruction'] # Confirm that layer generates correctly sized tensors. self.assertEqual(outputs['test_data'].shape, self.x.shape) self.assertEqual(outputs['inputs']['test_data'].shape, self.x.shape) self.assertEqual(x_rec.shape, self.x.shape) self.assertEqual(z.shape[-1], self.z_dims) self.assertAllClose(x_rec, x_rec2) # Confirm that variables are inherited by DAGLayer. self.assertLen(dag_layer.trainable_variables, 6) # 3 weights, 3 biases.
def test_l1_attack(self, attack_name, random_start): num_iter, step_size, epsilon, percentile = 4, 1.0, 2.5, 99 gin.parse_config([ f"attacks.l1_config.num_iter = {num_iter}", f"attacks.l1_config.step_size = {step_size}", f"attacks.l1_config.epsilon = {epsilon}", f"attacks.l1_config.percentile = {percentile}", "attacks.union_config.restart = 5", ]) x = tf.random.uniform(shape=self.batched_input_shape) y = tf.random.categorical( tf.zeros([self.batch_size, self.num_classes]), 1) attack = attacks.construct_attack(attack_name) adv_x = attack.attack(tf.constant(x), tf.constant(y), self.model, self.loss_fn, random_start=random_start) diff = tf.reshape(adv_x - x, (self.batch_size, -1)).numpy() l1_norm = np.linalg.norm(diff, ord=1, axis=-1) self.assertAllLessEqual(l1_norm, epsilon + 1e-5) touched = np.count_nonzero(diff, axis=-1) self.assertAllLessEqual(touched, diff.shape[1] * num_iter * percentile / 100)
def test_train(self): gin.parse_config([ "data.preprocess_image.height = 28", "data.preprocess_image.width = 28", "data.preprocess_image.num_channels = 1", "data.get_test_dataset.batch_size = 1", "data.get_test_dataset.dataset = 'mnist'", "data.get_training_dataset.batch_size = 1", "data.get_training_dataset.dataset = 'mnist'", "data.get_training_dataset.shuffle_buffer_size = 1", "data.get_validation_dataset.batch_size = 1", "data.get_validation_dataset.dataset = 'mnist'", "data.get_validation_dataset.split = '2'", "resnet.build_resnet_v1.input_shape = (28, 28, 1)", "resnet.build_resnet_v1.depth = 8", "selectors.construct_representation_selector.selection_strategy = 'multiweight'", "selectors.construct_representation_selector.sample_freq = 1", "selectors.construct_representation_selector.update_freq = 1", "trainer.train.epochs = 2", "trainer.train.steps_per_epoch = 1", "trainer.train.representation_list = [('identity', 'l2'), ('dct', 'l2')]", ]) with tfds.testing.mock_data(num_examples=10): trainer.train(self.ckpt_dir.full_path, self.summary_dir.full_path) ckpt_path = os.path.join(self.ckpt_dir, "ckpt-2") self.assertTrue(tf.io.gfile.exists(ckpt_path + ".index")) variables = [ name for name, shape in tf.train.list_variables(ckpt_path) ] self.assertTrue(any(name.startswith("model") for name in variables)) self.assertTrue(any(name.startswith("selector") for name in variables))
def configure_gin(self, ckpt): """Parse the model operative config with special streaming parameters.""" parse_operative_config(ckpt) # Set streaming specific params. preprocessor_ref = gin.query_parameter('Autoencoder.preprocessor') preprocessor_str = preprocessor_ref.scoped_selector time_steps = gin.query_parameter(f'{preprocessor_str}.time_steps') n_samples = gin.query_parameter('Harmonic.n_samples') if not isinstance(n_samples, int): n_samples = gin.query_parameter('%n_samples') samples_per_frame = int(n_samples / time_steps) config = [ 'Autoencoder.preprocessor = @F0PowerPreprocessor()', 'F0PowerPreprocessor.time_steps = 1', f'Harmonic.n_samples = {samples_per_frame}', f'FilteredNoise.n_samples = {samples_per_frame}', ] # Remove reverb and crop processors. processor_group_string = """ProcessorGroup.dag = [ (@synths.Harmonic(), ['amps', 'harmonic_distribution', 'f0_hz']), (@synths.FilteredNoise(), ['noise_magnitudes']), (@processors.Add(), ['filtered_noise/signal', 'harmonic/signal']), ]""" config.append(processor_group_string) with gin.unlock_config(): gin.parse_config(config)
def __init__( self, env_class, agent_class, network_fn, model_class, model_network_fn, config, init_hooks, ): # Limit number of threads used between independent tf.op-s to 1. import tensorflow as tf # pylint: disable=import-outside-toplevel tf.config.threading.set_inter_op_parallelism_threads(1) tf.config.threading.set_intra_op_parallelism_threads(1) gin.parse_config(config, skip_unknown=True) for hook in init_hooks: hook() self.env = env_class() self.agent = (agent_class() if model_class is None else agent_class(model_class=model_class)) # Metrics cause some problems with Ray, so we switch them off, # as we don't train any networks inside the worker. if network_fn: network_fn = functools.partial(network_fn, metrics=None) if model_network_fn: model_network_fn = functools.partial(model_network_fn, metrics=None) self._request_handler = core.RequestHandler( network_fn, model_network_fn=model_network_fn)
def estimator(self, vocabulary, init_checkpoint=None, disable_tpu=False, score_in_predict_mode=False): if not self._tpu or disable_tpu: with gin.unlock_config(): gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32") gin.bind_parameter( "utils.get_variable_dtype.activation_dtype", "float32") with gin.unlock_config(): gin.parse_config(self._gin_bindings) return utils.get_estimator( model_type=self._model_type, vocabulary=vocabulary, layout_rules=self._layout_rules, mesh_shape=mtf.Shape([]) if disable_tpu else self._mesh_shape, mesh_devices=None if disable_tpu else self._mesh_devices, model_dir=self._model_dir, batch_size=self.batch_size, sequence_length=self._sequence_length, autostack=self._autostack, learning_rate_schedule=self._learning_rate_schedule, keep_checkpoint_max=self._keep_checkpoint_max, save_checkpoints_steps=self._save_checkpoints_steps, optimizer=self._optimizer, predict_fn=self._predict_fn, variable_filter=self._variable_filter, ensemble_inputs=self._ensemble_inputs, use_tpu=None if disable_tpu else self._tpu, tpu_job_name=self._tpu_job_name, iterations_per_loop=self._iterations_per_loop, cluster=self._cluster, init_checkpoint=init_checkpoint, score_in_predict_mode=score_in_predict_mode)
def test_singletons(self): @gin.configurable class Champ(object): count = 0 def __init__(self): Champ.count += 1 config = ''' chuck_norris/singleton.constructor = @Champ f.x = @chuck_norris/singleton() g.z = @chuck_norris/singleton() ''' gin.parse_config(config) self.assertEqual(Champ.count, 0) f() self.assertEqual(Champ.count, 1) g() self.assertEqual(Champ.count, 1) with GinState(copy_state=True): f() self.assertEqual(Champ.count, 1) with GinState(): gin.parse_config(config) f() self.assertEqual(Champ.count, 2)
def _worker(self, root_dir, parameters, device_queue): # sleep for random seconds to avoid crowded launching try: time.sleep(random.uniform(0, 3)) device = device_queue.get() if self._conf.use_gpu: os.environ["CUDA_VISIBLE_DEVICES"] = str(device) else: os.environ["CUDA_VISIBLE_DEVICES"] = "" # run on cpu if torch.cuda.is_available(): alf.set_default_device("cuda") logging.set_verbosity(logging.INFO) logging.info("Search parameters %s" % parameters) with gin.unlock_config(): gin.parse_config( ['%s=%s' % (k, v) for k, v in parameters.items()]) gin.parse_config( "TrainerConfig.confirm_checkpoint_upon_crash=False") train_eval(FLAGS.ml_type, root_dir) device_queue.put(device) except Exception as e: logging.info(traceback.format_exc()) raise e
def main(argv): del argv # Import modules BEFORE running Gin. if FLAGS.import_module: for module_name in FLAGS.import_module: __import__(module_name) # First, try to parse from a config file. if FLAGS.config_file: bindings = None if bindings is None: with tf.io.gfile.GFile(FLAGS.config_file) as f: bindings = f.readlines() bindings = [six.ensure_str(b) for b in bindings if b.strip()] gin.parse_config('\n'.join(bindings)) if FLAGS.params: gin.parse_config(FLAGS.params) if FLAGS.run_functions_eagerly: tf.config.experimental_run_functions_eagerly(True) if not tf.io.gfile.exists(FLAGS.eval_dir): tf.io.gfile.makedirs(FLAGS.eval_dir) evaluation()
def parse_gin(model_dir): """Parse gin config from --gin_file, --gin_param, and the model directory.""" # Add user folders to the gin search path. for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path: gin.add_config_file_search_path(gin_search_path) # Parse gin configs, later calls override earlier ones. with gin.unlock_config(): # Optimization defaults. use_tpu = bool(FLAGS.tpu) opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin' gin.parse_config_file(os.path.join('optimization', opt_default)) # Load operative_config if it exists (model has already trained). operative_config = os.path.join(model_dir, 'operative_config-0.gin') if tf.io.gfile.exists(operative_config): gin.parse_config_file(operative_config, skip_unknown=True) # Only use the custom cumsum for TPUs. gin.parse_config('ddsp.core.cumsum.use_tpu={}'.format(use_tpu)) # User gin config and user hyperparameters from flags. gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param, skip_unknown=True)
def __init__( self, env_class, agent_class, network_fn, config, scope, init_hooks, compress_episodes, ): gin.parse_config(config, skip_unknown=True) for hook in init_hooks: hook() import tensorflow as tf tf.config.threading.set_inter_op_parallelism_threads(1) tf.config.threading.set_intra_op_parallelism_threads(1) with gin.config_scope(scope): self.env = env_class() self.agent = agent_class() self._request_handler = core.RequestHandler(network_fn) self._compress_episodes = compress_episodes
def _test_stability(max_time=5, render=False, test_generator=None): """Tests the stability of the controller using speed profiles.""" locomotion_controller_setup.load_sim_config(render=render) gin.parse_config(SCENARIO_SET_CONFIG) if FLAGS.add_random_push: locomotion_controller_setup.add_random_push_config() env = env_loader.load() controller = locomotion_controller_setup.setup_controller(env.robot, gait=FLAGS.gait) for name, speed_profile in test_generator(): env.reset() controller.reset() current_time = 0 while current_time < max_time: current_time = env.get_time_since_reset() lin_speed, ang_speed = _generate_linear_angular_speed( current_time, speed_profile[0], speed_profile[1]) _update_controller_params(controller, lin_speed, ang_speed) # Needed before every call to get_action(). controller.update() hybrid_action = controller.get_action() _, _, done, _ = env.step(hybrid_action) if done: break print( f"Scene name: flat ground. Random push: {FLAGS.add_random_push}. " f"Survival time for {name} = {speed_profile[1]} is {current_time}")
def test_transformer_steps(self, config, expected_block_count): gin.parse_config(config) _, params = edge_supervision_models.transformer_steps.init( jax.random.PRNGKey(0), node_embeddings=jnp.zeros((5, 3), jnp.float32), edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32), neighbor_mask=jnp.zeros((5, 5), jnp.float32), num_real_nodes_per_graph=4) # This component should contain the right number of blocks. self.assertLen(params, expected_block_count) for block in params.values(): # Each block contains 4 sublayers. self.assertLen(block, 4) # Gradients should work. outs, vjpfun = jax.vjp( functools.partial( edge_supervision_models.transformer_steps.call, node_embeddings=jnp.zeros((5, 3), jnp.float32), edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32), neighbor_mask=jnp.zeros((5, 5), jnp.float32), num_real_nodes_per_graph=4), params, ) vjpfun(outs)
def setUp(self): super().setUp() gin.clear_config() gin.parse_config(GIN_CONFIG) self.addCleanup(mock.patch.stopall) self.mock_load = mock.patch.object( loaders.TFDSLoader, 'load', autospec=True).start()
def test_transformer_steps_masking(self): """Transformer should mask out padding even if not masked to neigbors.""" gin.parse_config( textwrap.dedent("""\ transformer_steps.layers = 1 transformer_steps.share_weights = False transformer_steps.mask_to_neighbors = False NodeSelfAttention.heads = 2 NodeSelfAttention.query_key_dim = 3 NodeSelfAttention.value_dim = 4 """)) with flax.nn.capture_module_outputs() as outputs: edge_supervision_models.transformer_steps.init( jax.random.PRNGKey(0), node_embeddings=jnp.zeros((5, 3), jnp.float32), edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32), neighbor_mask=jnp.zeros((5, 5), jnp.float32), num_real_nodes_per_graph=4) attention_weights, = (v[0] for k, v in outputs.as_dict().items() if k.endswith("attend/attention_weights")) expected = np.array([[[0.25, 0.25, 0.25, 0.25, 0.0]] * 5] * 2) np.testing.assert_allclose(attention_weights, expected)
def testSynchronousTrainCollectEval(self): """End-to-end integration test. """ env = grasping_env.KukaGraspingProceduralEnv(downsample_width=64, downsample_height=64, continuous=True, remove_height_hack=True, render_mode='DIRECT') data_dir = 'testdata' gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'random_collect.gin') # Collect initial data from random policy without training. with open(gin_config, 'r') as f: gin.parse_config(f) train_collect_eval.train_collect_eval(collect_env=env, eval_env=None, test_env=None, root_dir=self._root_dir, train_fn=None) # Run training (synchronous train, collect, & eval). gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'train_dqn.gin') with open(gin_config, 'r') as f: gin.parse_config(f) train_collect_eval.train_collect_eval(collect_env=env, eval_env=None, test_env=None, root_dir=self._root_dir)
def test_nri_steps(self): gin.parse_config( textwrap.dedent("""\ graph_layers.NRIEdgeLayer.allow_non_adjacent = True graph_layers.NRIEdgeLayer.mlp_vtoe_dims = [4, 4] nri_steps.mlp_etov_dims = [8, 8] nri_steps.with_residual_layer_norm = True nri_steps.layers = 3 """)) _, params = edge_supervision_models.nri_steps.init( jax.random.PRNGKey(0), node_embeddings=jnp.zeros((5, 3), jnp.float32), edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32), num_real_nodes_per_graph=4) # This component should contain the right number of blocks. self.assertLen(params, 3) for block in params.values(): # Each block contains 5 sublayers: # - NRI message pass # - Three dense layers (from mlp_etov_dims, then back to embedding space) # - Layer norm self.assertLen(block, 5) # Gradients should work. outs, vjpfun = jax.vjp( functools.partial( edge_supervision_models.nri_steps.call, node_embeddings=jnp.zeros((5, 3), jnp.float32), edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32), num_real_nodes_per_graph=4), params, ) vjpfun(outs)
def testGinConfig(self): batch_size = 3 num_state_dims = 5 action_spec = tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1) num_actions = action_spec.maximum - action_spec.minimum + 1 self.assertEqual(num_actions, 2) observations_spec = tensor_spec.TensorSpec([3, 3, num_state_dims], tf.float32) observations = tf.random.uniform([batch_size, 3, 3, num_state_dims]) next_observations = tf.random.uniform( [batch_size, 3, 3, num_state_dims]) time_steps = ts.restart(observations, batch_size) next_time_steps = ts.restart(next_observations, batch_size) gin.parse_config(""" CategoricalQNetwork.conv_layer_params = [(16, 2, 1), (15, 2, 1)] CategoricalQNetwork.fc_layer_params = [4, 3, 5] """) q_network = categorical_q_network.CategoricalQNetwork( input_tensor_spec=observations_spec, action_spec=action_spec) logits, _ = q_network(time_steps.observation) next_logits, _ = q_network(next_time_steps.observation) self.assertAllEqual(logits.shape.as_list(), [batch_size, num_actions, q_network._num_atoms]) self.assertAllEqual(next_logits.shape.as_list(), [batch_size, num_actions, q_network._num_atoms]) # This time there are six layers: two conv layers, three fc layers, and one # final logits layer, for 12 trainable_variables in total. self.assertLen(q_network.trainable_variables, 12)
def test_ggtnn_steps(self): gin.parse_config( textwrap.dedent("""\ edge_supervision_models.ggnn_steps.iterations = 10 graph_layers.LinearMessagePassing.message_dim = 5 """)) _, params = edge_supervision_models.ggnn_steps.init( jax.random.PRNGKey(0), node_embeddings=jnp.zeros((5, 3), jnp.float32), edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32)) # This component should only contain one step block, with two sublayers. self.assertEqual(set(params.keys()), {"step"}) self.assertLen(params["step"], 2) # Gradients should work. outs, vjpfun = jax.vjp( functools.partial( edge_supervision_models.ggnn_steps.call, node_embeddings=jnp.zeros((5, 3), jnp.float32), edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32)), params, ) vjpfun(outs)
def main(_): np.random.seed(FLAGS.task) tf.set_random_seed(FLAGS.task) if FLAGS.distributed: task = FLAGS.task else: task = 0 if FLAGS.gin_config: if tf.gfile.Exists(FLAGS.gin_config): # Parse as a file. with tf.gfile.Open(FLAGS.gin_config) as f: gin.parse_config(f) else: gin.parse_config(FLAGS.gin_config) gin.finalize() if FLAGS.run_mode == 'collect_eval_once': train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir, train_fn=None, task=FLAGS.task) elif FLAGS.run_mode == 'train_only': train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir, do_collect_eval=False, task=task, master=FLAGS.master, ps_tasks=FLAGS.ps_tasks) elif FLAGS.run_mode == 'collect_eval_loop': raise NotImplementedError('collect_eval_loops') else: # Synchronous train-collect-eval. train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir, task=task)
def parse_args(args=None): gin.parse_config('torchexp.config.manual_seed.seed = %seed') gin.bind_parameter('%seed', None) if args is None: args = sys.argv[1:] for arg in args: try: key, value = arg.split('=', maxsplit=1) except ValueError: raise ValueError(f'The argument `{arg}` is not accepted!' ' All argument should be the form name=value,' ' --yaml=config.yaml or --gin=config.gin') if key == '--yaml': _read_yaml_macros(value) elif key == '--gin': gin.parse_config_file(value) else: if not check_gin_special(value): try: value = literal_eval(value) except (ValueError, SyntaxError): pass value = repr(value) gin.parse_config(f'{key} = {value}') manual_seed()