def test_default_preprocessor(self): rngs = PRNGSequence(13) box = gym.spaces.Box(low=0, high=1, shape=(2, 3)) dsc = gym.spaces.Discrete(7) mbn = gym.spaces.MultiBinary(11) mds = gym.spaces.MultiDiscrete(nvec=[3, 5]) tup = gym.spaces.Tuple((box, dsc, mbn, mds)) dct = gym.spaces.Dict({'box': box, 'dsc': dsc, 'mbn': mbn, 'mds': mds}) self.assertArrayShape(default_preprocessor(box)(next(rngs), box.sample()), (1, 2, 3)) self.assertArrayShape(default_preprocessor(dsc)(next(rngs), dsc.sample()), (1, 7)) self.assertArrayShape(default_preprocessor(mbn)(next(rngs), mbn.sample()), (1, 11)) self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds.sample())[0], (1, 3)) self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds.sample())[1], (1, 5)) self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[0], (1, 2, 3)) self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[1], (1, 7)) self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[2], (1, 11)) self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[3][0], (1, 3)) self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[3][1], (1, 5)) self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['box'], (1, 2, 3)) self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['dsc'], (1, 7)) self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mbn'], (1, 11)) self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mds'][0], (1, 3)) self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mds'][1], (1, 5))
def test_add_noise(): rng = PRNGSequence(0) assert np.isclose(add_noise(0.0, next(rng), 0.0), 0.0) assert np.isclose(add_noise(0.0, next(rng), 0.0, -1.0, 1.0), 0.0) assert np.isclose(add_noise(0.0, next(rng), 100.0, -1.0, 1.0, 0.0, 0.0), 0.0) assert -1.0 <= add_noise(0.0, next(rng), 100.0, -1.0, 1.0) <= 1.0 assert -20.0 <= add_noise(0.0, next(rng), 100.0, -20.0, 20.0) <= 20.0
def test_argmax_random_tiebreaking(self): rngs = PRNGSequence(13) vec = jnp.ones(shape=(5,)) mat = jnp.ones(shape=(3, 5)) self.assertEqual(argmax(next(rngs), vec), 2) # not zero self.assertArrayAlmostEqual(argmax(next(rngs), mat), [1, 1, 3])
def test_tree_sample(self): rngs = PRNGSequence(42) tn = get_transition_batch(self.env_discrete, batch_size=5) tn_sample = tree_sample(tn, next(rngs), n=3) assert tn_sample.batch_size == 3 tn_sample = tree_sample(tn, next(rngs), n=7, replace=True) assert tn_sample.batch_size == 7 msg = r"Cannot take a larger sample than population when 'replace=False'" with self.assertRaisesRegex(ValueError, msg): tree_sample(tn, next(rngs), n=7, replace=False)
def __init__( self, state_dim, action_dim, max_action, lr=3e-4, discount=0.99, tau=0.005, policy_noise=0.2, expl_noise=0.1, noise_clip=0.5, policy_freq=2, seed=0, ): self.rng = PRNGSequence(seed) actor_input_dim = [((1, state_dim), jnp.float32)] init_rng = next(self.rng) actor = build_td3_actor_model(actor_input_dim, action_dim, max_action, init_rng) self.actor_target = build_td3_actor_model(actor_input_dim, action_dim, max_action, init_rng) actor_optimizer = optim.Adam(learning_rate=lr).create(actor) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [ ((1, state_dim), jnp.float32), ((1, action_dim), jnp.float32), ] critic = build_td3_critic_model(critic_input_dim, init_rng) self.critic_target = build_td3_critic_model(critic_input_dim, init_rng) critic_optimizer = optim.Adam(learning_rate=lr).create(critic) self.critic_optimizer = jax.device_put(critic_optimizer) self.max_action = max_action self.discount = discount self.tau = tau self.policy_noise = policy_noise self.expl_noise = expl_noise self.noise_clip = noise_clip self.policy_freq = policy_freq self.total_it = 0
def __init__( self, state_dim: int, action_dim: int, max_action: float, lr: float = 3e-4, discount: float = 0.99, tau: float = 0.005, policy_noise: float = 0.2, expl_noise: float = 0.1, noise_clip: float = 0.5, policy_freq: int = 2, seed: int = 0, ): self.rng = PRNGSequence(seed) actor_input_dim = (1, state_dim) init_rng = next(self.rng) actor_params = build_td3_actor_model( actor_input_dim, action_dim, max_action, init_rng ) self.actor_target_params = build_td3_actor_model( actor_input_dim, action_dim, max_action, init_rng ) actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [(1, state_dim), (1, action_dim)] critic_params = build_td3_critic_model(critic_input_dim, init_rng) self.critic_target_params = build_td3_critic_model(critic_input_dim, init_rng) critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params) self.critic_optimizer = jax.device_put(critic_optimizer) self.max_action = max_action self.discount = discount self.tau = tau self.policy_noise = policy_noise self.expl_noise = expl_noise self.noise_clip = noise_clip self.policy_freq = policy_freq self.action_dim = action_dim self.total_it = 0
def __init__( self, features: tp.Iterable[h5py.Group], labels: h5py.Group, batch_size: int, shuffle_rng=0, ): def get_features(group): if hasattr(group, "train"): return group["train"] return group super().__init__((get_features(g) for g in features), labels, batch_size) self._rng = PRNGSequence(shuffle_rng)
def __init__( self, state_dim: int, action_dim: int, max_action: float, discount: float = 0.99, tau: float = 0.005, policy_freq: int = 2, lr: float = 3e-4, entropy_tune: bool = True, seed: int = 0, ): self.rng = PRNGSequence(seed) actor_input_dim = (1, state_dim) actor_params = build_gaussian_policy_model(actor_input_dim, action_dim, max_action, next(self.rng)) actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [(1, state_dim), (1, action_dim)] critic_params = build_double_critic_model(critic_input_dim, init_rng) self.critic_target_params = build_double_critic_model( critic_input_dim, init_rng) critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params) self.critic_optimizer = jax.device_put(critic_optimizer) self.entropy_tune = entropy_tune log_alpha_params = build_constant_model(-3.5, next(self.rng)) log_alpha_optimizer = optim.Adam( learning_rate=lr).create(log_alpha_params) self.log_alpha_optimizer = jax.device_put(log_alpha_optimizer) self.target_entropy = -action_dim self.max_action = max_action self.discount = discount self.tau = tau self.policy_freq = policy_freq self.action_dim = action_dim self.total_it = 0
def __init__( self, language_dim: int, vision_dim: Tuple[int, int], num_embeddings: int, embedding_dim: int, memory_hidden_dim: int, tokenizer: Tokenizer, discount: float = 0.9, hidden_size: int = 512, lr: float = 1e-5, policy_eps: float = 0.1, baseline_eps: float = 0.5, entropy_eps: float = 1e-5, reconstruction_eps: float = 1.0, seed: int = 0, ): self.rng = PRNGSequence(seed) input_dims = [ (1, language_dim), (1, *vision_dim), ] agent_params = build_fast_slow_agent_model( input_dims, memory_hidden_dim, embedding_dim, num_embeddings, embedding_dim, next(self.rng), ) agent_optimizer = optim.Adam(learning_rate=lr).create(agent_params) self.agent_optimizer = jax.device_put(agent_optimizer) self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.hidden_size = hidden_size self.discount = discount self.policy_eps = policy_eps self.baseline_eps = baseline_eps self.entropy_eps = entropy_eps self.reconstruction_eps = reconstruction_eps self.tokenizer = tokenizer self.total_it = 0
def test_argmax_consistent(self): rngs = PRNGSequence(13) vec = jax.random.normal(next(rngs), shape=(5, )) mat = jax.random.normal(next(rngs), shape=(3, 5)) ten = jax.random.normal(next(rngs), shape=(3, 5, 7)) self.assertEqual(argmax(next(rngs), vec), jnp.argmax(vec, axis=-1)) self.assertArrayAlmostEqual(argmax(next(rngs), mat), jnp.argmax(mat, axis=-1)) self.assertArrayAlmostEqual(argmax(next(rngs), mat, axis=0), jnp.argmax(mat, axis=0)) self.assertArrayAlmostEqual(argmax(next(rngs), ten), jnp.argmax(ten, axis=-1)) self.assertArrayAlmostEqual(argmax(next(rngs), ten, axis=0), jnp.argmax(ten, axis=0)) self.assertArrayAlmostEqual(argmax(next(rngs), ten, axis=1), jnp.argmax(ten, axis=1))
def main(): loss_obj = hk.transform(loss_fn, apply_rng=True) # Initial parameter values are typically random. In JAX you need a key in order # to generate random numbers and so Haiku requires you to pass one in. rng = PRNGSequence(42) # `init` runs your function, as such we need an example input. Typically you can # pass "dummy" inputs (e.g. ones of the same shape and dtype) since initialization # is not usually data dependent. shape = [([1000], float)] adam = optim.Adam(learning_rate=0.1) partial = Net.partial() _, params = partial.init_by_shape(next(rng), shape) net = nn.Model(partial, params) optimizer = jax.device_put(adam.create(net)) _, params = partial.init_by_shape(next(rng), shape) # HERE net = net.replace(params=params) optimizer = jax.device_put(adam.create(net)) # HERE
def __init__( self, num_agent_steps, state_space, action_space, seed, max_grad_norm, gamma, ): np.random.seed(seed) self.rng = PRNGSequence(seed) self.agent_step = 0 self.episode_step = 0 self.learning_step = 0 self.num_agent_steps = num_agent_steps self.state_space = state_space self.action_space = action_space self.gamma = gamma self.max_grad_norm = max_grad_norm self.discrete_action = False if type(action_space) == Box else True
def __init__( self, batch_size, buffer_size, discount, env_id, eval_freq, eval_episodes, learning_rate, load_path, max_time_steps, actor_freq, save_freq, save_model, seed, start_time_steps, tau, train_steps, render, use_tune, prefix=None, env=None, ): self.prefix = prefix seed = int(seed) policy = "SAC" if env_id == "levels": env_id = None self.use_tune = use_tune self.max_time_steps = int(max_time_steps) if max_time_steps else None self.start_time_steps = int(start_time_steps) self.train_steps = int(train_steps) self.batch_size = int(batch_size) self.buffer_size = int(buffer_size) self.eval_freq = eval_freq def make_env(): if env is None: return Environment.wrap(gym.make(env_id) if env_id else Env(100)) return env def eval_policy(): eval_env = make_env() eval_env.seed(seed) avg_reward = 0.0 it = ( itertools.count() if render else tqdm(range(eval_episodes), desc="eval") ) for _ in it: eval_time_step = eval_env.reset() while not eval_time_step.last(): if render: eval_env.render() action = policy.select_action(eval_time_step.observation) eval_time_step = eval_env.step(action) avg_reward += eval_time_step.reward avg_reward /= eval_episodes self.report(eval_reward=avg_reward) return avg_reward self.eval_policy = eval_policy self.report(policy=policy) self.report(env=env_id) self.report(seed=seed) if save_model and not os.path.exists("./models"): os.makedirs("./models") self.env = env = make_env() assert isinstance(env, Environment) # Set seeds np.random.seed(seed) state_shape = env.observation_spec().shape action_dim = env.action_spec().shape[0] max_action = env.max_action() # Initialize policy self.policy = policy = SAC.SAC( # TODO state_shape=state_shape, action_dim=action_dim, max_action=max_action, save_freq=save_freq, discount=discount, lr=learning_rate, actor_freq=actor_freq, tau=tau, ) self.rng = PRNGSequence(seed)
def test_preprocess_state(): rng = PRNGSequence(0) state = np.random.randint(0, 256, size=(64, 64, 3)).astype(np.uint8) state = preprocess_state(state, next(rng)) assert (-0.5 <= state).all() and (state <= 0.5).all()
def __init__( self, state_dim, action_dim, max_action, discount=0.99, lr=3e-4, eps_q=0.1, eps_mu=0.1, eps_sig=1e-4, temp_steps=10, target_freq=250, seed=0, ): self.rng = PRNGSequence(seed) init_rng = next(self.rng) actor_input_dim = [((1, state_dim), jnp.float32)] actor = build_gaussian_policy_model( actor_input_dim, action_dim, max_action, init_rng ) self.actor_target = build_gaussian_policy_model( actor_input_dim, action_dim, max_action, init_rng ) actor_optimizer = optim.Adam(learning_rate=lr).create(actor) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [ ((1, state_dim), jnp.float32), ((1, action_dim), jnp.float32), ] critic = build_double_critic_model(critic_input_dim, init_rng) self.critic_target = build_double_critic_model(critic_input_dim, init_rng) critic_optimizer = optim.Adam(learning_rate=lr).create(critic) self.critic_optimizer = jax.device_put(critic_optimizer) temp = build_constant_model(1.0, next(self.rng)) temp_optimizer = optim.Adam(learning_rate=lr).create(temp) self.temp_optimizer = jax.device_put(temp_optimizer) mu_lagrange = build_constant_model(1.0, next(self.rng)) mu_lagrange_optimizer = optim.Adam(learning_rate=lr).create(mu_lagrange) self.mu_lagrange_optimizer = jax.device_put(mu_lagrange_optimizer) sig_lagrange = build_constant_model(100.0, next(self.rng)) sig_lagrange_optimizer = optim.Adam(learning_rate=lr).create(sig_lagrange) self.sig_lagrange_optimizer = jax.device_put(sig_lagrange_optimizer) self.eps_q = eps_q self.eps_mu = eps_mu self.eps_sig = eps_sig self.temp_steps = temp_steps self.max_action = max_action self.discount = discount self.target_freq = target_freq self.action_dim = action_dim self.total_it = 0
def __init__( self, state_dim: int, action_dim: int, max_action: float, discount: float = 0.99, lr: float = 3e-4, eps_eta: float = 0.1, eps_mu: float = 5e-4, eps_sig: float = 1e-5, target_freq: int = 250, seed: int = 0, ): self.rng = PRNGSequence(seed) init_rng = next(self.rng) actor_input_dim = (1, state_dim) actor_params = build_gaussian_policy_model( actor_input_dim, action_dim, max_action, init_rng ) self.actor_target_params = build_gaussian_policy_model( actor_input_dim, action_dim, max_action, init_rng ) actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params) self.actor_optimizer = jax.device_put(actor_optimizer) init_rng = next(self.rng) critic_input_dim = [(1, state_dim), (1, action_dim)] critic_params = build_double_critic_model(critic_input_dim, init_rng) self.critic_target_params = build_double_critic_model( critic_input_dim, init_rng ) critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params) self.critic_optimizer = jax.device_put(critic_optimizer) mu_lagrange_params = build_constant_model( 1.0, absolute=True, init_rng=next(self.rng) ) mu_lagrange_optimizer = optim.Adam(learning_rate=lr).create(mu_lagrange_params) self.mu_lagrange_optimizer = jax.device_put(mu_lagrange_optimizer) sig_lagrange_params = build_constant_model( 100.0, absolute=True, init_rng=next(self.rng) ) sig_lagrange_optimizer = optim.Adam(learning_rate=lr).create( sig_lagrange_params ) self.sig_lagrange_optimizer = jax.device_put(sig_lagrange_optimizer) self.temp = 1.0 self.eps_eta = eps_eta self.eps_mu = eps_mu self.eps_sig = eps_sig self.max_action = max_action self.discount = discount self.target_freq = target_freq self.state_dim = state_dim self.action_dim = action_dim self.total_it = 0
def seed(self, seed=None): self.rng = PRNGSequence(seed)
def __init__( self, state_shape, action_dim, max_action, save_freq, discount=0.99, tau=0.005, actor_freq=2, lr=3e-4, entropy_tune=False, seed=0, ): self.rng = PRNGSequence(seed) actor_input_dim = [((1, *state_shape), jnp.float32)] critic_input_dim = self.critic_input_dim = [ ((1, *state_shape), jnp.float32), ((1, action_dim), jnp.float32), ] self.actor = None self.critic = None self.log_alpha = None self.entropy_tune = entropy_tune self.target_entropy = -action_dim self.adam = Optimizers( actor=optim.Adam(learning_rate=lr), critic=optim.Adam(learning_rate=lr), log_alpha=optim.Adam(learning_rate=lr), ) self.module = Modules( actor=GaussianPolicy.partial(action_dim=action_dim, max_action=max_action), critic=DoubleCritic.partial(), alpha=Constant.partial(start_value=1), ) self.optimizer = None self.max_action = max_action self.discount = discount self.tau = tau self.policy_freq = actor_freq self.save_freq = save_freq self.total_it = 0 self.model = None def new_params(module: nn.Module, shape=None): _, params = (module.init(next(self.rng)) if shape is None else module.init_by_shape(next(self.rng), shape)) return params def new_model(module: nn.Module, shape=None) -> nn.Model: return nn.Model(module, new_params(module, shape)) def update_model(model: nn.Model, shape=None) -> nn.Model: return model.replace(params=new_params(model.module, shape)) def reset_models() -> Models: if self.model is None: critic = new_model(self.module.critic, critic_input_dim) return Models( actor=new_model(self.module.actor, actor_input_dim), critic=critic, target_critic=critic.replace(params=critic.params), alpha=new_model(self.module.alpha), ) else: critic = update_model(self.model.critic, critic_input_dim) return Models( actor=update_model(self.model.actor, actor_input_dim), critic=critic, target_critic=critic.replace(params=critic.params), alpha=update_model(self.model.alpha), ) self.reset_models = reset_models def reset_optimizer(adam: Adam, model: nn.Model) -> Optimizer: return jax.device_put(adam.create(model)) def reset_optimizers() -> Optimizers: return Optimizers( actor=reset_optimizer(self.adam.actor, self.model.actor), critic=reset_optimizer(self.adam.critic, self.model.critic), log_alpha=reset_optimizer(self.adam.log_alpha, self.model.alpha), ) self.reset_optimizers = reset_optimizers self.i = 0