コード例 #1
0
def create_model(name="discrete"):
    if name == "discrete":
        return lambda: DiscreteUniform(n_actions=10)
    elif name == "continuous":
        bs = Bounds(low=-1, high=1, shape=(3, ))
        return lambda: ContinuousUniform(bounds=bs)
    elif name == "random_normal":
        bs = Bounds(low=-1, high=1, shape=(3, ))
        return lambda: NormalContinuous(loc=0, scale=1, bounds=bs)
    elif name == "discrete_with_critic":
        critic = GaussianDt(max_dt=3)
        return lambda: DiscreteUniform(n_actions=10, critic=critic)
    raise ValueError("Invalid param `name`.")
コード例 #2
0
ファイル: test_swarm.py プロジェクト: softmaxhuanchen/fragile
def create_cartpole_swarm():
    swarm = Swarm(
        model=lambda x: DiscreteUniform(env=x),
        walkers=Walkers,
        env=lambda: DiscreteEnv(ClassicControl("CartPole-v0")),
        reward_limit=121,
        n_walkers=150,
        max_epochs=300,
        reward_scale=2,
    )
    return swarm
コード例 #3
0
ファイル: test_swarm.py プロジェクト: vmarkovtsev/fragile
def create_cartpole_swarm():
    swarm = Swarm(
        model=lambda x: DiscreteUniform(env=x),
        walkers=Walkers,
        env=lambda: DiscreteEnv(ClassicControl()),
        n_walkers=20,
        max_iters=200,
        prune_tree=True,
        reward_scale=2,
    )
    return swarm
コード例 #4
0
ファイル: test_model.py プロジェクト: vmarkovtsev/fragile
    def test_sample_with_critic(self, n_actions):
        model = DiscreteUniform(n_actions=n_actions, critic=DummyCritic())
        model_states = model.predict(batch_size=1000)
        actions = model_states.actions
        assert len(actions.shape) == 1
        assert len(numpy.unique(actions)) <= n_actions
        assert all(actions >= 0)
        assert all(actions <= n_actions)
        assert "critic_score" in model_states.keys()
        assert (model_states.critic_score == 5).all()

        states = create_model_states(batch_size=100, model=model)
        model_states = model.sample(batch_size=states.n, model_states=states)
        actions = model_states.actions
        assert len(actions.shape) == 1
        assert len(numpy.unique(actions)) <= n_actions
        assert all(actions >= 0)
        assert all(actions <= n_actions)
        assert numpy.allclose(actions, actions.astype(int))
        assert "critic_score" in model_states.keys()
        assert (model_states.critic_score == 5).all()
コード例 #5
0
def create_atari_swarm():
    env = AtariEnvironment(name="MsPacman-ram-v0", )
    dt = GaussianDt(min_dt=10, max_dt=100, loc_dt=5, scale_dt=2)
    swarm = Swarm(
        model=lambda x: DiscreteUniform(env=x, critic=dt),
        env=lambda: DiscreteEnv(env),
        n_walkers=6,
        max_epochs=10,
        reward_scale=2,
        reward_limit=1,
    )
    return swarm
コード例 #6
0
ファイル: test_model.py プロジェクト: vmarkovtsev/fragile
    def test_sample(self, n_actions):
        model = DiscreteUniform(n_actions=n_actions)
        model_states = model.predict(batch_size=1000)
        actions = model_states.actions
        assert len(actions.shape) == 1
        assert len(numpy.unique(actions)) <= n_actions
        assert all(actions >= 0)
        assert all(actions <= n_actions)
        assert "critic_score" in model_states.keys()
        assert isinstance(model_states.critic_score, numpy.ndarray)
        assert (
            model_states.critic_score == 1).all(), model_states.critic_score

        states = create_model_states(batch_size=100, model=model)
        model_states = model.sample(batch_size=states.n, model_states=states)
        actions = model_states.actions
        assert len(actions.shape) == 1
        assert len(numpy.unique(actions)) <= n_actions
        assert all(actions >= 0)
        assert all(actions <= n_actions)
        assert numpy.allclose(actions, actions.astype(int))
        assert "critic_score" in model_states.keys()
        assert (model_states.critic_score == 1).all()
コード例 #7
0
    def test_sample(self, n_actions):
        model = DiscreteUniform(n_actions=n_actions)
        model_states = model.predict(batch_size=1000)
        actions = model_states.actions
        assert len(actions.shape) == 1
        assert len(judo.unique(actions)) <= n_actions
        assert all(actions >= 0)
        assert all(actions <= n_actions)
        assert "critic_score" in model_states.keys()
        assert dtype.is_tensor(model_states.critic_score)
        assert (
            model_states.critic_score == 1).all(), model_states.critic_score

        states = create_model_states(batch_size=100, model=model)
        model_states = model.sample(batch_size=states.n, model_states=states)
        actions = model_states.actions
        assert len(actions.shape) == 1
        assert len(judo.unique(actions)) <= n_actions
        assert all(actions >= 0)
        assert all(actions <= n_actions)
        assert judo.allclose(actions, judo.astype(actions, dtype.int))
        assert "critic_score" in model_states.keys()
        assert (model_states.critic_score == 1).all()
コード例 #8
0
ファイル: test_swarm.py プロジェクト: softmaxhuanchen/fragile
def create_atari_swarm():
    env = AtariEnvironment(name="MsPacman-ram-v0",
                           clone_seeds=True,
                           autoreset=True)
    dt = GaussianDt(min_dt=3, max_dt=100, loc_dt=5, scale_dt=2)
    swarm = Swarm(
        model=lambda x: DiscreteUniform(env=x, critic=dt),
        walkers=Walkers,
        env=lambda: DiscreteEnv(env),
        n_walkers=67,
        max_epochs=500,
        reward_scale=2,
        reward_limit=751,
    )
    return swarm
コード例 #9
0
ファイル: test_swarm.py プロジェクト: vmarkovtsev/fragile
def create_atari_swarm():
    env = ParallelEnvironment(
        env_class=AtariEnvironment,
        name="MsPacman-ram-v0",
        clone_seeds=True,
        autoreset=True,
        blocking=False,
    )
    dt = GaussianDt(min_dt=3, max_dt=100, loc_dt=5, scale_dt=2)
    swarm = Swarm(
        model=lambda x: DiscreteUniform(env=x, critic=dt),
        walkers=Walkers,
        env=lambda: DiscreteEnv(env),
        n_walkers=67,
        max_iters=20,
        prune_tree=True,
        reward_scale=2,
    )
    return swarm