예제 #1
0
    def test_default_preprocessor(self):
        rngs = PRNGSequence(13)

        box = gym.spaces.Box(low=0, high=1, shape=(2, 3))
        dsc = gym.spaces.Discrete(7)
        mbn = gym.spaces.MultiBinary(11)
        mds = gym.spaces.MultiDiscrete(nvec=[3, 5])
        tup = gym.spaces.Tuple((box, dsc, mbn, mds))
        dct = gym.spaces.Dict({'box': box, 'dsc': dsc, 'mbn': mbn, 'mds': mds})

        self.assertArrayShape(default_preprocessor(box)(next(rngs), box.sample()), (1, 2, 3))
        self.assertArrayShape(default_preprocessor(dsc)(next(rngs), dsc.sample()), (1, 7))
        self.assertArrayShape(default_preprocessor(mbn)(next(rngs), mbn.sample()), (1, 11))
        self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds.sample())[0], (1, 3))
        self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds.sample())[1], (1, 5))

        self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[0], (1, 2, 3))
        self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[1], (1, 7))
        self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[2], (1, 11))
        self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[3][0], (1, 3))
        self.assertArrayShape(default_preprocessor(tup)(next(rngs), tup.sample())[3][1], (1, 5))

        self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['box'], (1, 2, 3))
        self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['dsc'], (1, 7))
        self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mbn'], (1, 11))
        self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mds'][0], (1, 3))
        self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mds'][1], (1, 5))
예제 #2
0
def test_add_noise():
    rng = PRNGSequence(0)
    assert np.isclose(add_noise(0.0, next(rng), 0.0), 0.0)
    assert np.isclose(add_noise(0.0, next(rng), 0.0, -1.0, 1.0), 0.0)
    assert np.isclose(add_noise(0.0, next(rng), 100.0, -1.0, 1.0, 0.0, 0.0),
                      0.0)
    assert -1.0 <= add_noise(0.0, next(rng), 100.0, -1.0, 1.0) <= 1.0
    assert -20.0 <= add_noise(0.0, next(rng), 100.0, -20.0, 20.0) <= 20.0
예제 #3
0
    def test_argmax_random_tiebreaking(self):
        rngs = PRNGSequence(13)

        vec = jnp.ones(shape=(5,))
        mat = jnp.ones(shape=(3, 5))

        self.assertEqual(argmax(next(rngs), vec), 2)  # not zero
        self.assertArrayAlmostEqual(argmax(next(rngs), mat), [1, 1, 3])
예제 #4
0
    def test_tree_sample(self):
        rngs = PRNGSequence(42)
        tn = get_transition_batch(self.env_discrete, batch_size=5)

        tn_sample = tree_sample(tn, next(rngs), n=3)
        assert tn_sample.batch_size == 3

        tn_sample = tree_sample(tn, next(rngs), n=7, replace=True)
        assert tn_sample.batch_size == 7

        msg = r"Cannot take a larger sample than population when 'replace=False'"
        with self.assertRaisesRegex(ValueError, msg):
            tree_sample(tn, next(rngs), n=7, replace=False)
