Beispiel #1
0
def main(_):

    logging.set_verbosity(FLAGS.log_level)

    if FLAGS.enable_eager_execution:
        tf.enable_eager_execution()

    if FLAGS.tf_xla:
        tf.config.optimizer.set_jit(True)

    tf.config.optimizer.set_experimental_options(
        {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host})

    tf.config.optimizer.set_experimental_options(
        {'layout_optimizer': FLAGS.tf_opt_layout})

    _setup_gin()

    if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
        # Numpy backend doesn't benefit from having the input pipeline run on GPU,
        # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
        # set up first before determining the backend.
        tf.config.experimental.set_visible_devices([], 'GPU')

    # Setup output directory
    output_dir = FLAGS.output_dir or _default_output_dir()
    trainer_lib.log('Using --output_dir %s' % output_dir)
    output_dir = os.path.expanduser(output_dir)

    # If on TPU, let JAX know.
    if FLAGS.use_tpu:
        jax.config.update('jax_platform_name', 'tpu')

    trainer_lib.train(output_dir=output_dir)
Beispiel #2
0
    def test_train_restart(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            trainer_lib.train(output_dir,
                              model=model_fn,
                              inputs=inputs,
                              train_steps=train_steps,
                              eval_steps=eval_steps)

            # Restart training
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      train_steps=(2 * train_steps),
                                      eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(state.step, 2 * train_steps)
Beispiel #3
0
    def test_inits_policy_by_world_model_checkpoint(self):
        transformer_kwargs = {
            "d_model": 1,
            "d_ff": 1,
            "n_layers": 1,
            "n_heads": 1,
            "max_len": 128,
            "mode": "train",
        }
        rng = jax_random.PRNGKey(123)
        init_kwargs = {
            "input_shapes": (1, 1),
            "input_dtype": np.int32,
            "rng": rng,
        }
        model_fn = functools.partial(models.TransformerLM,
                                     vocab_size=4,
                                     **transformer_kwargs)
        output_dir = self.get_temp_dir()
        # Initialize a world model checkpoint by running the trainer.
        trainer_lib.train(
            output_dir,
            model=model_fn,
            inputs=functools.partial(inputs.random_inputs,
                                     input_shape=(1, 1),
                                     output_shape=(1, 1)),
            train_steps=1,
            eval_steps=1,
        )

        policy = ppo.policy_and_value_net(
            n_actions=3,
            n_controls=2,
            vocab_size=4,
            bottom_layers_fn=functools.partial(models.TransformerDecoder,
                                               **transformer_kwargs),
            two_towers=False,
        )
        (policy_params, policy_state) = policy.initialize_once(**init_kwargs)

        # Initialize policy parameters from world model parameters.
        new_policy_params = ppo.init_policy_from_world_model_checkpoint(
            policy_params, output_dir)
        # Try to run the policy with new parameters.
        observations = np.zeros((1, 100), dtype=np.int32)
        policy(observations,
               params=new_policy_params,
               state=policy_state,
               rng=rng)
Beispiel #4
0
    def test_train_eval_predict_sm3(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest(
                "tf-numpy backend doesn't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2
            model_fn = functools.partial(models.MLP,
                                         d_hidden=16,
                                         n_output_classes=n_classes)
            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      train_steps=train_steps,
                                      eval_steps=eval_steps,
                                      optimizer=trax_opt.SM3)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get('train', 'metrics/accuracy')
            eval_acc = state.history.get('eval', 'metrics/accuracy')
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], params=state.opt_state.params)
    def test_inits_policy_by_world_model_checkpoint(self):
        transformer_kwargs = {
            'd_model': 1,
            'd_ff': 1,
            'n_layers': 1,
            'n_heads': 1,
            'max_len': 128,
            'mode': 'train',
        }
        rng = jax_random.PRNGKey(123)
        model_fn = functools.partial(models.TransformerLM,
                                     vocab_size=4,
                                     **transformer_kwargs)
        output_dir = self.get_temp_dir()
        # Initialize a world model checkpoint by running the trainer.
        trainer_lib.train(
            output_dir,
            model=model_fn,
            inputs=functools.partial(inputs.random_inputs,
                                     input_shape=(1, 1),
                                     output_shape=(1, 1)),
            train_steps=1,
            eval_steps=1,
        )

        policy = ppo.policy_and_value_net(
            n_actions=3,
            n_controls=2,
            vocab_size=4,
            bottom_layers_fn=functools.partial(models.TransformerDecoder,
                                               **transformer_kwargs),
            two_towers=False,
        )
        input_signature = ShapeDtype((1, 1), np.int32)
        policy._set_rng(rng)
        (policy_params, policy_state) = policy.initialize_once(input_signature)

        # Initialize policy parameters from world model parameters.
        new_policy_params = ppo.init_policy_from_world_model_checkpoint(
            policy_params, output_dir)
        # Try to run the policy with new parameters.
        observations = np.zeros((1, 100), dtype=np.int32)
        policy(observations,
               weights=new_policy_params,
               state=policy_state,
               rng=rng)
