Пример #1
0
def run_linear_ocm_exp(variant):
    from railrl.tf.ddpg import DDPG
    from railrl.envs.flattened_product_box import FlattenedProductBox
    from railrl.exploration_strategies.ou_strategy import OUStrategy
    from railrl.tf.policies.nn_policy import FeedForwardPolicy
    from railrl.qfunctions.nn_qfunction import FeedForwardCritic
    from railrl.envs.memory.continuous_memory_augmented import (
        ContinuousMemoryAugmented
    )
    from railrl.launchers.launcher_util import (
        set_seed,
    )

    """
    Set up experiment variants.
    """
    seed = variant['seed']
    algo_params = variant['algo_params']
    env_class = variant['env_class']
    env_params = variant['env_params']
    memory_dim = variant['memory_dim']
    ou_params = variant['ou_params']

    set_seed(seed)

    """
    Code for running the experiment.
    """

    env = env_class(**env_params)
    env = ContinuousMemoryAugmented(
        env,
        num_memory_states=memory_dim,
    )
    env = FlattenedProductBox(env)

    qf = FeedForwardCritic(
        name_or_scope="critic",
        env_spec=env.spec,
    )
    policy = FeedForwardPolicy(
        name_or_scope="policy",
        env_spec=env.spec,
    )
    es = OUStrategy(
        env_spec=env.spec,
        **ou_params
    )
    algorithm = DDPG(
        env,
        es,
        policy,
        qf,
        **algo_params
    )

    algorithm.train()
Пример #2
0
def example(variant):
    load_policy_file = variant.get('load_policy_file', None)
    if load_policy_file is not None and exists(load_policy_file):
        with tf.Session():
            data = joblib.load(load_policy_file)
            print(data)
            policy = data['policy']
            qf = data['qf']
            replay_buffer = data['pool']
        env = HalfCheetahEnv()
        es = OUStrategy(action_space=env.action_space)
        use_new_version = variant['use_new_version']
        algorithm = DDPG(
            env,
            es,
            policy,
            qf,
            n_epochs=2,
            batch_size=1024,
            replay_pool=replay_buffer,
            use_new_version=use_new_version,
        )
        algorithm.train()
    else:
        env = HalfCheetahEnv()
        es = OUStrategy(action_space=env.action_space)
        qf = FeedForwardCritic(
            name_or_scope="critic",
            env_spec=env.spec,
        )
        policy = FeedForwardPolicy(
            name_or_scope="actor",
            env_spec=env.spec,
        )
        use_new_version = variant['use_new_version']
        algorithm = DDPG(
            env,
            es,
            policy,
            qf,
            n_epochs=2,
            batch_size=1024,
            use_new_version=use_new_version,
        )
        algorithm.train()
def example(*_):
    env = DoublePendulumEnv()
    es = OUStrategy(env_spec=env.spec)
    qf = FeedForwardCritic(
        name_or_scope="critic",
        env_spec=env.spec,
    )
    policy = FeedForwardPolicy(
        name_or_scope="actor",
        env_spec=env.spec,
    )
    algorithm = DDPG(
        env,
        es,
        policy,
        qf,
        n_epochs=30,
        batch_size=1024,
    )
    algorithm.train()
Пример #4
0
def example(*_):
    env = HalfCheetahEnv()
    es = OUStrategy(env_spec=env.spec)
    qf = FeedForwardCritic(
        name_or_scope="critic",
        env_spec=env.spec,
    )
    policy = FeedForwardPolicy(
        name_or_scope="actor",
        env_spec=env.spec,
    )
    algorithm = DDPG(
        env,
        es,
        policy,
        qf,
        n_epochs=25,
        epoch_length=1000,
        batch_size=1024,
    )
    algorithm.train()
def example(variant):
    env_settings = get_env_settings(
        **variant['env_params']
    )
    env = env_settings['env']
    es = OUStrategy(env_spec=env.spec)
    qf = FeedForwardCritic(
        name_or_scope="critic",
        env_spec=env.spec,
    )
    policy = FeedForwardPolicy(
        name_or_scope="actor",
        env_spec=env.spec,
    )
    algorithm = DDPG(
        env,
        es,
        policy,
        qf,
        **variant['ddpg_params']
    )
    algorithm.train()
Пример #6
0
def run_linear_ocm_exp(variant):
    from railrl.tf.ddpg import DDPG
    from railrl.launchers.launcher_util import (
        set_seed, )
    from railrl.exploration_strategies.ou_strategy import OUStrategy
    from railrl.tf.policies.nn_policy import FeedForwardPolicy
    from railrl.qfunctions.nn_qfunction import FeedForwardCritic
    """
    Set up experiment variants.
    """
    H = variant['H']
    seed = variant['seed']
    algo_params = variant['algo_params']
    env_class = variant['env_class']
    env_params = variant['env_params']
    ou_params = variant['ou_params']

    set_seed(seed)
    """
    Code for running the experiment.
    """

    env = env_class(**env_params)

    qf = FeedForwardCritic(
        name_or_scope="critic",
        env_spec=env.spec,
    )
    policy = FeedForwardPolicy(
        name_or_scope="policy",
        env_spec=env.spec,
    )
    es = OUStrategy(env_spec=env.spec, **ou_params)
    algorithm = DDPG(env, es, policy, qf, **algo_params)

    algorithm.train()