예제 #5
0
파일: TD3.py 프로젝트: ethanabrooks/jax-rl
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        lr=3e-4,
        discount=0.99,
        tau=0.005,
        policy_noise=0.2,
        expl_noise=0.1,
        noise_clip=0.5,
        policy_freq=2,
        seed=0,
    ):

        self.rng = PRNGSequence(seed)

        actor_input_dim = [((1, state_dim), jnp.float32)]

        init_rng = next(self.rng)

        actor = build_td3_actor_model(actor_input_dim, action_dim, max_action,
                                      init_rng)
        self.actor_target = build_td3_actor_model(actor_input_dim, action_dim,
                                                  max_action, init_rng)
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [
            ((1, state_dim), jnp.float32),
            ((1, action_dim), jnp.float32),
        ]

        critic = build_td3_critic_model(critic_input_dim, init_rng)
        self.critic_target = build_td3_critic_model(critic_input_dim, init_rng)
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.expl_noise = expl_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.total_it = 0
예제 #6
0
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        lr: float = 3e-4,
        discount: float = 0.99,
        tau: float = 0.005,
        policy_noise: float = 0.2,
        expl_noise: float = 0.1,
        noise_clip: float = 0.5,
        policy_freq: int = 2,
        seed: int = 0,
    ):
        self.rng = PRNGSequence(seed)

        actor_input_dim = (1, state_dim)

        init_rng = next(self.rng)

        actor_params = build_td3_actor_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        self.actor_target_params = build_td3_actor_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [(1, state_dim), (1, action_dim)]

        critic_params = build_td3_critic_model(critic_input_dim, init_rng)
        self.critic_target_params = build_td3_critic_model(critic_input_dim, init_rng)
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.expl_noise = expl_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq

        self.action_dim = action_dim

        self.total_it = 0
예제 #7
0
    def __init__(
        self,
        features: tp.Iterable[h5py.Group],
        labels: h5py.Group,
        batch_size: int,
        shuffle_rng=0,
    ):
        def get_features(group):
            if hasattr(group, "train"):
                return group["train"]
            return group

        super().__init__((get_features(g) for g in features), labels,
                         batch_size)
        self._rng = PRNGSequence(shuffle_rng)
예제 #8
0
파일: SAC.py 프로젝트: henry-prior/jax-rl
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        discount: float = 0.99,
        tau: float = 0.005,
        policy_freq: int = 2,
        lr: float = 3e-4,
        entropy_tune: bool = True,
        seed: int = 0,
    ):
        self.rng = PRNGSequence(seed)

        actor_input_dim = (1, state_dim)

        actor_params = build_gaussian_policy_model(actor_input_dim, action_dim,
                                                   max_action, next(self.rng))
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [(1, state_dim), (1, action_dim)]

        critic_params = build_double_critic_model(critic_input_dim, init_rng)
        self.critic_target_params = build_double_critic_model(
            critic_input_dim, init_rng)
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        self.entropy_tune = entropy_tune

        log_alpha_params = build_constant_model(-3.5, next(self.rng))
        log_alpha_optimizer = optim.Adam(
            learning_rate=lr).create(log_alpha_params)
        self.log_alpha_optimizer = jax.device_put(log_alpha_optimizer)
        self.target_entropy = -action_dim

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_freq = policy_freq

        self.action_dim = action_dim

        self.total_it = 0
예제 #9
0
    def __init__(
        self,
        language_dim: int,
        vision_dim: Tuple[int, int],
        num_embeddings: int,
        embedding_dim: int,
        memory_hidden_dim: int,
        tokenizer: Tokenizer,
        discount: float = 0.9,
        hidden_size: int = 512,
        lr: float = 1e-5,
        policy_eps: float = 0.1,
        baseline_eps: float = 0.5,
        entropy_eps: float = 1e-5,
        reconstruction_eps: float = 1.0,
        seed: int = 0,
    ):
        self.rng = PRNGSequence(seed)

        input_dims = [
            (1, language_dim),
            (1, *vision_dim),
        ]
        agent_params = build_fast_slow_agent_model(
            input_dims,
            memory_hidden_dim,
            embedding_dim,
            num_embeddings,
            embedding_dim,
            next(self.rng),
        )
        agent_optimizer = optim.Adam(learning_rate=lr).create(agent_params)
        self.agent_optimizer = jax.device_put(agent_optimizer)

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.discount = discount
        self.policy_eps = policy_eps
        self.baseline_eps = baseline_eps
        self.entropy_eps = entropy_eps
        self.reconstruction_eps = reconstruction_eps

        self.tokenizer = tokenizer

        self.total_it = 0