Beispiel #6
0
def main(_):

  logging.set_verbosity(FLAGS.log_level)

  if FLAGS.enable_eager_execution:
    tf.compat.v1.enable_eager_execution()

  if FLAGS.tf_xla:
    tf.config.optimizer.set_jit(True)
    backend.set_tf_xla_forced_compile(FLAGS.tf_xla_forced_compile)

  tf.config.optimizer.set_experimental_options(
      {'pin_to_host_optimization': FLAGS.tf_opt_pin_to_host}
  )

  tf.config.optimizer.set_experimental_options(
      {'layout_optimizer': FLAGS.tf_opt_layout}
  )

  set_tf_allow_float64(FLAGS.tf_allow_float64)

  _setup_gin()

  if FLAGS.enable_eager_execution and backend.get_name() in ('numpy', 'jax'):
    # Numpy backend doesn't benefit from having the input pipeline run on GPU,
    # and jax backend has GPU memory contention if TF uses the GPU. Gin must be
    # set up first before determining the backend.
    tf.config.experimental.set_visible_devices([], 'GPU')

  # Setup output directory
  output_dir = FLAGS.output_dir or _default_output_dir()
  trainer_lib.log('Using --output_dir %s' % output_dir)
  output_dir = os.path.expanduser(output_dir)

  # If on TPU, let JAX know.
  if FLAGS.use_tpu:
    jax.config.update('jax_platform_name', 'tpu')
    jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
    jax.config.update('jax_backend_target', FLAGS.jax_backend_target)

  if FLAGS.use_tpu and backend.get_name() == 'tf':
    worker_cpu = tf_init_tpu()
    with tf.device(worker_cpu):
      if trainer_lib.num_devices() == 1:
        # TF's device priority is GPU > CPU > TPU, so we need to explicitly make
        # the TPU core the default device here.
        with tf.device('/device:TPU:0'):
          trainer_lib.train(output_dir=output_dir)
      else:
        trainer_lib.train(output_dir=output_dir)
  else:
    trainer_lib.train(output_dir=output_dir)

  trainer_lib.log('Finished training.')
Beispiel #7
0
    def _test_train_eval_predict(self, backend_name):
        if xla_bridge.device_count() > 1 and backend_name == 'tf':
            self.skipTest("tf-numpy backend does't support multi-devices yet.")
        with backend.use_backend(backend_name), self.tmp_dir() as output_dir:
            # Prepare model and inputs
            n_classes = 4
            train_steps = 2
            eval_steps = 2

            # Adds Dropout and BatchNorm to test state handling.
            def model_fn(mode='train'):
                return layers.Serial(
                    layers.Dropout(mode=mode, rate=0.1),
                    layers.BatchNorm(mode=mode),
                    models.MLP(d_hidden=16,
                               n_output_classes=n_classes,
                               mode=mode))

            inputs = lambda _: test_inputs(n_classes)

            # Train and evaluate
            state = trainer_lib.train(output_dir,
                                      model=model_fn,
                                      inputs=inputs,
                                      train_steps=train_steps,
                                      eval_steps=eval_steps)

            # Assert total train steps
            self.assertEqual(train_steps, state.step)

            # Assert 2 evaluations ran
            train_acc = state.history.get('train', 'metrics/accuracy')
            eval_acc = state.history.get('eval', 'metrics/accuracy')
            self.assertEqual(len(train_acc), len(eval_acc))
            self.assertLen(eval_acc, 2)

            # Predict with final params
            inputs = inputs(1).train_stream()
            model = layers.Serial(model_fn())
            model(next(inputs)[0], params=state.opt_state.params)
Beispiel #8
0
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info('SimPLe epoch [% 6d]: training model.',
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=(lambda: train_stream),
            train_eval_stream=(lambda: train_stream),
            eval_stream=(lambda: eval_stream),
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
            # TODO(lukaszkaiser): correct those, they may differ from inputs.
            target_shape=self._sim_env.model_input_shape,
            target_dtype=self._sim_env.model_input_dtype)

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope('world_model'):
            state = trainer_lib.train(
                model=self._sim_env.model,
                inputs=inputs,
                train_steps=self._model_train_step,
                output_dir=self._model_dir,
                has_weights=True,
            )

        logging.vlog(1, 'Training model took %0.2f sec.',
                     time.time() - start_time)
        return state.step > self._model_train_step