예제 #1
0
파일: trainer.py 프로젝트: tianhai123/-
def main(_):
    logging.set_verbosity(FLAGS.log_level)

    if FLAGS.tf_eager:
        tf.enable_eager_execution()

    if FLAGS.tf_xla:
        tf.config.optimizer.set_jit(True)

    tf.config.optimizer.set_experimental_options(
        {"pin_to_host_optimization": FLAGS.tf_opt_pin_to_host})

    tf.config.optimizer.set_experimental_options(
        {"layout_optimizer": FLAGS.tf_opt_layout})

    _setup_gin()

    if FLAGS.tf_eager and backend.get_name() in ("numpy", "jax"):
        # Numpy backend doesn't benefit from having the input pipeline run on GPU,
        # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
        # set up first before determining the backend.
        tf.config.experimental.set_visible_devices([], "GPU")

    # Setup output directory
    output_dir = FLAGS.output_dir or _default_output_dir()
    trax.log("Using --output_dir %s" % output_dir)
    output_dir = os.path.expanduser(output_dir)

    # If on TPU, let JAX know.
    if FLAGS.use_tpu:
        jax.config.update("jax_platform_name", "tpu")

    trax.train(output_dir=output_dir)
예제 #2
0
    def test_train_restart(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            trax.train(output_dir,
                       model=model_fn,
                       inputs=inputs,
                       train_steps=train_steps,
                       eval_steps=eval_steps)

            # Restart training
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=(2 * train_steps),
                               eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, 2 * train_steps)
예제 #3
0
    def test_train_restart(self):
        with self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            trax.train(output_dir,
                       model=model_fn,
                       inputs=inputs,
                       train_steps=train_steps,
                       eval_steps=eval_steps)

            # Restart training
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, 2 * train_steps)
예제 #4
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.jax:
        # Hacking main v1 flags to work with jax.
        config_strs = []
        config_strs.append("train.train_steps=" + str(FLAGS.train_steps))
        config_strs.append("train.eval_steps=" + str(FLAGS.eval_steps))
        config_strs.append("train.eval_frequency=" +
                           str(FLAGS.local_eval_frequency))
        if FLAGS.hparams:
            config_strs.extend(str(FLAGS.hparams).split(","))
        data_dir = os.path.expanduser(FLAGS.data_dir)
        output_dir = os.path.expanduser(FLAGS.output_dir)

        gin.bind_parameter("train.dataset", FLAGS.problem)
        config_strs += ["train.model=@" + FLAGS.model]
        config_files = []
        if FLAGS.hparams_set:
            config_files = [os.path.expanduser(FLAGS.hparams_set)]
        gin.parse_config_files_and_bindings(config_files, config_strs)
        trax.train(data_dir=data_dir, output_dir=output_dir)
        return

    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    # If we just have to print the registry, do that and exit early.
    maybe_log_registry_and_exit()

    # Create HParams.
    if argv:
        set_hparams_from_args(argv[1:])
    hparams = create_hparams()

    if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode":
        mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams)
    if FLAGS.schedule == "run_std_server":
        run_std_server()
    mlperf_log.transformer_print(key=mlperf_log.RUN_SET_RANDOM_SEED,
                                 value=FLAGS.random_seed,
                                 hparams=hparams)
    trainer_lib.set_random_seed(FLAGS.random_seed)

    if FLAGS.cloud_mlengine:
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    exp_fn = create_experiment_fn()
    exp = exp_fn(create_run_config(hparams), hparams)
    if is_chief():
        save_metadata(hparams)
    execute_schedule(exp)
    if FLAGS.schedule != "train":
        mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)
예제 #5
0
def main(_):
  logging.set_verbosity(FLAGS.log_level)

  if FLAGS.tf_eager:
    tf.enable_eager_execution()

  if FLAGS.tf_xla:
    tf.config.optimizer.set_jit(True)

  tf.config.optimizer.set_experimental_options(
      {"pin_to_host_optimization": FLAGS.tf_opt_pin_to_host}
  )

  tf.config.optimizer.set_experimental_options(
      {"layout_optimizer": FLAGS.tf_opt_layout}
  )

  _setup_gin()

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trax.log("Using --output_dir %s" % output_dir)
  output_dir = os.path.expanduser(output_dir)

  # If on TPU, let JAX know.
  if FLAGS.use_tpu:
    jax.config.update("jax_platform_name", "tpu")

  trax.train(output_dir=output_dir)
예제 #6
0
    def train_model(self):
        logging.info("Epoch %d: training model", self._simple_epoch)

        # Load data from all epochs.
        # TODO(pkozakowski): Handle the case when the data won't fit in the memory.
        (train_trajectories, eval_trajectories) = self._load_trajectories(
            self._trajectory_dump_root_dir)
        train_stream = lambda: self._data_stream(  # pylint: disable=g-long-lambda
            train_trajectories, self._model_train_batch_size)
        eval_stream = lambda: self._data_stream(  # pylint: disable=g-long-lambda
            eval_trajectories, self._model_train_batch_size)
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=train_stream,
            train_eval_stream=train_stream,
            eval_stream=eval_stream,
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
        )

        self._model_train_step += self._n_model_train_steps
        trax.train(
            model=self._sim_env.model,
            inputs=inputs,
            train_steps=self._model_train_step,
            output_dir=self._model_dir,
            has_weights=True,
        )