예제 #10
0
    def test_argmax_consistent(self):
        rngs = PRNGSequence(13)

        vec = jax.random.normal(next(rngs), shape=(5, ))
        mat = jax.random.normal(next(rngs), shape=(3, 5))
        ten = jax.random.normal(next(rngs), shape=(3, 5, 7))

        self.assertEqual(argmax(next(rngs), vec), jnp.argmax(vec, axis=-1))
        self.assertArrayAlmostEqual(argmax(next(rngs), mat),
                                    jnp.argmax(mat, axis=-1))
        self.assertArrayAlmostEqual(argmax(next(rngs), mat, axis=0),
                                    jnp.argmax(mat, axis=0))
        self.assertArrayAlmostEqual(argmax(next(rngs), ten),
                                    jnp.argmax(ten, axis=-1))
        self.assertArrayAlmostEqual(argmax(next(rngs), ten, axis=0),
                                    jnp.argmax(ten, axis=0))
        self.assertArrayAlmostEqual(argmax(next(rngs), ten, axis=1),
                                    jnp.argmax(ten, axis=1))
예제 #11
0
def main():
    loss_obj = hk.transform(loss_fn, apply_rng=True)
    # Initial parameter values are typically random. In JAX you need a key in order
    # to generate random numbers and so Haiku requires you to pass one in.
    rng = PRNGSequence(42)

    # `init` runs your function, as such we need an example input. Typically you can
    # pass "dummy" inputs (e.g. ones of the same shape and dtype) since initialization
    # is not usually data dependent.
    shape = [([1000], float)]
    adam = optim.Adam(learning_rate=0.1)
    partial = Net.partial()
    _, params = partial.init_by_shape(next(rng), shape)
    net = nn.Model(partial, params)

    optimizer = jax.device_put(adam.create(net))
    _, params = partial.init_by_shape(next(rng), shape)  # HERE
    net = net.replace(params=params)
    optimizer = jax.device_put(adam.create(net))  # HERE
예제 #12
0
    def __init__(
        self,
        num_agent_steps,
        state_space,
        action_space,
        seed,
        max_grad_norm,
        gamma,
    ):
        np.random.seed(seed)
        self.rng = PRNGSequence(seed)

        self.agent_step = 0
        self.episode_step = 0
        self.learning_step = 0
        self.num_agent_steps = num_agent_steps
        self.state_space = state_space
        self.action_space = action_space
        self.gamma = gamma
        self.max_grad_norm = max_grad_norm
        self.discrete_action = False if type(action_space) == Box else True
예제 #13
0
파일: main.py 프로젝트: ethanabrooks/jax-rl
    def __init__(
        self,
        batch_size,
        buffer_size,
        discount,
        env_id,
        eval_freq,
        eval_episodes,
        learning_rate,
        load_path,
        max_time_steps,
        actor_freq,
        save_freq,
        save_model,
        seed,
        start_time_steps,
        tau,
        train_steps,
        render,
        use_tune,
        prefix=None,
        env=None,
    ):
        self.prefix = prefix
        seed = int(seed)
        policy = "SAC"
        if env_id == "levels":
            env_id = None
        self.use_tune = use_tune
        self.max_time_steps = int(max_time_steps) if max_time_steps else None
        self.start_time_steps = int(start_time_steps)
        self.train_steps = int(train_steps)
        self.batch_size = int(batch_size)
        self.buffer_size = int(buffer_size)
        self.eval_freq = eval_freq

        def make_env():
            if env is None:
                return Environment.wrap(gym.make(env_id) if env_id else Env(100))
            return env

        def eval_policy():
            eval_env = make_env()
            eval_env.seed(seed)

            avg_reward = 0.0
            it = (
                itertools.count() if render else tqdm(range(eval_episodes), desc="eval")
            )
            for _ in it:
                eval_time_step = eval_env.reset()
                while not eval_time_step.last():
                    if render:
                        eval_env.render()
                    action = policy.select_action(eval_time_step.observation)
                    eval_time_step = eval_env.step(action)
                    avg_reward += eval_time_step.reward

            avg_reward /= eval_episodes

            self.report(eval_reward=avg_reward)
            return avg_reward

        self.eval_policy = eval_policy

        self.report(policy=policy)
        self.report(env=env_id)
        self.report(seed=seed)
        if save_model and not os.path.exists("./models"):
            os.makedirs("./models")
        self.env = env = make_env()
        assert isinstance(env, Environment)
        # Set seeds
        np.random.seed(seed)
        state_shape = env.observation_spec().shape
        action_dim = env.action_spec().shape[0]
        max_action = env.max_action()
        # Initialize policy
        self.policy = policy = SAC.SAC(  # TODO
            state_shape=state_shape,
            action_dim=action_dim,
            max_action=max_action,
            save_freq=save_freq,
            discount=discount,
            lr=learning_rate,
            actor_freq=actor_freq,
            tau=tau,
        )
        self.rng = PRNGSequence(seed)
예제 #14
0
def test_preprocess_state():
    rng = PRNGSequence(0)
    state = np.random.randint(0, 256, size=(64, 64, 3)).astype(np.uint8)
    state = preprocess_state(state, next(rng))
    assert (-0.5 <= state).all() and (state <= 0.5).all()
예제 #15
0
파일: MPO.py 프로젝트: ethanabrooks/jax-rl
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        discount=0.99,
        lr=3e-4,
        eps_q=0.1,
        eps_mu=0.1,
        eps_sig=1e-4,
        temp_steps=10,
        target_freq=250,
        seed=0,
    ):

        self.rng = PRNGSequence(seed)

        init_rng = next(self.rng)

        actor_input_dim = [((1, state_dim), jnp.float32)]

        actor = build_gaussian_policy_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        self.actor_target = build_gaussian_policy_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [
            ((1, state_dim), jnp.float32),
            ((1, action_dim), jnp.float32),
        ]

        critic = build_double_critic_model(critic_input_dim, init_rng)
        self.critic_target = build_double_critic_model(critic_input_dim, init_rng)
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        temp = build_constant_model(1.0, next(self.rng))
        temp_optimizer = optim.Adam(learning_rate=lr).create(temp)
        self.temp_optimizer = jax.device_put(temp_optimizer)

        mu_lagrange = build_constant_model(1.0, next(self.rng))
        mu_lagrange_optimizer = optim.Adam(learning_rate=lr).create(mu_lagrange)
        self.mu_lagrange_optimizer = jax.device_put(mu_lagrange_optimizer)
        sig_lagrange = build_constant_model(100.0, next(self.rng))
        sig_lagrange_optimizer = optim.Adam(learning_rate=lr).create(sig_lagrange)
        self.sig_lagrange_optimizer = jax.device_put(sig_lagrange_optimizer)

        self.eps_q = eps_q
        self.eps_mu = eps_mu
        self.eps_sig = eps_sig
        self.temp_steps = temp_steps

        self.max_action = max_action
        self.discount = discount
        self.target_freq = target_freq

        self.action_dim = action_dim

        self.total_it = 0
예제 #16
0
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        discount: float = 0.99,
        lr: float = 3e-4,
        eps_eta: float = 0.1,
        eps_mu: float = 5e-4,
        eps_sig: float = 1e-5,
        target_freq: int = 250,
        seed: int = 0,
    ):
        self.rng = PRNGSequence(seed)

        init_rng = next(self.rng)

        actor_input_dim = (1, state_dim)

        actor_params = build_gaussian_policy_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        self.actor_target_params = build_gaussian_policy_model(
            actor_input_dim, action_dim, max_action, init_rng
        )
        actor_optimizer = optim.Adam(learning_rate=lr).create(actor_params)
        self.actor_optimizer = jax.device_put(actor_optimizer)

        init_rng = next(self.rng)

        critic_input_dim = [(1, state_dim), (1, action_dim)]

        critic_params = build_double_critic_model(critic_input_dim, init_rng)
        self.critic_target_params = build_double_critic_model(
            critic_input_dim, init_rng
        )
        critic_optimizer = optim.Adam(learning_rate=lr).create(critic_params)
        self.critic_optimizer = jax.device_put(critic_optimizer)

        mu_lagrange_params = build_constant_model(
            1.0, absolute=True, init_rng=next(self.rng)
        )
        mu_lagrange_optimizer = optim.Adam(learning_rate=lr).create(mu_lagrange_params)
        self.mu_lagrange_optimizer = jax.device_put(mu_lagrange_optimizer)

        sig_lagrange_params = build_constant_model(
            100.0, absolute=True, init_rng=next(self.rng)
        )
        sig_lagrange_optimizer = optim.Adam(learning_rate=lr).create(
            sig_lagrange_params
        )
        self.sig_lagrange_optimizer = jax.device_put(sig_lagrange_optimizer)

        self.temp = 1.0
        self.eps_eta = eps_eta
        self.eps_mu = eps_mu
        self.eps_sig = eps_sig

        self.max_action = max_action
        self.discount = discount
        self.target_freq = target_freq

        self.state_dim = state_dim
        self.action_dim = action_dim

        self.total_it = 0
예제 #17
0
 def seed(self, seed=None):
     self.rng = PRNGSequence(seed)
예제 #18
0
파일: SAC.py 프로젝트: ethanabrooks/jax-rl
    def __init__(
        self,
        state_shape,
        action_dim,
        max_action,
        save_freq,
        discount=0.99,
        tau=0.005,
        actor_freq=2,
        lr=3e-4,
        entropy_tune=False,
        seed=0,
    ):

        self.rng = PRNGSequence(seed)

        actor_input_dim = [((1, *state_shape), jnp.float32)]
        critic_input_dim = self.critic_input_dim = [
            ((1, *state_shape), jnp.float32),
            ((1, action_dim), jnp.float32),
        ]
        self.actor = None
        self.critic = None
        self.log_alpha = None
        self.entropy_tune = entropy_tune
        self.target_entropy = -action_dim

        self.adam = Optimizers(
            actor=optim.Adam(learning_rate=lr),
            critic=optim.Adam(learning_rate=lr),
            log_alpha=optim.Adam(learning_rate=lr),
        )
        self.module = Modules(
            actor=GaussianPolicy.partial(action_dim=action_dim,
                                         max_action=max_action),
            critic=DoubleCritic.partial(),
            alpha=Constant.partial(start_value=1),
        )
        self.optimizer = None

        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_freq = actor_freq
        self.save_freq = save_freq

        self.total_it = 0
        self.model = None

        def new_params(module: nn.Module, shape=None):
            _, params = (module.init(next(self.rng)) if shape is None else
                         module.init_by_shape(next(self.rng), shape))
            return params

        def new_model(module: nn.Module, shape=None) -> nn.Model:
            return nn.Model(module, new_params(module, shape))

        def update_model(model: nn.Model, shape=None) -> nn.Model:
            return model.replace(params=new_params(model.module, shape))

        def reset_models() -> Models:
            if self.model is None:
                critic = new_model(self.module.critic, critic_input_dim)
                return Models(
                    actor=new_model(self.module.actor, actor_input_dim),
                    critic=critic,
                    target_critic=critic.replace(params=critic.params),
                    alpha=new_model(self.module.alpha),
                )
            else:
                critic = update_model(self.model.critic, critic_input_dim)
                return Models(
                    actor=update_model(self.model.actor, actor_input_dim),
                    critic=critic,
                    target_critic=critic.replace(params=critic.params),
                    alpha=update_model(self.model.alpha),
                )

        self.reset_models = reset_models

        def reset_optimizer(adam: Adam, model: nn.Model) -> Optimizer:
            return jax.device_put(adam.create(model))

        def reset_optimizers() -> Optimizers:
            return Optimizers(
                actor=reset_optimizer(self.adam.actor, self.model.actor),
                critic=reset_optimizer(self.adam.critic, self.model.critic),
                log_alpha=reset_optimizer(self.adam.log_alpha,
                                          self.model.alpha),
            )

        self.reset_optimizers = reset_optimizers
        self.i = 0