def _make_schedule( self, history, start_lr=1e-3, observation_metrics=(("eval", "metrics/accuracy"),), action_multipliers=(1.0,), ): policy_and_value_model = atari_cnn.FrameStackMLP net = ppo.policy_and_value_net( n_actions=len(action_multipliers), n_controls=1, bottom_layers_fn=policy_and_value_model, two_towers=False, ) rng = jax_random.get_prng(seed=0) obs_dim = len(observation_metrics) (params, state) = net.initialize((1, 1, obs_dim), np.float32, rng) policy_dir = self.get_temp_dir() # Optimizer slots should not be used for anything. slots = None opt_state = (params, slots) ppo.save_opt_state(policy_dir, opt_state, state, epoch=0, total_opt_step=0) return learning_rate.PolicySchedule( history, observation_metrics=observation_metrics, include_lr_in_observation=False, action_multipliers=action_multipliers, start_lr=start_lr, policy_and_value_model=policy_and_value_model, policy_and_value_two_towers=False, policy_dir=policy_dir, )
def check_shape_agreement(test_case, init_fun, apply_fun, input_shape): rng_key1, rng_key2 = random.split(random.get_prng(0)) result_shape, params = init_fun(rng_key1, input_shape) inputs = random_inputs(onp.random.RandomState(0), input_shape) result = apply_fun(params, inputs, rng=rng_key2) test_case.assertEqual(result.shape, result_shape) return result_shape
def check_shape_agreement(test_case, layer, input_shape): rng_key1, rng_key2 = random.split(random.get_prng(0)) result_shape = layer.output_shape(input_shape) params = layer.initialize(input_shape, rng_key1) inputs = random_inputs(onp.random.RandomState(0), input_shape) result = layer(inputs, params, rng=rng_key2) test_case.assertEqual(result.shape, result_shape) return result_shape
def get_random_number_generator_and_set_seed(seed=None): """Get a JAX random number generator and set random seed everywhere.""" random.seed(seed) # While python random accepts None as seed and uses time/os seed then, # some other functions expect integers so we create one here. if seed is None: seed = random.randint(0, 2**31 - 1) tf.set_random_seed(seed) numpy.random.seed(seed) return jax_random.get_prng(seed)
def _test_cell_runs(self, model, input_shape, output_shape): source = np.ones(input_shape, dtype=np.float32) # Build params rng = jax_random.get_prng(0) model.initialize(input_shape, rng) # Run network output = model(source) self.assertEqual(output_shape, output.shape)
def test_div(self): init_fun, apply_fun = stax.Div(divisor=2.0) input_np = onp.array([[1, 2, 3], [4, 5, 6]], dtype=onp.float32) input_shape = input_np.shape rng = random.get_prng(0) _, _ = init_fun(rng, input_shape) output_np = apply_fun(None, input_np) # absltest doesn't have ndarray equalities. expected_output_np = input_np / 2.0 self.assertAlmostEqual(0.0, onp.sum((output_np - expected_output_np)**2), delta=1e-6)
def test_computes(self): rng_key = jax_random.get_prng(0) hidden_size = (4, 4) output_size = 6 model = atari_cnn.FrameStackMLP(hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, 3 # pylint: disable=invalid-name rng_key, key = jax_random.split(rng_key) _, _ = model.initialize_once((1, 1, OBS), onp.float32, key) x = onp.arange(B * (T + 1) * OBS).reshape(B, T + 1, OBS) y = model(x) self.assertEqual((B, T + 1, output_size), y.shape)
def test_dense_param_sharing(self): model1 = layers.Serial(layers.Dense(32), layers.Dense(32)) layer = layers.Dense(32) model2 = layers.Serial(layer, layer) rng = random.get_prng(0) params1 = model1.initialize((-1, 32), rng) params2 = model2.initialize((-1, 32), rng) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def test_computes(self): rng_key = jax_random.get_prng(0) hidden_size = (4, 4) output_size = 6 model = atari_cnn.AtariCnn(hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name rng_key, key = jax_random.split(rng_key) _, _ = model.initialize_once((1, 1) + OBS, onp.float32, key) x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( B, T + 1, *OBS) y = model(x) self.assertEqual((B, T + 1, output_size), y.shape)
def check_shape_agreement(layer_instance, input_shape, integer_inputs=False): """Check if layer.output_shape agrees with the actual output shape.""" rng1, rng2, rng3 = random.split(random.get_prng(0), 3) output_shape = layer_instance.output_shape(input_shape) output_shape = nested_map(output_shape, int) # Make non-numpy. params = layer_instance.initialize(input_shape, rng1) inputs = _random_inputs(input_shape, rng2, integer_inputs=integer_inputs) result = layer_instance(inputs, params, rng=rng3) result_shape = shapes(result) msg = 'output shape %s != real result shape %s' % (output_shape, result_shape) assert output_shape == result_shape, msg return output_shape
def test_computes(self): rng_key = jax_random.get_prng(0) hidden_size = (4, 4) output_size = 6 policy = atari_cnn.AtariCnn( hidden_sizes=hidden_size, output_size=output_size) B, T, OBS = 2, 2, (28, 28, 3) # pylint: disable=invalid-name rng_key, key = jax_random.split(rng_key) params = policy.initialize((-1, -1) + OBS, key) x = onp.arange(B * (T + 1) * functools.reduce(op.mul, OBS)).reshape( B, T + 1, *OBS) rng_key, key = jax_random.split(rng_key) y = policy(x, params, rng=key) self.assertEqual((B, T + 1, output_size), y.shape)
def test_dense_param_sharing(self): model1 = stax.Serial(stax.Dense(32), stax.Dense(32)) layer = stax.Dense(32) model2 = stax.Serial(layer, layer) init_fun1, _ = model1 init_fun2, _ = model2 rng = random.get_prng(0) _, params1 = init_fun1(rng, [-1, 32]) _, params2 = init_fun2(rng, [-1, 32]) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def test_ngpu(self): vocab_size = 2 in_shape = [3, 5, 7] source = np.ones(in_shape, dtype=np.int32) model = neural_gpu.NeuralGPU(feature_depth=30, steps=4, vocab_size=vocab_size) # Build params rng = jax_random.get_prng(0) model.initialize(in_shape, rng) # Run network output = model(source) self.assertEqual(tuple(in_shape + [vocab_size]), output.shape)
def seed(self, seed=None): if seed is None: seed = random.randint(0, 2**31 - 1) self._rng = jax_random.get_prng(seed) return super(SimulatedEnvProblem, self).seed(seed=seed)
def test_random_normal(self): initializer = initializers.RandomNormalInitializer() input_shape = (29, 5, 7, 20) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual(tuple(init_value.shape), input_shape)
def PolicySchedule( history, observation_metrics=( ("train", "metrics/accuracy"), ("train", "metrics/loss"), ("eval", "metrics/accuracy"), ("eval", "metrics/loss"), ), include_lr_in_observation=False, observation_range=(0.0, 5.0), start_lr=0.001, max_lr=10.0, action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_dir=gin.REQUIRED, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_lr_in_observation: bool, whether to include the learning rate in observations. observation_range: tuple (low, high), range to clip the observation to. start_lr: starting learning rate. max_lr: maximum value to clip the learning rate to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_dir: directory with the policy checkpoint. Returns: a function learning_rate(step): float -> float, the step-dependent lr. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() observations = online_tune.history_to_observations( history, observation_metrics, observation_range, include_lr_in_observation) logging.vlog(1, "Building observations took %0.2f sec.", time.time() - start_time) if observations.shape[0] == 0: return lambda _: start_lr # Build the policy network and load its parameters. start_time = time.time() net = ppo.policy_and_value_net( n_actions=len(action_multipliers), bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) logging.vlog(1, "Building the policy network took %0.2f sec.", time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step) (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir) assert opt_state is not None, "Policy checkpoint not found." (params, _) = opt_state logging.vlog(1, "Restoring the policy parameters took %0.2f sec.", time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() # ((log_probs, value_preds), state). We have no way to pass state to the next # step, but that should be fine. ((log_probs, _), _) = net(np.array([observations]), params, state, rng=rng) logging.vlog(1, "Running the policy took %0.2f sec.", time.time() - start_time) # Sample from the action distribution for the last timestep. action = utils.gumbel_sample(log_probs[0, -1, :]) # Get a new learning rate. new_lr = online_tune.new_learning_rate(action, history, action_multipliers, max_lr) return lambda _: new_lr
def PolicySchedule( history, observation_metrics=( ("train", "metrics/accuracy"), ("train", "metrics/loss"), ("eval", "metrics/accuracy"), ("eval", "metrics/loss"), ), include_controls_in_observation=False, control_configs=( # (name, start, (low, high), flip) ("learning_rate", 1e-3, (1e-9, 10.0), False), ), observation_range=(0.0, 10.0), action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_and_value_vocab_size=None, policy_dir=gin.REQUIRED, temperature=1.0, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_controls_in_observation: bool, whether to include the controls in observations. control_configs: control configs, see trax.rl.envs.OnlineTuneEnv. observation_range: tuple (low, high), range to clip the metrics to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_and_value_vocab_size: vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. policy_dir: directory with the policy checkpoint. temperature: temperature for sampling from the policy. Returns: a function nontrainable_params(step): float -> {"name": float}, the step-dependent schedule for nontrainable parameters. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() observations = online_tune.history_to_observations( history, observation_metrics, observation_range, control_configs if include_controls_in_observation else None) logging.vlog(1, "Building observations took %0.2f sec.", time.time() - start_time) if observations.shape[0] == 0: controls = { name: start_value for (name, start_value, _, _) in control_configs } return lambda _: controls assert policy_and_value_vocab_size is None, ( "Serialized policies are not supported yet.") # Build the policy network and load its parameters. start_time = time.time() net = ppo.policy_and_value_net( n_controls=len(control_configs), n_actions=len(action_multipliers), vocab_size=policy_and_value_vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) logging.vlog(1, "Building the policy network took %0.2f sec.", time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step) (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir) assert opt_state is not None, "Policy checkpoint not found." (params, _) = opt_state logging.vlog(1, "Restoring the policy parameters took %0.2f sec.", time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() # ((log_probs, value_preds), state). We have no way to pass state to the next # step, but that should be fine. (log_probs, _) = (net(np.array([observations]), params=params, state=state, rng=rng)) logging.vlog(1, "Running the policy took %0.2f sec.", time.time() - start_time) # Sample from the action distribution for the last timestep. assert log_probs.shape == (1, len(control_configs) * observations.shape[0], len(action_multipliers)) action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] / temperature) # Get new controls. controls = { # name: value control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension control_config, control_action, history, action_multipliers) for (control_action, control_config) in zip(action, control_configs) } return lambda _: controls
def test_kaiming_uniform(self): initializer = initializers.KaimingUniformInitializer() input_shape = (29, 5, 7, 20) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual(tuple(init_value.shape), input_shape)