예제 #7
0
def main(_):
  _setup_gin()

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trax.log("Using --output_dir %s" % output_dir)
  output_dir = os.path.expanduser(output_dir)

  trax.train(output_dir=output_dir)
예제 #8
0
def main(_):
  logging.set_verbosity(FLAGS.log_level)

  _setup_gin()

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trax.log("Using --output_dir %s" % output_dir)
  output_dir = os.path.expanduser(output_dir)

  trax.train(output_dir=output_dir)
예제 #9
0
def main(_):
  _setup_gin()

  # Setup directories
  data_dir = FLAGS.data_dir
  output_dir = FLAGS.output_dir or _default_output_dir()
  assert data_dir, "Must specify a data directory"
  assert output_dir, "Must specify an output directory"
  trax.log("Using --output_dir %s" % output_dir)

  data_dir = os.path.expanduser(data_dir)
  output_dir = os.path.expanduser(output_dir)

  trax.train(data_dir=data_dir, output_dir=output_dir)
예제 #10
0
    def test_inits_policy_by_world_model_checkpoint(self):
        transformer_kwargs = {
            "d_model": 1,
            "d_ff": 1,
            "n_layers": 1,
            "n_heads": 1,
            "max_len": 128,
            "mode": "train",
        }
        rng = jax_random.PRNGKey(123)
        init_kwargs = {
            "input_shapes": (1, 1),
            "input_dtype": np.int32,
            "rng": rng,
        }
        model_fn = functools.partial(models.TransformerLM,
                                     vocab_size=4,
                                     **transformer_kwargs)
        output_dir = self.get_temp_dir()
        # Initialize a world model checkpoint by running the trainer.
        trax.train(
            output_dir,
            model=model_fn,
            inputs=functools.partial(inputs.random_inputs,
                                     input_shape=(1, 1),
                                     output_shape=(1, 1)),
            train_steps=1,
            eval_steps=1,
        )

        policy = ppo.policy_and_value_net(
            n_actions=3,
            n_controls=2,
            vocab_size=4,
            bottom_layers_fn=functools.partial(models.TransformerDecoder,
                                               **transformer_kwargs),
            two_towers=False,
        )
        (policy_params, policy_state) = policy.initialize_once(**init_kwargs)

        # Initialize policy parameters from world model parameters.
        new_policy_params = ppo.init_policy_from_world_model_checkpoint(
            policy_params, output_dir)
        # Try to run the policy with new parameters.
        observations = np.zeros((1, 100), dtype=np.int32)
        policy(observations,
               params=new_policy_params,
               state=policy_state,
               rng=rng)
예제 #11
0
def main(_):
  logging.set_verbosity(FLAGS.log_level)

  _setup_gin()

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trax.log("Using --output_dir %s" % output_dir)
  output_dir = os.path.expanduser(output_dir)

  # If on TPU, let JAX know.
  if FLAGS.use_tpu:
    jax.config.update("jax_platform_name", "tpu")

  trax.train(output_dir=output_dir)
예제 #12
0
    def test_train_eval_predict(self):
        with self.tmp_dir() as output_dir:
            # Prepare model and inputs
            num_classes = 4
            train_steps = 2
            eval_steps = 2
            model = functools.partial(models.MLP,
                                      hidden_size=16,
                                      num_output_classes=num_classes)
            inputs = lambda: test_inputs(num_classes)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get("train", "metrics/accuracy")
            eval_acc = state.history.get("eval", "metrics/accuracy")
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertEqual(2, len(eval_acc))

            # Predict with final params
            _, predict_fun = model()
            inputs = inputs().train_stream()
            predict_fun(state.params, next(inputs)[0])
예제 #13
0
파일: trax_test.py 프로젝트: tianhai123/-
    def test_train_with_weights(self, backend_name):
        if jax.lib.xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            gin.bind_parameter("unpack_batch.has_weights", True)

            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes, with_weights=True)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, train_steps)
예제 #14
0
    def test_train_eval_predict_sm3(self):
        with self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps,
                               optimizer=trax_opt.SM3)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get("train", "metrics/accuracy")
            eval_acc = state.history.get("eval", "metrics/accuracy")
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertEqual(2, len(eval_acc))

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], state.params[0])
예제 #15
0
    def test_train_eval_predict_sm3(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps,
                               optimizer=trax_opt.SM3)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get("train", "metrics/accuracy")
            eval_acc = state.history.get("eval", "metrics/accuracy")
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], params=state.opt_state.params)
예제 #16
0
  def train_model(self):
    logging.info("Epoch %d: training model", self._epoch)

    train_stream = lambda: self._data_stream(  # pylint: disable=g-long-lambda
        self._train_trajectories, self._model_train_batch_size)
    eval_stream = lambda: self._data_stream(  # pylint: disable=g-long-lambda
        self._eval_trajectories, self._model_train_batch_size)
    # Ignore n_devices for now.
    inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
        train_stream=train_stream,
        train_eval_stream=train_stream,
        eval_stream=eval_stream,
        input_shape=self._sim_env.model_input_shape,
        input_dtype=self._sim_env.model_input_dtype,
    )
    trax.train(
        model=self._sim_env.model,
        inputs=inputs,
        output_dir=self._model_dir,
        has_weights=True,
    )
