コード例 #1
0
 def test_from_file(self):
     params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]])
     filename = self.create_tempfile('params.npy').full_path
     with open(filename, 'wb') as f:
         np.save(f, params)
     initializer = initializers.InitializerFromFile(filename)
     input_shape = (3, 2)
     init_value = initializer(input_shape, random.get_prng(0))
     self.assertEqual('%s' % init_value, '%s' % params)
コード例 #2
0
ファイル: trainer_lib.py プロジェクト: shivajid/trax
def init_random_number_generators(seed=None):
  """Initializes random generators for Python, NumPy, TensorFlow, and JAX."""
  # Seed Python random (None as seed is okay), then use it to seed the others.
  random.seed(seed)
  if seed is None:
    seed = random.randint(0, 2**31 - 1)
  numpy.random.seed(seed)
  tf.random.set_seed(seed)
  return jax_random.get_prng(seed)
コード例 #3
0
def get_random_number_generator_and_set_seed(seed=None):
    """Get a JAX random number generator and set random seed everywhere."""
    random.seed(seed)
    # While python random accepts None as seed and uses time/os seed then,
    # some other functions expect integers so we create one here.
    if seed is None:
        seed = random.randint(0, 2**31 - 1)
    tf.random.set_seed(seed)
    numpy.random.seed(seed)
    return jax_random.get_prng(seed)
コード例 #4
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size,
                                     output_size=output_size)
     B, T, OBS = 2, 2, 3  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     _, _ = model.initialize_once((1, 1, OBS), onp.float32, key)
     x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS)
     y = model(x)
     self.assertEqual((B, T + 1, output_size), y.shape)
コード例 #5
0
 def _make_schedule(
         self,
         history,
         control_configs,
         observation_metrics=(('eval', 'metrics/accuracy'), ),
         action_multipliers=(1.0, ),
         vocab_size=None,
 ):
     policy_and_value_model = functools.partial(
         transformer.TransformerDecoder,
         d_model=2,
         d_ff=2,
         n_layers=0,
         vocab_size=vocab_size,
     )
     net = ppo.policy_and_value_net(
         n_actions=len(action_multipliers),
         n_controls=len(control_configs),
         vocab_size=None,
         bottom_layers_fn=policy_and_value_model,
         two_towers=False,
     )
     rng = jax_random.get_prng(seed=0)
     obs_dim = len(observation_metrics)
     if vocab_size is None:
         shape = (1, 1, obs_dim)
         dtype = np.float32
     else:
         shape = (1, 1)
         dtype = np.int32
     (params, state) = net.initialize_once(shape, dtype, rng)
     policy_dir = self.get_temp_dir()
     # Optimizer slots should not be used for anything.
     slots = None
     opt_state = (params, slots)
     ppo.save_opt_state(policy_dir,
                        opt_state,
                        state,
                        epoch=0,
                        total_opt_step=0,
                        history=history)
     return learning_rate.PolicySchedule(
         history,
         observation_metrics=observation_metrics,
         include_controls_in_observation=False,
         action_multipliers=action_multipliers,
         control_configs=control_configs,
         policy_and_value_model=policy_and_value_model,
         policy_and_value_two_towers=False,
         policy_and_value_vocab_size=vocab_size,
         policy_dir=policy_dir,
     )
コード例 #6
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.AtariCnn(hidden_sizes=hidden_size,
                                output_size=output_size)
     B, T, OBS = 2, 2, (28, 28, 3)  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     _, _ = model.initialize_once((1, 1) + OBS, onp.float32, key)
     x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape(
         B, T + 1, *OBS)
     y = model(x)
     self.assertEqual((B, T + 1, output_size), y.shape)
