コード例 #1
0
ファイル: simple_trainer.py プロジェクト: zhaoqiuye/trax
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info('SimPLe epoch [% 6d]: training model.',
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = trax_inputs.Inputs(train_stream=(lambda _: train_stream),
                                    eval_stream=(lambda _: eval_stream))
        (obs, act, _, _) = next(train_stream)
        # TODO(pkozakowski): Refactor Inputs so this can be inferred correctly.
        inputs._input_shape = (tuple(obs.shape)[1:], tuple(act.shape)[1:])  # pylint: disable=protected-access
        inputs._input_dtype = (obs.dtype, act.dtype)  # pylint: disable=protected-access

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope('world_model'):
            state = trainer_lib.train(
                model=self._sim_env.model,
                inputs=inputs,
                steps=self._model_train_step,
                output_dir=self._model_dir,
            )

        logging.vlog(1, 'Training model took %0.2f sec.',
                     time.time() - start_time)
        return state.step > self._model_train_step
コード例 #2
0
ファイル: simple_trainer.py プロジェクト: zzszmyf/trax
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info('SimPLe epoch [% 6d]: training model.',
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = trax_inputs.Inputs(train_stream=(lambda _: train_stream),
                                    eval_stream=(lambda _: eval_stream))

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope('world_model'):
            state = trainer_lib.train(
                model=self._sim_env.model,
                inputs=inputs,
                steps=self._model_train_step,
                output_dir=self._model_dir,
                has_weights=True,
            )

        logging.vlog(1, 'Training model took %0.2f sec.',
                     time.time() - start_time)
        return state.step > self._model_train_step
コード例 #3
0
def _test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)):
    """Make trainer_lib.inputs.Inputs."""
    batch_size = 2 * xla_bridge.device_count()

    def input_stream(n_devices):
        del n_devices
        key = fastmath.random.get_prng(0)
        while True:
            keys = fastmath.random.split(key, 4)
            key = keys[0]
            inputs = fastmath.random.uniform(keys[1],
                                             [batch_size] + list(input_shape))
            targets = fastmath.random.randint(keys[2], [batch_size],
                                              dtype=jnp.int32,
                                              minval=0,
                                              maxval=n_classes)
            weights = fastmath.random.uniform(keys[3], [batch_size])
            if with_weights:
                yield inputs, targets, weights
            else:
                yield inputs, targets

    def input_stream_masked(n_devices):
        return inputs_lib.add_loss_weights(input_stream(n_devices))

    return inputs_lib.Inputs(input_stream_masked)
コード例 #4
0
def test_inputs(n_classes, with_weights=False, input_shape=(6, 6, 3)):
    """Make trainer_lib.inputs.Inputs."""
    batch_size = 2 * xla_bridge.device_count()

    def input_stream():
        key = math.random.get_prng(0)
        while True:
            keys = math.random.split(key, 4)
            key = keys[0]
            inputs = math.random.uniform(keys[1],
                                         [batch_size] + list(input_shape))
            targets = math.random.randint(keys[2], [batch_size],
                                          dtype=np.int32,
                                          minval=0,
                                          maxval=n_classes)
            weights = math.random.uniform(keys[3], [batch_size])
            if with_weights:
                yield inputs, targets, weights
            else:
                yield inputs, targets

    return inputs_lib.Inputs(train_stream=input_stream,
                             train_eval_stream=input_stream,
                             eval_stream=input_stream,
                             input_shape=input_shape,
                             input_dtype=np.float32,
                             target_shape=(),
                             target_dtype=np.int32)
