def init_random_number_generators(seed=None): """Initializes random generators for Python, NumPy, TensorFlow, and JAX.""" # Seed Python random (None as seed is okay), then use it to seed the others. random.seed(seed) if seed is None: seed = random.randint(0, 2**31 - 1) numpy.random.seed(seed) tf.random.set_seed(seed) return jax_random.get_prng(seed)
def test_from_file(self): params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) filename = self.create_tempfile('params.npy').full_path with open(filename, 'wb') as f: np.save(f, params) initializer = initializers.InitializerFromFile(filename) input_shape = (3, 2) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual('%s' % init_value, '%s' % params)
def test_from_file(self): params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) # `create_tempfile` needs access to --test_tmpdir, however in the OSS world # pytest doesn't run `absltest.main`, so we need to manually parse the flags test_utils.ensure_flag('test_tmpdir') filename = self.create_tempfile('params.npy').full_path with open(filename, 'wb') as f: np.save(f, params) initializer = initializers.InitializerFromFile(filename) input_shape = (3, 2) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual('%s' % init_value, '%s' % params)
def seed(self, seed=None): if seed is None: seed = random.randint(0, 2**31 - 1) self._rng = jax_random.get_prng(seed) return super(SimulatedEnvProblem, self).seed(seed=seed)
def test_orthogonal(self): initializer = initializers.OrthogonalInitializer() input_shape = (29, 5, 7, 20) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual(tuple(init_value.shape), input_shape)
def test_kaiming_uniform(self): initializer = initializers.KaimingUniformInitializer() input_shape = (29, 5, 7, 20) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual(tuple(init_value.shape), input_shape)
def test_lecun_normal(self): initializer = initializers.LeCunNormalInitializer() input_shape = (29, 5, 7, 20) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual(tuple(init_value.shape), input_shape)
def PolicySchedule( history, observation_metrics=( ('train', 'metrics/accuracy'), ('train', 'metrics/loss'), ('eval', 'metrics/accuracy'), ('eval', 'metrics/loss'), ), include_controls_in_observation=False, control_configs=( # (name, start, (low, high), flip) ('learning_rate', 1e-3, (1e-9, 10.0), False), ), observation_range=(0.0, 10.0), action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_and_value_vocab_size=None, policy_dir=gin.REQUIRED, temperature=1.0, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_controls_in_observation: bool, whether to include the controls in observations. control_configs: control configs, see trax.rl.envs.OnlineTuneEnv. observation_range: tuple (low, high), range to clip the metrics to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_and_value_vocab_size: vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. policy_dir: directory with the policy checkpoint. temperature: temperature for sampling from the policy. Returns: a function nontrainable_params(step): float -> {'name': float}, the step-dependent schedule for nontrainable parameters. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() observations = online_tune.history_to_observations( history, observation_metrics, observation_range, control_configs if include_controls_in_observation else None) logging.vlog(1, 'Building observations took %0.2f sec.', time.time() - start_time) if observations.shape[0] == 0: controls = { name: start_value for (name, start_value, _, _) in control_configs } return lambda _: controls # Build the policy network and load its parameters. start_time = time.time() (low, high) = observation_range observation_space = gym.spaces.Box(shape=observations.shape[1:], low=low, high=high) action_space = gym.spaces.MultiDiscrete(nvec=(len(action_multipliers), ) * len(control_configs)) (net, _) = policy_based_utils.policy_and_value_net( bottom_layers_fn=policy_and_value_model, observation_space=observation_space, action_space=action_space, vocab_size=policy_and_value_vocab_size, two_towers=policy_and_value_two_towers, ) logging.vlog(1, 'Building the policy network took %0.2f sec.', time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step, history) (opt_state, state, _, _, _) = policy_based_utils.maybe_restore_opt_state(policy_dir) assert opt_state is not None, 'Policy checkpoint not found.' (params, _, _) = opt_state logging.vlog(1, 'Restoring the policy parameters took %0.2f sec.', time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() n_timesteps = observations.shape[0] # (log_probs, value_preds, state, rng) (log_probs, _, _, _) = policy_based_utils.run_policy( policy_and_value_net_apply=net, observations=np.array([observations]), lengths=np.array([n_timesteps]), weights=params, state=state, rng=rng, action_space=action_space, ) logging.vlog(1, 'Running the policy took %0.2f sec.', time.time() - start_time) # Sample from the action distribution for the last timestep. assert log_probs.shape == (1, len(control_configs), len(action_multipliers)) action = tl.gumbel_sample(log_probs[0], temperature) # Get new controls. controls = { # name: value control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension control_config, control_action, history, action_multipliers) for (control_action, control_config) in zip(action, control_configs) } return lambda _: controls