コード例 #7
0
 def _make_schedule(
         self,
         history,
         control_configs,
         observation_metrics=(("eval", "metrics/accuracy"), ),
         action_multipliers=(1.0, ),
 ):
     policy_and_value_model = atari_cnn.FrameStackMLP
     net = ppo.policy_and_value_net(
         n_actions=len(action_multipliers),
         n_controls=len(control_configs),
         vocab_size=None,
         bottom_layers_fn=policy_and_value_model,
         two_towers=False,
     )
     rng = jax_random.get_prng(seed=0)
     obs_dim = len(observation_metrics)
     (params, state) = net.initialize_once((1, 1, obs_dim), np.float32, rng)
     policy_dir = self.get_temp_dir()
     # Optimizer slots should not be used for anything.
     slots = None
     opt_state = (params, slots)
     ppo.save_opt_state(policy_dir,
                        opt_state,
                        state,
                        epoch=0,
                        total_opt_step=0)
     return learning_rate.PolicySchedule(
         history,
         observation_metrics=observation_metrics,
         include_controls_in_observation=False,
         action_multipliers=action_multipliers,
         control_configs=control_configs,
         policy_and_value_model=policy_and_value_model,
         policy_and_value_two_towers=False,
         policy_dir=policy_dir,
     )
コード例 #8
0
def PolicySchedule(
    history,
    observation_metrics=(
        ('train', 'metrics/accuracy'),
        ('train', 'metrics/loss'),
        ('eval', 'metrics/accuracy'),
        ('eval', 'metrics/loss'),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ('learning_rate', 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {'name': float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, 'Building observations took %0.2f sec.',
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_controls=len(control_configs),
        n_actions=len(action_multipliers),
        vocab_size=policy_and_value_vocab_size,
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, 'Building the policy network took %0.2f sec.',
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step, history)
    (opt_state, state, _, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, 'Policy checkpoint not found.'
    (params, _, _) = opt_state
    logging.vlog(1, 'Restoring the policy parameters took %0.2f sec.',
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()

    (low, high) = observation_range
    observation_space = gym.spaces.Box(shape=observations.shape[1:],
                                       low=low,
                                       high=high)
    action_space = gym.spaces.MultiDiscrete(nvec=(len(action_multipliers), ) *
                                            len(control_configs))
    n_timesteps = observations.shape[0]
    rewards_to_actions = ppo.init_rewards_to_actions(
        policy_and_value_vocab_size, observation_space, action_space,
        n_timesteps)
    # (log_probs, value_preds, state, rng)
    (log_probs, _, _, _) = ppo.run_policy(
        policy_and_value_net_apply=net,
        observations=np.array([observations]),
        lengths=np.array([n_timesteps]),
        weights=params,
        state=state,
        rng=rng,
        vocab_size=policy_and_value_vocab_size,
        observation_space=observation_space,
        action_space=action_space,
        rewards_to_actions=rewards_to_actions,
    )

    logging.vlog(1, 'Running the policy took %0.2f sec.',
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs),
                               len(action_multipliers))
    action = utils.gumbel_sample(log_probs[0] / temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls
 def seed(self, seed=None):
     if seed is None:
         seed = random.randint(0, 2**31 - 1)
     self._rng = jax_random.get_prng(seed)
     return super(SimulatedEnvProblem, self).seed(seed=seed)
コード例 #10
0
 def test_orthogonal(self):
     initializer = initializers.OrthogonalInitializer()
     input_shape = (29, 5, 7, 20)
     init_value = initializer(input_shape, random.get_prng(0))
     self.assertEqual(tuple(init_value.shape), input_shape)
コード例 #11
0
 def test_kaiming_uniform(self):
     initializer = initializers.KaimingUniformInitializer()
     input_shape = (29, 5, 7, 20)
     init_value = initializer(input_shape, random.get_prng(0))
     self.assertEqual(tuple(init_value.shape), input_shape)
コード例 #12
0
 def test_lecun_normal(self):
     initializer = initializers.LeCunNormalInitializer()
     input_shape = (29, 5, 7, 20)
     init_value = initializer(input_shape, random.get_prng(0))
     self.assertEqual(tuple(init_value.shape), input_shape)
コード例 #13
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ("learning_rate", 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {"name": float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    assert policy_and_value_vocab_size is None, (
        "Serialized policies are not supported yet.")
    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_controls=len(control_configs),
        n_actions=len(action_multipliers),
        vocab_size=policy_and_value_vocab_size,
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    (log_probs, _) = (net(np.array([observations]),
                          params=params,
                          state=state,
                          rng=rng))
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs) * observations.shape[0],
                               len(action_multipliers))
    action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] /
                                 temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls