def main(_): logging.set_verbosity(FLAGS.log_level) if FLAGS.tf_eager: tf.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) tf.config.optimizer.set_experimental_options( {"pin_to_host_optimization": FLAGS.tf_opt_pin_to_host}) tf.config.optimizer.set_experimental_options( {"layout_optimizer": FLAGS.tf_opt_layout}) _setup_gin() if FLAGS.tf_eager and backend.get_name() in ("numpy", "jax"): # Numpy backend doesn't benefit from having the input pipeline run on GPU, # and jax backend has GPU memory contention if TF uses the GPU. Gin must be # set up first before determining the backend. tf.config.experimental.set_visible_devices([], "GPU") # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trax.log("Using --output_dir %s" % output_dir) output_dir = os.path.expanduser(output_dir) # If on TPU, let JAX know. if FLAGS.use_tpu: jax.config.update("jax_platform_name", "tpu") trax.train(output_dir=output_dir)
def test_train_restart(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Restart training state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=(2 * train_steps), eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, 2 * train_steps)
def test_train_restart(self): with self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Restart training state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, 2 * train_steps)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.jax: # Hacking main v1 flags to work with jax. config_strs = [] config_strs.append("train.train_steps=" + str(FLAGS.train_steps)) config_strs.append("train.eval_steps=" + str(FLAGS.eval_steps)) config_strs.append("train.eval_frequency=" + str(FLAGS.local_eval_frequency)) if FLAGS.hparams: config_strs.extend(str(FLAGS.hparams).split(",")) data_dir = os.path.expanduser(FLAGS.data_dir) output_dir = os.path.expanduser(FLAGS.output_dir) gin.bind_parameter("train.dataset", FLAGS.problem) config_strs += ["train.model=@" + FLAGS.model] config_files = [] if FLAGS.hparams_set: config_files = [os.path.expanduser(FLAGS.hparams_set)] gin.parse_config_files_and_bindings(config_files, config_strs) trax.train(data_dir=data_dir, output_dir=output_dir) return usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # If we just have to print the registry, do that and exit early. maybe_log_registry_and_exit() # Create HParams. if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode": mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams) if FLAGS.schedule == "run_std_server": run_std_server() mlperf_log.transformer_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed, hparams=hparams) trainer_lib.set_random_seed(FLAGS.random_seed) if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)
def main(_): logging.set_verbosity(FLAGS.log_level) if FLAGS.tf_eager: tf.enable_eager_execution() if FLAGS.tf_xla: tf.config.optimizer.set_jit(True) tf.config.optimizer.set_experimental_options( {"pin_to_host_optimization": FLAGS.tf_opt_pin_to_host} ) tf.config.optimizer.set_experimental_options( {"layout_optimizer": FLAGS.tf_opt_layout} ) _setup_gin() # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trax.log("Using --output_dir %s" % output_dir) output_dir = os.path.expanduser(output_dir) # If on TPU, let JAX know. if FLAGS.use_tpu: jax.config.update("jax_platform_name", "tpu") trax.train(output_dir=output_dir)
def train_model(self): logging.info("Epoch %d: training model", self._simple_epoch) # Load data from all epochs. # TODO(pkozakowski): Handle the case when the data won't fit in the memory. (train_trajectories, eval_trajectories) = self._load_trajectories( self._trajectory_dump_root_dir) train_stream = lambda: self._data_stream( # pylint: disable=g-long-lambda train_trajectories, self._model_train_batch_size) eval_stream = lambda: self._data_stream( # pylint: disable=g-long-lambda eval_trajectories, self._model_train_batch_size) # Ignore n_devices for now. inputs = lambda _: trax_inputs.Inputs( # pylint: disable=g-long-lambda train_stream=train_stream, train_eval_stream=train_stream, eval_stream=eval_stream, input_shape=self._sim_env.model_input_shape, input_dtype=self._sim_env.model_input_dtype, ) self._model_train_step += self._n_model_train_steps trax.train( model=self._sim_env.model, inputs=inputs, train_steps=self._model_train_step, output_dir=self._model_dir, has_weights=True, )
def main(_): _setup_gin() # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trax.log("Using --output_dir %s" % output_dir) output_dir = os.path.expanduser(output_dir) trax.train(output_dir=output_dir)
def main(_): logging.set_verbosity(FLAGS.log_level) _setup_gin() # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trax.log("Using --output_dir %s" % output_dir) output_dir = os.path.expanduser(output_dir) trax.train(output_dir=output_dir)
def main(_): _setup_gin() # Setup directories data_dir = FLAGS.data_dir output_dir = FLAGS.output_dir or _default_output_dir() assert data_dir, "Must specify a data directory" assert output_dir, "Must specify an output directory" trax.log("Using --output_dir %s" % output_dir) data_dir = os.path.expanduser(data_dir) output_dir = os.path.expanduser(output_dir) trax.train(data_dir=data_dir, output_dir=output_dir)
def test_inits_policy_by_world_model_checkpoint(self): transformer_kwargs = { "d_model": 1, "d_ff": 1, "n_layers": 1, "n_heads": 1, "max_len": 128, "mode": "train", } rng = jax_random.PRNGKey(123) init_kwargs = { "input_shapes": (1, 1), "input_dtype": np.int32, "rng": rng, } model_fn = functools.partial(models.TransformerLM, vocab_size=4, **transformer_kwargs) output_dir = self.get_temp_dir() # Initialize a world model checkpoint by running the trainer. trax.train( output_dir, model=model_fn, inputs=functools.partial(inputs.random_inputs, input_shape=(1, 1), output_shape=(1, 1)), train_steps=1, eval_steps=1, ) policy = ppo.policy_and_value_net( n_actions=3, n_controls=2, vocab_size=4, bottom_layers_fn=functools.partial(models.TransformerDecoder, **transformer_kwargs), two_towers=False, ) (policy_params, policy_state) = policy.initialize_once(**init_kwargs) # Initialize policy parameters from world model parameters. new_policy_params = ppo.init_policy_from_world_model_checkpoint( policy_params, output_dir) # Try to run the policy with new parameters. observations = np.zeros((1, 100), dtype=np.int32) policy(observations, params=new_policy_params, state=policy_state, rng=rng)
def main(_): logging.set_verbosity(FLAGS.log_level) _setup_gin() # Setup output directory output_dir = FLAGS.output_dir or _default_output_dir() trax.log("Using --output_dir %s" % output_dir) output_dir = os.path.expanduser(output_dir) # If on TPU, let JAX know. if FLAGS.use_tpu: jax.config.update("jax_platform_name", "tpu") trax.train(output_dir=output_dir)
def test_train_eval_predict(self): with self.tmp_dir() as output_dir: # Prepare model and inputs num_classes = 4 train_steps = 2 eval_steps = 2 model = functools.partial(models.MLP, hidden_size=16, num_output_classes=num_classes) inputs = lambda: test_inputs(num_classes) # Train and evaluate state = trax.train(output_dir, model=model, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get("train", "metrics/accuracy") eval_acc = state.history.get("eval", "metrics/accuracy") self.assertEqual(len(train_acc), len(eval_acc)) self.assertEqual(2, len(eval_acc)) # Predict with final params _, predict_fun = model() inputs = inputs().train_stream() predict_fun(state.params, next(inputs)[0])
def test_train_with_weights(self, backend_name): if jax.lib.xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: gin.bind_parameter("unpack_batch.has_weights", True) # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes, with_weights=True) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, train_steps)
def test_train_eval_predict_sm3(self): with self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps, optimizer=trax_opt.SM3) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get("train", "metrics/accuracy") eval_acc = state.history.get("eval", "metrics/accuracy") self.assertEqual(len(train_acc), len(eval_acc)) self.assertEqual(2, len(eval_acc)) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], state.params[0])
def test_train_eval_predict_sm3(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps, optimizer=trax_opt.SM3) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get("train", "metrics/accuracy") eval_acc = state.history.get("eval", "metrics/accuracy") self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], params=state.opt_state.params)
def train_model(self): logging.info("Epoch %d: training model", self._epoch) train_stream = lambda: self._data_stream( # pylint: disable=g-long-lambda self._train_trajectories, self._model_train_batch_size) eval_stream = lambda: self._data_stream( # pylint: disable=g-long-lambda self._eval_trajectories, self._model_train_batch_size) # Ignore n_devices for now. inputs = lambda _: trax_inputs.Inputs( # pylint: disable=g-long-lambda train_stream=train_stream, train_eval_stream=train_stream, eval_stream=eval_stream, input_shape=self._sim_env.model_input_shape, input_dtype=self._sim_env.model_input_dtype, ) trax.train( model=self._sim_env.model, inputs=inputs, output_dir=self._model_dir, has_weights=True, )
def _test_train(self, train_args): with self.tmp_dir() as output_dir: state = trax.train(output_dir, **train_args) # Assert total train steps self.assertEqual(train_args["train_steps"], state.step) # Assert 2 epochs ran train_acc = state.history.get("train", "metrics/accuracy") eval_acc = state.history.get("eval", "metrics/accuracy") self.assertEqual(len(train_acc), len(eval_acc)) self.assertEqual(2, len(eval_acc))
def train_model(self): logging.info("SimPLe epoch [% 6d]: training model.", self._simple_epoch) (train_stream, eval_stream) = self._make_input_streams() # Ignore n_devices for now. inputs = lambda _: trax_inputs.Inputs( # pylint: disable=g-long-lambda train_stream=(lambda: train_stream), train_eval_stream=(lambda: train_stream), eval_stream=(lambda: eval_stream), input_shape=self._sim_env.model_input_shape, input_dtype=self._sim_env.model_input_dtype, ) self._model_train_step += self._n_model_train_steps trax.train( model=self._sim_env.model, inputs=inputs, train_steps=self._model_train_step, output_dir=self._model_dir, has_weights=True, )
def test_train_eval_predict(self, backend_name): if xla_bridge.device_count() > 1 and backend_name == "tf": self.skipTest( "tf-numpy backend doesn't support multi-devices yet.") with backend.use_backend(backend_name), self.tmp_dir() as output_dir: # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 # Adds Dropout and BatchNorm to test state handling. def model_fn(mode="train"): return layers.Model( layers.Dropout(mode=mode, rate=0.1), layers.BatchNorm(mode=mode), models.MLP(d_hidden=16, n_output_classes=n_classes, mode=mode)) inputs = lambda _: test_inputs(n_classes) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(train_steps, state.step) # Assert 2 evaluations ran train_acc = state.history.get("train", "metrics/accuracy") eval_acc = state.history.get("eval", "metrics/accuracy") self.assertEqual(len(train_acc), len(eval_acc)) self.assertLen(eval_acc, 2) # Predict with final params inputs = inputs(1).train_stream() model = layers.Serial(model_fn()) model(next(inputs)[0], params=state.opt_state.params)
def train_model(self): """Train the model. Returns: whether the training was skipped due to a restart. """ logging.info("SimPLe epoch [% 6d]: training model.", self._simple_epoch) start_time = time.time() (train_stream, eval_stream) = self._make_input_streams() # Ignore n_devices for now. inputs = lambda _: trax_inputs.Inputs( # pylint: disable=g-long-lambda train_stream=(lambda: train_stream), train_eval_stream=(lambda: train_stream), eval_stream=(lambda: eval_stream), input_shape=self._sim_env.model_input_shape, input_dtype=self._sim_env.model_input_dtype, # TODO(lukaszkaiser): correct those, they may differ from inputs. target_shape=self._sim_env.model_input_shape, target_dtype=self._sim_env.model_input_dtype) if self._simple_epoch == 0: train_steps = self._n_model_initial_train_steps else: train_steps = self._n_model_train_steps_per_epoch self._model_train_step += train_steps with gin.config_scope("world_model"): state = trax.train( model=self._sim_env.model, inputs=inputs, train_steps=self._model_train_step, output_dir=self._model_dir, has_weights=True, ) logging.vlog(1, "Training model took %0.2f sec.", time.time() - start_time) return state.step > self._model_train_step
def test_train_with_weights(self): with self.tmp_dir() as output_dir: gin.bind_parameter("unpack_batch.has_weights", True) # Prepare model and inputs n_classes = 4 train_steps = 2 eval_steps = 2 model_fn = functools.partial(models.MLP, d_hidden=16, n_output_classes=n_classes) inputs = lambda _: test_inputs(n_classes, with_weights=True) # Train and evaluate state = trax.train(output_dir, model=model_fn, inputs=inputs, train_steps=train_steps, eval_steps=eval_steps) # Assert total train steps self.assertEqual(state.step, train_steps)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.jax: # Setup trax FLAGS dataset = FLAGS.problem model = FLAGS.model data_dir = FLAGS.data_dir output_dir = FLAGS.output_dir config_file = [FLAGS.hparams_set] config = [ "train.train_steps=%d" % FLAGS.train_steps, "train.eval_steps=%d" % FLAGS.eval_steps, "train.eval_frequency=%d" % FLAGS.local_eval_frequency, ] + str(FLAGS.hparams).split(",") # Copied _setup_gin exactly from trax/trainer.py and removed "FLAGS." def _setup_gin(): """Setup gin configuration.""" # Imports for configurables # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable from tensor2tensor.trax import inputs as _trax_inputs from tensor2tensor.trax import models as _trax_models from tensor2tensor.trax import optimizers as _trax_opt # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable configs = config or [] # Override with --dataset and --model if dataset: configs.append("inputs.dataset_name='%s'" % dataset) configs.append("inputs.data_dir='%s'" % data_dir) configs.append("[email protected]") if model: configs.append("[email protected].%s" % model) gin.parse_config_files_and_bindings(config_file, configs) _setup_gin() trax.train(output_dir=output_dir) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # If we just have to print the registry, do that and exit early. maybe_log_registry_and_exit() # Create HParams. if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode": mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams) if FLAGS.schedule == "run_std_server": run_std_server() mlperf_log.transformer_print( key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed, hparams=hparams) trainer_lib.set_random_seed(FLAGS.random_seed) if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)