Example #1
0
def her_td3_experiment(variant):
    import gym
    import multiworld.envs.mujoco
    import multiworld.envs.pygame
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.exploration_strategies.epsilon_greedy import EpsilonGreedy
    from railrl.exploration_strategies.gaussian_strategy import GaussianStrategy
    from railrl.exploration_strategies.ou_strategy import OUStrategy
    from railrl.torch.grill.launcher import get_video_save_func
    from railrl.demos.her_td3bc import HerTD3BC
    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    from railrl.data_management.obs_dict_replay_buffer import (
        ObsDictRelabelingBuffer)

    if 'env_id' in variant:
        env = gym.make(variant['env_id'])
    else:
        env = variant['env_class'](**variant['env_kwargs'])

    observation_key = variant['observation_key']
    desired_goal_key = variant['desired_goal_key']
    variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    if variant.get('normalize', False):
        raise NotImplementedError()

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    demo_train_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    demo_test_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = env.observation_space.spaces['observation'].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces['desired_goal'].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = HerTD3BC(env,
                         qf1=qf1,
                         qf2=qf2,
                         policy=policy,
                         exploration_policy=exploration_policy,
                         demo_train_buffer=demo_train_buffer,
                         demo_test_buffer=demo_test_buffer,
                         replay_buffer=replay_buffer,
                         demo_path=variant["demo_path"],
                         **variant['algo_kwargs'])
    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)
    algorithm.to(ptu.device)
    algorithm.train()
Example #2
0
def her_td3_experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    observation_key = variant['observation_key']
    desired_goal_key = variant['desired_goal_key']
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    replay_buffer = variant['replay_buffer_class'](
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    variant['count_based_sampler_kwargs']['replay_buffer'] = replay_buffer
    env = CountBasedGoalSamplingEnv(wrapped_env=env, **variant['count_based_sampler_kwargs'])

    obs_dim = env.observation_space.spaces['observation'].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces['desired_goal'].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(
            action_space=env.action_space,
            **variant['es_kwargs']
        )
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = HerTd3(
        env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        training_env=env,
        exploration_policy=exploration_policy,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )
    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
def her_td3_experiment(variant):
    import multiworld.envs.mujoco
    import multiworld.envs.pygame
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.exploration_strategies.epsilon_greedy import EpsilonGreedy
    from railrl.exploration_strategies.gaussian_strategy import GaussianStrategy
    from railrl.exploration_strategies.ou_strategy import OUStrategy
    from railrl.torch.grill.launcher import get_video_save_func
    from railrl.torch.her.her_td3 import HerTd3
    from railrl.data_management.obs_dict_replay_buffer import (
        ObsDictRelabelingBuffer)

    if 'env_id' in variant:
        env = gym.make(variant['env_id'])
    else:
        env = variant['env_class'](**variant['env_kwargs'])

    imsize = 84
    env = MujocoGymToMultiEnv(env.env)  # unwrap TimeLimit
    env = ImageEnv(env,
                   non_presampled_goal_img_is_garbage=True,
                   recompute_reward=False)

    observation_key = variant['observation_key']
    desired_goal_key = variant['desired_goal_key']
    variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    if variant.get('normalize', False):
        raise NotImplementedError()

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = env.observation_space.spaces[observation_key].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces[desired_goal_key].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)

    use_images_for_q = variant["use_images_for_q"]
    use_images_for_pi = variant["use_images_for_pi"]

    qs = []
    for i in range(2):
        if use_images_for_q:
            image_q = MergedCNN(input_width=imsize,
                                input_height=imsize,
                                output_size=1,
                                input_channels=3,
                                added_fc_input_size=action_dim,
                                **variant['cnn_params'])
            q = ImageStateQ(image_q, None)
        else:
            state_q = FlattenMlp(input_size=action_dim + goal_dim,
                                 output_size=1,
                                 **variant['qf_kwargs'])
            q = ImageStateQ(None, state_q)
        qs.append(q)
    qf1, qf2 = qs

    if use_images_for_pi:
        image_policy = CNNPolicy(
            input_width=imsize,
            input_height=imsize,
            output_size=action_dim,
            input_channels=3,
            **variant['cnn_params'],
            output_activation=torch.tanh,
        )
        policy = ImageStatePolicy(image_policy, None)
    else:
        state_policy = TanhMlpPolicy(input_size=goal_dim,
                                     output_size=action_dim,
                                     **variant['policy_kwargs'])
        policy = ImageStatePolicy(None, state_policy)

    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = HerTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       replay_buffer=replay_buffer,
                       **variant['algo_kwargs'])
    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)
    algorithm.to(ptu.device)
    algorithm.train()
Example #4
0
def experiment(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.torch.her.her import HERTrainer
    from railrl.torch.td3.td3 import TD3 as TD3Trainer
    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    from railrl.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from railrl.samplers.data_collector import GoalConditionedPathCollector
    from railrl.torch.grill.launcher import (
        grill_preprocess_variant,
        get_envs,
        get_exploration_strategy,
        full_experiment_variant_preprocess,
        train_vae_and_update_variant,
        get_video_save_func,
    )

    full_experiment_variant_preprocess(variant)
    if not variant['grill_variant'].get('do_state_exp', False):
        train_vae_and_update_variant(variant)
    variant = variant['grill_variant']

    grill_preprocess_variant(variant)
    eval_env = get_envs(variant)
    expl_env = get_envs(variant)
    es = get_exploration_strategy(variant, expl_env)

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['td3_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    if variant.get("save_video", True):  # Does not work
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            eval_env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        eval_env.vae.to(ptu.device)
        expl_env.vae.to(ptu.device)

    algorithm.train()