コード例 #5
0
    def test_training_loop_cartpole_serialized_init_from_world_model(
            self, two_towers):
        gin.bind_parameter('BoxSpaceSerializer.precision', 1)

        transformer_kwargs = {
            'd_model': 1,
            'd_ff': 1,
            'n_layers': 1,
            'n_heads': 1,
            'max_len': 128,
        }
        obs_serializer = space_serializer.create(gym.spaces.MultiDiscrete(
            [2, 2]),
                                                 vocab_size=4)
        act_serializer = space_serializer.create(gym.spaces.Discrete(2),
                                                 vocab_size=4)
        model_fn = lambda mode: serialization_utils.SerializedModel(  # pylint: disable=g-long-lambda
            seq_model=models.TransformerLM(
                mode=mode, vocab_size=4, **transformer_kwargs),
            observation_serializer=obs_serializer,
            action_serializer=act_serializer,
            significance_decay=0.9,
        )
        with self.tmp_dir() as output_dir:
            model_dir = os.path.join(output_dir, 'model')

            def dummy_stream(_):
                while True:
                    obs = np.zeros((1, 2, 2), dtype=np.int32)
                    act = np.zeros((1, 1), dtype=np.int32)
                    mask = np.ones_like(obs)
                    yield (obs, act, obs, mask)

            inputs = trax_inputs.Inputs(train_stream=dummy_stream,
                                        eval_stream=dummy_stream)
            inputs._input_shape = ((2, 2), (1, ))  # pylint: disable=protected-access
            inputs._input_dtype = (np.int32, np.int32)  # pylint: disable=protected-access

            # Initialize a world model checkpoint by running the trainer.
            trainer_lib.train(
                model_dir,
                model=model_fn,
                inputs=inputs,
                steps=1,
                eval_steps=1,
                has_weights=True,
            )

            policy_dir = os.path.join(output_dir, 'policy')
            trainer = self._make_trainer(
                train_env=self.get_wrapped_env('CartPole-v0', 2),
                eval_env=self.get_wrapped_env('CartPole-v0', 2),
                output_dir=policy_dir,
                model=functools.partial(models.TransformerDecoder,
                                        **transformer_kwargs),
                policy_and_value_vocab_size=4,
                init_policy_from_world_model_output_dir=model_dir,
                policy_and_value_two_towers=two_towers,
            )
            trainer.training_loop(n_epochs=2)
コード例 #6
0
def _add_weights(trax_inputs):
    """Add weights to inputs."""
    def _weight_stream(input_stream):
        """Add weights to the given stream."""
        for example in input_stream:
            inp, targets = example
            weights = np.ones_like(targets).astype(np.float32)
            yield (inp, targets, weights)

    return inputs.Inputs(
        train_stream=lambda n: _weight_stream(trax_inputs.train_stream(n)),
        eval_stream=lambda n: _weight_stream(trax_inputs.eval_stream(n)),
        train_eval_stream=lambda n: _weight_stream(  # pylint: disable=g-long-lambda
            trax_inputs.train_eval_stream(n)))
コード例 #7
0
 def inputs(n_devices):
     del n_devices
     stream = itertools.repeat(
         (np.zeros(history_shape), np.zeros(action_shape,
                                            dtype=np.int32),
          np.zeros(obs_shape), np.zeros(reward_shape)))
     inp = trax_inputs.Inputs(train_stream=lambda: stream,
                              train_eval_stream=lambda: stream,
                              eval_stream=lambda: stream)
     inp._input_shape = (history_shape[1:], action_shape[1:])
     inp._input_dtype = (np.float32, np.int32)
     inp._target_shape = (obs_shape[1:], reward_shape[1:])
     inp._target_dtype = (np.float32, np.float32)
     return inp
コード例 #8
0
    def train_model(self):
        """Train the model.

    Returns:
      whether the training was skipped due to a restart.
    """
        logging.info('SimPLe epoch [% 6d]: training model.',
                     self._simple_epoch)
        start_time = time.time()

        (train_stream, eval_stream) = self._make_input_streams()
        # Ignore n_devices for now.
        inputs = lambda _: trax_inputs.Inputs(  # pylint: disable=g-long-lambda
            train_stream=(lambda: train_stream),
            train_eval_stream=(lambda: train_stream),
            eval_stream=(lambda: eval_stream),
            input_shape=self._sim_env.model_input_shape,
            input_dtype=self._sim_env.model_input_dtype,
            # TODO(lukaszkaiser): correct those, they may differ from inputs.
            target_shape=self._sim_env.model_input_shape,
            target_dtype=self._sim_env.model_input_dtype)

        if self._simple_epoch == 0:
            train_steps = self._n_model_initial_train_steps
        else:
            train_steps = self._n_model_train_steps_per_epoch
        self._model_train_step += train_steps
        with gin.config_scope('world_model'):
            state = trainer_lib.train(
                model=self._sim_env.model,
                inputs=inputs,
                train_steps=self._model_train_step,
                output_dir=self._model_dir,
                has_weights=True,
            )

        logging.vlog(1, 'Training model took %0.2f sec.',
                     time.time() - start_time)
        return state.step > self._model_train_step
