Beispiel #1
0
from blackbox_mpc.policies.mpc_policy import \
    MPCPolicy
from blackbox_mpc.utils.dynamics_learning import learn_dynamics_from_policy
from blackbox_mpc.environment_utils import EnvironmentWrapper
from blackbox_mpc.utils.pendulum import pendulum_reward_function
import gym
import tensorflow as tf

env = gym.make("Pendulum-v0")
dynamics_function = DeterministicMLP(
    layers=[
        env.action_space.shape[0] + env.observation_space.shape[0], 32, 32, 32,
        env.observation_space.shape[0]
    ],
    activation_functions=[tf.math.tanh, tf.math.tanh, tf.math.tanh, None])
policy = RandomPolicy(number_of_agents=10, env_action_space=env.action_space)
dynamics_handler = learn_dynamics_from_policy(
    env=EnvironmentWrapper.make_standard_gym_env("Pendulum-v0",
                                                 num_of_agents=10),
    policy=policy,
    number_of_rollouts=5,
    task_horizon=200,
    dynamics_function=dynamics_function)
mpc_policy = MPCPolicy(reward_function=pendulum_reward_function,
                       env_action_space=env.action_space,
                       env_observation_space=env.observation_space,
                       dynamics_handler=dynamics_handler,
                       optimizer_name='CEM')

current_obs = env.reset()
for t in range(200):
Beispiel #2
0
import tensorflow as tf

log_dir = './'
tf_writer = tf.summary.create_file_writer(log_dir)
env = HalfCheetahEnvModified()
num_of_agents = 1
parallel_env = EnvironmentWrapper.make_custom_gym_env(
    HalfCheetahEnvModified, num_of_agents=num_of_agents)

dynamics_function = DeterministicMLP(
    layers=[
        env.action_space.shape[0] + env.observation_space.shape[0], 500, 500,
        500, env.observation_space.shape[0]
    ],
    activation_functions=[tf.math.tanh, tf.math.tanh, tf.math.tanh, None])
initial_policy = RandomPolicy(number_of_agents=num_of_agents,
                              env_action_space=env.action_space)

system_dynamics_handler, mpc_policy = learn_dynamics_iteratively_w_mpc(
    env=parallel_env,
    env_action_space=env.action_space,
    env_observation_space=env.observation_space,
    number_of_initial_rollouts=5,
    number_of_rollouts_for_refinement=3,
    number_of_refinement_steps=1,
    task_horizon=1000,
    planning_horizon=15,
    initial_policy=initial_policy,
    dynamics_function=dynamics_function,
    num_agents=num_of_agents,
    reward_function=reward_function,
    log_dir=log_dir,