예제 #17
0
    def _test_train(self, train_args):
        with self.tmp_dir() as output_dir:
            state = trax.train(output_dir, **train_args)

            # Assert total train steps
            self.assertEqual(train_args["train_steps"], state.step)

            # Assert 2 epochs ran
            train_acc = state.history.get("train", "metrics/accuracy")
            eval_acc = state.history.get("eval", "metrics/accuracy")
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertEqual(2, len(eval_acc))
예제 #18
0
    def train_model(self):
        logging.info("SimPLe epoch [% 6d]: training model.",
                     self._simple_epoch)

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=(lambda: train_stream),
            train_eval_stream=(lambda: train_stream),
            eval_stream=(lambda: eval_stream),
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
        )

        self._model_train_step += self._n_model_train_steps
        trax.train(
            model=self._sim_env.model,
            inputs=inputs,
            train_steps=self._model_train_step,
            output_dir=self._model_dir,
            has_weights=True,
        )
예제 #19
0
    def test_train_eval_predict(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == "tf":
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2

            # Adds Dropout and BatchNorm to test state handling.
            def model_fn(mode="train"):
                return layers.Model(
                    layers.Dropout(mode=mode, rate=0.1),
                    layers.BatchNorm(mode=mode),
                    models.MLP(d_hidden=16,
                               n_output_classes=n_classes,
                               mode=mode))

            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get("train", "metrics/accuracy")
            eval_acc = state.history.get("eval", "metrics/accuracy")
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], params=state.opt_state.params)
예제 #20
0
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info("SimPLe epoch [% 6d]: training model.",
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=(lambda: train_stream),
            train_eval_stream=(lambda: train_stream),
            eval_stream=(lambda: eval_stream),
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
            # TODO(lukaszkaiser): correct those, they may differ from inputs.
            target_shape=self._sim_env.model_input_shape,
            target_dtype=self._sim_env.model_input_dtype)

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope("world_model"):
            state = trax.train(
                model=self._sim_env.model,
                inputs=inputs,
                train_steps=self._model_train_step,
                output_dir=self._model_dir,
                has_weights=True,
            )

        logging.vlog(1, "Training model took %0.2f sec.",
                     time.time() - start_time)
        return state.step > self._model_train_step
예제 #21
0
    def test_train_with_weights(self):
        with self.tmp_dir() as output_dir:
            gin.bind_parameter("unpack_batch.has_weights", True)

            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes, with_weights=True)

            # Train and evaluate
            state = trax.train(output_dir,
                               model=model_fn,
                               inputs=inputs,
                               train_steps=train_steps,
                               eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, train_steps)
예제 #22
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.jax:
    # Setup trax FLAGS
    dataset = FLAGS.problem
    model = FLAGS.model
    data_dir = FLAGS.data_dir
    output_dir = FLAGS.output_dir
    config_file = [FLAGS.hparams_set]
    config = [
        "train.train_steps=%d" % FLAGS.train_steps,
        "train.eval_steps=%d" % FLAGS.eval_steps,
        "train.eval_frequency=%d" % FLAGS.local_eval_frequency,
    ] + str(FLAGS.hparams).split(",")

    # Copied _setup_gin exactly from trax/trainer.py and removed "FLAGS."

    def _setup_gin():
      """Setup gin configuration."""
      # Imports for configurables
      # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable
      from tensor2tensor.trax import inputs as _trax_inputs
      from tensor2tensor.trax import models as _trax_models
      from tensor2tensor.trax import optimizers as _trax_opt
      # pylint: disable=g-import-not-at-top,unused-import,g-bad-import-order,reimported,unused-variable

      configs = config or []
      # Override with --dataset and --model
      if dataset:
        configs.append("inputs.dataset_name='%s'" % dataset)
        configs.append("inputs.data_dir='%s'" % data_dir)
        configs.append("[email protected]")
      if model:
        configs.append("[email protected].%s" % model)
      gin.parse_config_files_and_bindings(config_file, configs)

    _setup_gin()
    trax.train(output_dir=output_dir)

  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

  # If we just have to print the registry, do that and exit early.
  maybe_log_registry_and_exit()

  # Create HParams.
  if argv:
    set_hparams_from_args(argv[1:])
  hparams = create_hparams()

  if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode":
    mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams)
  if FLAGS.schedule == "run_std_server":
    run_std_server()
  mlperf_log.transformer_print(
      key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed,
      hparams=hparams)
  trainer_lib.set_random_seed(FLAGS.random_seed)

  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  exp_fn = create_experiment_fn()
  exp = exp_fn(create_run_config(hparams), hparams)
  if is_chief():
    save_metadata(hparams)
  execute_schedule(exp)
  if FLAGS.schedule != "train":
    mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL,
                                 hparams=hparams)