コード例 #9
0
ファイル: trainer_lib.py プロジェクト: victorustc/trax
def _add_weights_and_mask(inputs, id_to_mask):
    """Add weights to inputs without weights and masks by id if requested.

  Each of the (train, eval, train_eval) streams of inputs is augmented in
  the following way:
  * if the stream consists of pairs (inputs, targets), a loss mask is added
    that is creates as a tensor of ones of the same shape as targets
  * if id_to_mask is not None, and the stream (after the previous point) has
    triples (inputs, targets, weights), the weights are multipled by a 0/1 mask
    that is 0 iff targets is equal to id_to_mask (1 otherwise).

  Args:
    inputs: a trax_inputs.Inputs object to operate on
    id_to_mask: int or None, id to pad in targets if not None

  Returns:
    a trax_inputs.Inputs object with augmented streams
  """
    def _with_masks(input_stream):
        """Create masks for the given stream."""
        for example in input_stream:
            if len(example) > 3 or len(example) < 2:
                assert id_to_mask is None, 'Cannot automatically mask this stream.'
                yield example
            else:
                if len(example) == 2:
                    weights = numpy.ones_like(example[1]).astype(numpy.float32)
                else:
                    weights = example[2].astype(numpy.float32)
                mask = 1.0 - numpy.equal(example[1], id_to_mask).astype(
                    np.float32)
                weights *= mask
                yield (example[0], example[1], weights)

    return trax_inputs.Inputs(
        train_stream=lambda n: _with_masks(inputs.train_stream(n)),
        eval_stream=lambda n: _with_masks(inputs.eval_stream(n)),
        train_eval_stream=lambda n: _with_masks(inputs.train_eval_stream(n)))
コード例 #10
0
ファイル: ppo_trainer_test.py プロジェクト: zzszmyf/trax
    def test_training_loop_simulated(self):
        n_actions = 5
        history_shape = (3, 2, 3)
        action_shape = (3, )
        obs_shape = (3, 3)
        reward_shape = (3, 1)

        def model(mode):
            del mode
            return layers.Serial(
                layers.Parallel(
                    layers.Flatten(),  # Observation stack.
                    layers.Embedding(d_feature=1,
                                     vocab_size=n_actions),  # Action.
                ),
                layers.Concatenate(),
                layers.Dense(n_units=1),
                layers.Dup(),
                layers.Parallel(
                    layers.Dense(n_units=obs_shape[1]),  # New observation.
                    None,  # Reward.
                ))

        stream = itertools.repeat(
            (np.zeros(history_shape), np.zeros(action_shape, dtype=np.int32),
             np.zeros(obs_shape), np.zeros(reward_shape)))
        inp = trax_inputs.Inputs(lambda _: stream)
        inp._input_shape = (history_shape[1:], action_shape[1:])
        inp._input_dtype = (np.float32, np.int32)
        inp._target_shape = (obs_shape[1:], reward_shape[1:])
        inp._target_dtype = (np.float32, np.float32)
        inputs = inp

        def loss(id_to_mask=None, has_weights=False):
            """Cross-entropy loss as scalar compatible with Trax masking."""
            return layers.Serial(
                # Swap from (pred-obs, pred-reward, target-obs, target-reward)
                # to (pred-obs, target-obs, pred-reward, target-reward).
                layers.Parallel([], layers.Swap()),
                # Cross-entropy loss for obs, L2 loss on reward.
                layers.Parallel(
                    layers.CrossEntropyLoss(id_to_mask, has_weights),
                    layers.L2Loss(id_to_mask, has_weights)),
                # Add both losses.
                layers.Add(),
                # Zero out in this test.
                layers.Fn(lambda x: x * 0.0),
            )

        with self.tmp_dir() as output_dir:
            # Run fake training just to save the parameters.
            trainer = trainer_lib.Trainer(
                model=model,
                loss_fn=loss,
                inputs=inputs,
                optimizer=trax_opt.SM3,
                lr_schedule=lr.MultifactorSchedule,
                output_dir=output_dir,
            )
            trainer.train_epoch(n_steps=1, n_eval_steps=1)

            # Repeat the history over and over again.
            stream = itertools.repeat(np.zeros(history_shape))
            env_fn = functools.partial(
                simulated_env_problem.RawSimulatedEnvProblem,
                model=model,
                history_length=history_shape[1],
                trajectory_length=3,
                batch_size=history_shape[0],
                observation_space=gym.spaces.Box(low=-np.inf,
                                                 high=np.inf,
                                                 shape=(obs_shape[1], )),
                action_space=gym.spaces.Discrete(n=n_actions),
                reward_range=(-1, 1),
                discrete_rewards=False,
                history_stream=stream,
                output_dir=output_dir,
            )

            trainer = self._make_trainer(
                train_env=env_fn(),
                eval_env=env_fn(),
                output_dir=output_dir,
            )
            trainer.training_loop(n_epochs=2)