コード例 #1
0
 def _make_schedule(
     self,
     history,
     start_lr=1e-3,
     observation_metrics=(("eval", "metrics/accuracy"),),
     action_multipliers=(1.0,),
 ):
   policy_and_value_model = atari_cnn.FrameStackMLP
   net = ppo.policy_and_value_net(
       n_actions=len(action_multipliers),
       n_controls=1,
       bottom_layers_fn=policy_and_value_model,
       two_towers=False,
   )
   rng = jax_random.get_prng(seed=0)
   obs_dim = len(observation_metrics)
   (params, state) = net.initialize((1, 1, obs_dim), np.float32, rng)
   policy_dir = self.get_temp_dir()
   # Optimizer slots should not be used for anything.
   slots = None
   opt_state = (params, slots)
   ppo.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0)
   return learning_rate.PolicySchedule(
       history,
       observation_metrics=observation_metrics,
       include_lr_in_observation=False,
       action_multipliers=action_multipliers,
       start_lr=start_lr,
       policy_and_value_model=policy_and_value_model,
       policy_and_value_two_towers=False,
       policy_dir=policy_dir,
   )
コード例 #2
0
def check_shape_agreement(test_case, init_fun, apply_fun, input_shape):
    rng_key1, rng_key2 = random.split(random.get_prng(0))
    result_shape, params = init_fun(rng_key1, input_shape)
    inputs = random_inputs(onp.random.RandomState(0), input_shape)
    result = apply_fun(params, inputs, rng=rng_key2)
    test_case.assertEqual(result.shape, result_shape)
    return result_shape
コード例 #3
0
ファイル: base_test.py プロジェクト: yawenz/tensor2tensor
def check_shape_agreement(test_case, layer, input_shape):
    rng_key1, rng_key2 = random.split(random.get_prng(0))
    result_shape = layer.output_shape(input_shape)
    params = layer.initialize(input_shape, rng_key1)
    inputs = random_inputs(onp.random.RandomState(0), input_shape)
    result = layer(inputs, params, rng=rng_key2)
    test_case.assertEqual(result.shape, result_shape)
    return result_shape
コード例 #4
0
def get_random_number_generator_and_set_seed(seed=None):
    """Get a JAX random number generator and set random seed everywhere."""
    random.seed(seed)
    # While python random accepts None as seed and uses time/os seed then,
    # some other functions expect integers so we create one here.
    if seed is None:
        seed = random.randint(0, 2**31 - 1)
    tf.set_random_seed(seed)
    numpy.random.seed(seed)
    return jax_random.get_prng(seed)
コード例 #5
0
ファイル: rnn_test.py プロジェクト: yawenz/tensor2tensor
    def _test_cell_runs(self, model, input_shape, output_shape):
        source = np.ones(input_shape, dtype=np.float32)

        # Build params
        rng = jax_random.get_prng(0)
        model.initialize(input_shape, rng)

        # Run network
        output = model(source)

        self.assertEqual(output_shape, output.shape)
コード例 #6
0
 def test_div(self):
     init_fun, apply_fun = stax.Div(divisor=2.0)
     input_np = onp.array([[1, 2, 3], [4, 5, 6]], dtype=onp.float32)
     input_shape = input_np.shape
     rng = random.get_prng(0)
     _, _ = init_fun(rng, input_shape)
     output_np = apply_fun(None, input_np)
     # absltest doesn't have ndarray equalities.
     expected_output_np = input_np / 2.0
     self.assertAlmostEqual(0.0,
                            onp.sum((output_np - expected_output_np)**2),
                            delta=1e-6)
コード例 #7
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size,
                                     output_size=output_size)
     B, T, OBS = 2, 2, 3  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     _, _ = model.initialize_once((1, 1, OBS), onp.float32, key)
     x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS)
     y = model(x)
     self.assertEqual((B, T + 1, output_size), y.shape)
コード例 #8
0
ファイル: base_test.py プロジェクト: yawenz/tensor2tensor
 def test_dense_param_sharing(self):
     model1 = layers.Serial(layers.Dense(32), layers.Dense(32))
     layer = layers.Dense(32)
     model2 = layers.Serial(layer, layer)
     rng = random.get_prng(0)
     params1 = model1.initialize((-1, 32), rng)
     params2 = model2.initialize((-1, 32), rng)
     # The first parameters have 2 kernels of size (32, 32).
     self.assertEqual((32, 32), params1[0][0].shape)
     self.assertEqual((32, 32), params1[1][0].shape)
     # The second parameters have 1 kernel of size (32, 32) and an empty dict.
     self.assertEqual((32, 32), params2[0][0].shape)
     self.assertEqual((), params2[1])
コード例 #9
0
 def test_computes(self):
     rng_key = jax_random.get_prng(0)
     hidden_size = (4, 4)
     output_size = 6
     model = atari_cnn.AtariCnn(hidden_sizes=hidden_size,
                                output_size=output_size)
     B, T, OBS = 2, 2, (28, 28, 3)  # pylint: disable=invalid-name
     rng_key, key = jax_random.split(rng_key)
     _, _ = model.initialize_once((1, 1) + OBS, onp.float32, key)
     x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape(
         B, T + 1, *OBS)
     y = model(x)
     self.assertEqual((B, T + 1, output_size), y.shape)
