def main(_): logging.set_verbosity(FLAGS.log_level) if FLAGS.enable_eager_execution: tf.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) tf.config.optimizer.set_experimental_options( {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host}) tf.config.optimizer.set_experimental_options( {'layout_optimizer': FLAGS.tf_opt_layout}) _setup_gin() if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU') # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trainer_lib.log('Using --output_dir %s' % output_dir) output_dir = os.path.expanduser(output_dir) # If on TPU, let JAX know. if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') trainer_lib.train(output_dir=output_dir)
def test_train_restart(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate trainer_lib.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Restart training state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, train_steps=(2 * train_steps), eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, 2 * train_steps)
def test_inits_policy_by_world_model_checkpoint(self): transformer_kwargs = { "d_model": 1, "d_ff": 1, "n_layers": 1, "n_heads": 1, "max_len": 128, "mode": "train", } rng = jax_random.PRNGKey(123) init_kwargs = { "input_shapes": (1, 1), "input_dtype": np.int32, "rng": rng, } model_fn = functools.partial(models.TransformerLM, vocab_size=4, **transformer_kwargs) output_dir = self.get_temp_dir() # Initialize a world model checkpoint by running the trainer. trainer_lib.train( output_dir, model=model_fn, inputs=functools.partial(inputs.random_inputs, input_shape=(1, 1), output_shape=(1, 1)), train_steps=1, eval_steps=1, ) policy = ppo.policy_and_value_net( n_actions=3, n_controls=2, vocab_size=4, bottom_layers_fn=functools.partial(models.TransformerDecoder, **transformer_kwargs), two_towers=False, ) (policy_params, policy_state) = policy.initialize_once(**init_kwargs) # Initialize policy parameters from world model parameters. new_policy_params = ppo.init_policy_from_world_model_checkpoint( policy_params, output_dir) # Try to run the policy with new parameters. observations = np.zeros((1, 100), dtype=np.int32) policy(observations, params=new_policy_params, state=policy_state, rng=rng)
def test_train_eval_predict_sm3(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps, optimizer=trax_opt.SM3) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get('train', 'metrics/accuracy') eval_acc = state.history.get('eval', 'metrics/accuracy') self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], params=state.opt_state.params)
def test_inits_policy_by_world_model_checkpoint(self): transformer_kwargs = { 'd_model': 1, 'd_ff': 1, 'n_layers': 1, 'n_heads': 1, 'max_len': 128, 'mode': 'train', } rng = jax_random.PRNGKey(123) model_fn = functools.partial(models.TransformerLM, vocab_size=4, **transformer_kwargs) output_dir = self.get_temp_dir() # Initialize a world model checkpoint by running the trainer. trainer_lib.train( output_dir, model=model_fn, inputs=functools.partial(inputs.random_inputs, input_shape=(1, 1), output_shape=(1, 1)), train_steps=1, eval_steps=1, ) policy = ppo.policy_and_value_net( n_actions=3, n_controls=2, vocab_size=4, bottom_layers_fn=functools.partial(models.TransformerDecoder, **transformer_kwargs), two_towers=False, ) input_signature = ShapeDtype((1, 1), np.int32) policy._set_rng(rng) (policy_params, policy_state) = policy.initialize_once(input_signature) # Initialize policy parameters from world model parameters. new_policy_params = ppo.init_policy_from_world_model_checkpoint( policy_params, output_dir) # Try to run the policy with new parameters. observations = np.zeros((1, 100), dtype=np.int32) policy(observations, weights=new_policy_params, state=policy_state, rng=rng)
def main(_): logging.set_verbosity(FLAGS.log_level) if FLAGS.enable_eager_execution: tf.compat.v1.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile) tf.config.optimizer.set_experimental_options( {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host} ) tf.config.optimizer.set_experimental_options( {'layout_optimizer': FLAGS.tf_opt_layout} ) set_tf_allow_float64(FLAGS.tf_allow_float64) _setup_gin() if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], 'GPU') # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trainer_lib.log('Using --output_dir %s' % output_dir) output_dir = os.path.expanduser(output_dir) # If on TPU, let JAX know. if FLAGS.use_tpu: jax.config.update('jax_platform_name', 'tpu') jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend) jax.config.update('jax_backend_target', FLAGS.jax_backend_target) if FLAGS.use_tpu and backend.get_name() == 'tf': worker_cpu = tf_init_tpu() with tf.device(worker_cpu): if trainer_lib.num_devices() == 1: # TF's device priority is GPU > CPU > TPU, so we need to explicitly make # the TPU core the default device here. with tf.device('/device:TPU:0'): trainer_lib.train(output_dir=output_dir) else: trainer_lib.train(output_dir=output_dir) else: trainer_lib.train(output_dir=output_dir) trainer_lib.log('Finished training.')
def _test_train_eval_predict(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == 'tf': self.skipTest("tf-numpy backend does't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 # Adds Dropout and BatchNorm to test state handling. def model_fn(mode='train'): return layers.Serial( layers.Dropout(mode=mode, rate=0.1), layers.BatchNorm(mode=mode), models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode)) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trainer_lib.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get('train', 'metrics/accuracy') eval_acc = state.history.get('eval', 'metrics/accuracy') self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], params=state.opt_state.params)
def train_model(self): """Train the model. Returns: whether the training was skipped due to a restart. """ logging.info('SimPLe epoch [% 6d]: training model.', self._simple_epoch) start_time = time.time() (train_stream, eval_stream) = self._make_input_streams() # Ignore n_devices for now. inputs = lambda _: trax_inputs.Inputs( # pylint: disable=g-long-lambda train_stream=(lambda: train_stream), train_eval_stream=(lambda: train_stream), eval_stream=(lambda: eval_stream), input_shape=self._sim_env.model_input_shape, input_dtype=self._sim_env.model_input_dtype, # TODO(lukaszkaiser): correct those, they may differ from inputs. target_shape=self._sim_env.model_input_shape, target_dtype=self._sim_env.model_input_dtype) if self._simple_epoch == 0: train_steps = self._n_model_initial_train_steps else: train_steps = self._n_model_train_steps_per_epoch self._model_train_step += train_steps with gin.config_scope('world_model'): state = trainer_lib.train( model=self._sim_env.model, inputs=inputs, train_steps=self._model_train_step, output_dir=self._model_dir, has_weights=True, ) logging.vlog(1, 'Training model took %0.2f sec.', time.time() - start_time) return state.step > self._model_train_step