コード例 #10
0
def check_shape_agreement(layer_instance, input_shape, integer_inputs=False):
    """Check if layer.output_shape agrees with the actual output shape."""
    rng1, rng2, rng3 = random.split(random.get_prng(0), 3)
    output_shape = layer_instance.output_shape(input_shape)
    output_shape = nested_map(output_shape, int)  # Make non-numpy.
    params = layer_instance.initialize(input_shape, rng1)
    inputs = _random_inputs(input_shape, rng2, integer_inputs=integer_inputs)
    result = layer_instance(inputs, params, rng=rng3)
    result_shape = shapes(result)
    msg = 'output shape %s != real result shape %s' % (output_shape,
                                                       result_shape)
    assert output_shape == result_shape, msg
    return output_shape
コード例 #11
0
 def test_computes(self):
   rng_key = jax_random.get_prng(0)
   hidden_size = (4, 4)
   output_size = 6
   policy = atari_cnn.AtariCnn(
       hidden_sizes=hidden_size, output_size=output_size)
   B, T, OBS = 2, 2, (28, 28, 3)  # pylint: disable=invalid-name
   rng_key, key = jax_random.split(rng_key)
   params = policy.initialize((-1, -1) + OBS, key)
   x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape(
       B, T + 1, *OBS)
   rng_key, key = jax_random.split(rng_key)
   y = policy(x, params, rng=key)
   self.assertEqual((B, T + 1, output_size), y.shape)
コード例 #12
0
 def test_dense_param_sharing(self):
     model1 = stax.Serial(stax.Dense(32), stax.Dense(32))
     layer = stax.Dense(32)
     model2 = stax.Serial(layer, layer)
     init_fun1, _ = model1
     init_fun2, _ = model2
     rng = random.get_prng(0)
     _, params1 = init_fun1(rng, [-1, 32])
     _, params2 = init_fun2(rng, [-1, 32])
     # The first parameters have 2 kernels of size (32, 32).
     self.assertEqual((32, 32), params1[0][0].shape)
     self.assertEqual((32, 32), params1[1][0].shape)
     # The second parameters have 1 kernel of size (32, 32) and an empty dict.
     self.assertEqual((32, 32), params2[0][0].shape)
     self.assertEqual((), params2[1])
コード例 #13
0
    def test_ngpu(self):
        vocab_size = 2
        in_shape = [3, 5, 7]
        source = np.ones(in_shape, dtype=np.int32)

        model = neural_gpu.NeuralGPU(feature_depth=30,
                                     steps=4,
                                     vocab_size=vocab_size)
        # Build params
        rng = jax_random.get_prng(0)
        model.initialize(in_shape, rng)

        # Run network
        output = model(source)

        self.assertEqual(tuple(in_shape + [vocab_size]), output.shape)
コード例 #14
0
 def seed(self, seed=None):
     if seed is None:
         seed = random.randint(0, 2**31 - 1)
     self._rng = jax_random.get_prng(seed)
     return super(SimulatedEnvProblem, self).seed(seed=seed)
コード例 #15
0
 def test_random_normal(self):
     initializer = initializers.RandomNormalInitializer()
     input_shape = (29, 5, 7, 20)
     init_value = initializer(input_shape, random.get_prng(0))
     self.assertEqual(tuple(init_value.shape), input_shape)
コード例 #16
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_lr_in_observation=False,
    observation_range=(0.0, 5.0),
    start_lr=0.001,
    max_lr=10.0,
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_dir=gin.REQUIRED,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_lr_in_observation: bool, whether to include the learning rate in
      observations.
    observation_range: tuple (low, high), range to clip the observation to.
    start_lr: starting learning rate.
    max_lr: maximum value to clip the learning rate to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_dir: directory with the policy checkpoint.

  Returns:
    a function learning_rate(step): float -> float, the step-dependent lr.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        include_lr_in_observation)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        return lambda _: start_lr

    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_actions=len(action_multipliers),
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    ((log_probs, _), _) = net(np.array([observations]), params, state, rng=rng)
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    action = utils.gumbel_sample(log_probs[0, -1, :])

    # Get a new learning rate.
    new_lr = online_tune.new_learning_rate(action, history, action_multipliers,
                                           max_lr)
    return lambda _: new_lr
コード例 #17
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ("learning_rate", 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {"name": float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    assert policy_and_value_vocab_size is None, (
        "Serialized policies are not supported yet.")
    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_controls=len(control_configs),
        n_actions=len(action_multipliers),
        vocab_size=policy_and_value_vocab_size,
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    (log_probs, _) = (net(np.array([observations]),
                          params=params,
                          state=state,
                          rng=rng))
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs) * observations.shape[0],
                               len(action_multipliers))
    action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] /
                                 temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls
コード例 #18
0
 def test_kaiming_uniform(self):
     initializer = initializers.KaimingUniformInitializer()
     input_shape = (29, 5, 7, 20)
     init_value = initializer(input_shape, random.get_prng(0))
     self.assertEqual(tuple(init_value.shape), input_shape)