Beispiel #1
0
def td3_experiment_online_vae_exploring(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.torch.her.online_vae_joint_algo import OnlineVaeHerJointAlgo
    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    from railrl.torch.td3.td3 import TD3
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    exploring_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    exploring_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs'],
    )
    exploring_exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=exploring_policy,
    )

    vae = env.vae
    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    if variant.get('use_replay_buffer_goals', False):
        env.replay_buffer = replay_buffer
        env.use_replay_buffer_goals = True

    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)

    control_algorithm = TD3(env=env,
                            training_env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs'])
    exploring_algorithm = TD3(env=env,
                              training_env=env,
                              qf1=exploring_qf1,
                              qf2=exploring_qf2,
                              policy=exploring_policy,
                              exploration_policy=exploring_exploration_policy,
                              **variant['algo_kwargs'])

    assert 'vae_training_schedule' not in variant,\
        "Just put it in joint_algo_kwargs"
    algorithm = OnlineVaeHerJointAlgo(vae=vae,
                                      vae_trainer=t,
                                      env=env,
                                      training_env=env,
                                      policy=policy,
                                      exploration_policy=exploration_policy,
                                      replay_buffer=replay_buffer,
                                      algo1=control_algorithm,
                                      algo2=exploring_algorithm,
                                      algo1_prefix="Control_",
                                      algo2_prefix="VAE_Exploration_",
                                      observation_key=observation_key,
                                      desired_goal_key=desired_goal_key,
                                      **variant['joint_algo_kwargs'])

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    algorithm.train()
def td3_experiment(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)

    from railrl.torch.td3.td3 import TD3 as TD3Trainer
    from railrl.torch.torch_rl_algorithm import TorchBatchRLAlgorithm

    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    # preprocess_rl_variant(variant)
    env = get_envs(variant)
    expl_env = env
    eval_env = env
    es = get_exploration_strategy(variant, env)

    if variant.get("use_masks", False):
        mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict())

        expl_mask_distribution_kwargs = variant[
            "expl_mask_distribution_kwargs"]
        expl_mask_distribution = DiscreteDistribution(
            **expl_mask_distribution_kwargs)
        expl_env = RewardMaskWrapper(env, expl_mask_distribution,
                                     **mask_wrapper_kwargs)

        eval_mask_distribution_kwargs = variant[
            "eval_mask_distribution_kwargs"]
        eval_mask_distribution = DiscreteDistribution(
            **eval_mask_distribution_kwargs)
        eval_env = RewardMaskWrapper(env, eval_mask_distribution,
                                     **mask_wrapper_kwargs)
        env = eval_env

    max_path_length = variant['max_path_length']

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key',
                                    'latent_achieved_goal')
    # achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)

    action_dim = env.action_space.low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])

    if variant.get("use_subgoal_policy", False):
        from railrl.policies.timed_policy import SubgoalPolicyWrapper

        subgoal_policy_kwargs = variant.get('subgoal_policy_kwargs', {})

        policy = SubgoalPolicyWrapper(wrapped_policy=policy,
                                      env=env,
                                      episode_length=max_path_length,
                                      **subgoal_policy_kwargs)
        target_policy = SubgoalPolicyWrapper(wrapped_policy=target_policy,
                                             env=env,
                                             episode_length=max_path_length,
                                             **subgoal_policy_kwargs)

    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        # use_masks=variant.get("use_masks", False),
        **variant['replay_buffer_kwargs'])

    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['td3_trainer_kwargs'])
    # if variant.get("use_masks", False):
    #     from railrl.torch.her.her import MaskedHERTrainer
    #     trainer = MaskedHERTrainer(trainer)
    # else:
    trainer = HERTrainer(trainer)
    if variant.get("do_state_exp", False):
        eval_path_collector = GoalConditionedPathCollector(
            eval_env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            # use_masks=variant.get("use_masks", False),
            # full_mask=True,
        )
        expl_path_collector = GoalConditionedPathCollector(
            expl_env,
            expl_policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            # use_masks=variant.get("use_masks", False),
        )
    else:
        eval_path_collector = VAEWrappedEnvPathCollector(
            env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            goal_sampling_mode=['evaluation_goal_sampling_mode'],
        )
        expl_path_collector = VAEWrappedEnvPathCollector(
            env,
            expl_policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            goal_sampling_mode=['exploration_goal_sampling_mode'],
        )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    vis_variant = variant.get('vis_kwargs', {})
    vis_list = vis_variant.get('vis_list', [])
    if variant.get("save_video", True):
        if variant.get("do_state_exp", False):
            rollout_function = rf.create_rollout_function(
                rf.multitask_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                desired_goal_key=desired_goal_key,
                # use_masks=variant.get("use_masks", False),
                # full_mask=True,
                # vis_list=vis_list,
            )
            video_func = get_video_save_func(
                rollout_function,
                env,
                policy,
                variant,
            )
        else:
            video_func = VideoSaveFunction(
                env,
                variant,
            )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)
    algorithm.train()
Beispiel #3
0
def experiment(variant):
    feat_points = 16
    history = 1
    latent_obs_dim = feat_points * 2 * history
    imsize = 64
    downsampled_size = 32

    env = SawyerXYZEnv()
    extra_fc_size = env.obs_dim
    env = ImageMujocoWithObsEnv(env,
                                imsize=imsize,
                                normalize=True,
                                grayscale=True,
                                keep_prev=history - 1,
                                init_camera=camera.sawyer_init_camera)
    """env = ImageMujocoEnv(env,
                        imsize=imsize,
                        keep_prev=history-1,
                        init_camera=camera.sawyer_init_camera)"""

    es = GaussianStrategy(action_space=env.action_space, )
    obs_dim = env.observation_space.low.size
    action_dim = env.action_space.low.size
    ae = FeatPointMlp(input_size=imsize,
                      downsample_size=downsampled_size,
                      input_channels=1,
                      num_feat_points=feat_points)
    replay_buffer = AEEnvReplayBuffer(int(1e4),
                                      env,
                                      imsize=imsize,
                                      history_length=history,
                                      downsampled_size=downsampled_size)

    qf = FlattenMlp(input_size=latent_obs_dim + extra_fc_size + action_dim,
                    output_size=1,
                    hidden_sizes=[400, 300])
    policy = AETanhPolicy(
        input_size=latent_obs_dim + extra_fc_size,
        ae=ae,
        env=env,
        history_length=history,
        output_size=action_dim,
        hidden_sizes=[400, 300],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    algorithm = FeatPointDDPG(ae,
                              history,
                              env=env,
                              qf=qf,
                              policy=policy,
                              exploration_policy=exploration_policy,
                              replay_buffer=replay_buffer,
                              extra_fc_size=extra_fc_size,
                              imsize=imsize,
                              downsampled_size=downsampled_size,
                              **variant['algo_params'])

    algorithm.to(ptu.device)
    algorithm.train()
Beispiel #4
0
        'e': np.array([-1, -1, 0, 0]),
        'z': np.array([1, 1, 0, 0]),
        'c': np.array([-1, 1, 0, 0]),
        'x': 'toggle',
        'r': 'reset',
    }
    # np.random.seed(100)
    env = SawyerDoorPullOpenActionLimitedEnv(fix_goal=True, min_y_pos=.3)
    policy = ZeroPolicy(env.action_space.low.size)
    es = OUStrategy(env.action_space, theta=1)
    es = EpsilonGreedy(
        action_space=env.action_space,
        prob_random_action=0.1,
    )
    policy = exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    env.reset()
    ACTION_FROM = 'hardcoded'
    # ACTION_FROM = 'pd'
    # ACTION_FROM = 'random'
    H = 10000
    # H = 300
    # H = 50
    goal = .25

    while True:
        lock_action = False
        obs = env.reset()
        last_reward_t = 0
Beispiel #5
0
def experiment(variant):
    expl_env = variant['env_class'](**variant['env_kwargs'])
    eval_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        max_sigma=.2,
        min_sigma=.2,  # constant sigma
        epsilon=.3,
    )
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    trainer = TD3(
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['trainer_kwargs']
    )
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
Beispiel #6
0
def experiment(variant):
    rdim = variant["rdim"]
    vae_paths = {
        2:
        "/home/ashvin/data/s3doodad/ashvin/vae/pusher2d-conv-sweep2/run1/id0/params.pkl",
        4:
        "/home/ashvin/data/s3doodad/ashvin/vae/pusher2d-conv-sweep2/run1/id1/params.pkl",
        8:
        "/home/ashvin/data/s3doodad/ashvin/vae/pusher2d-conv-sweep2/run1/id2/params.pkl",
        16:
        "/home/ashvin/data/s3doodad/ashvin/vae/pusher2d-conv-sweep2/run1/id3/params.pkl"
    }
    vae_path = vae_paths[rdim]
    vae = torch.load(vae_path)
    print("loaded", vae_path)

    if variant['multitask']:
        env = CylinderXYPusher2DEnv(**variant["env_kwargs"])
        env = ImageMujocoEnv(env, 84, camera_name="topview", transpose=True)
        env = VAEWrappedEnv(env,
                            vae,
                            use_vae_obs=True,
                            use_vae_reward=True,
                            use_vae_goals=True)
        env = MultitaskToFlatEnv(env)
    # else:
    # env = Pusher2DEnv(**variant['env_kwargs'])
    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space)
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=0.1,
            min_sigma=0.1,  # Constant sigma
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=0.1,
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    obs_dim = env.observation_space.low.size
    action_dim = env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 300],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 300],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[400, 300],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = TD3(env,
                    training_env=env,
                    qf1=qf1,
                    qf2=qf2,
                    policy=policy,
                    exploration_policy=exploration_policy,
                    **variant['algo_kwargs'])
    print("use_gpu", variant["use_gpu"], bool(variant["use_gpu"]))
    if variant["use_gpu"]:
        gpu_id = variant["gpu_id"]
        ptu.set_gpu_mode(True)
        ptu.set_device(gpu_id)
        algorithm.to(ptu.device)
        env._wrapped_env.vae.to(ptu.device)
    algorithm.train()
Beispiel #7
0
def her_td3_experiment(variant):
    import gym

    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import \
        PolicyWrappedWithExplorationStrategy
    from railrl.exploration_strategies.gaussian_and_epislon import \
        GaussianAndEpislonStrategy
    from railrl.launchers.launcher_util import setup_logger
    from railrl.samplers.data_collector import GoalConditionedPathCollector
    from railrl.torch.her.her import HERTrainer
    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    from railrl.torch.td3.td3 import TD3
    from railrl.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    import railrl.samplers.rollout_functions as rf
    from railrl.torch.grill.launcher import get_state_experiment_video_save_function

    if 'env_id' in variant:
        eval_env = gym.make(variant['env_id'])
        expl_env = gym.make(variant['env_id'])
    else:
        eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs'])
        eval_env = variant['env_class'](**eval_env_kwargs)
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        max_sigma=.2,
        min_sigma=.2,  # constant sigma
        epsilon=.3,
    )
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(input_size=obs_dim + goal_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + goal_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    target_qf1 = FlattenMlp(input_size=obs_dim + goal_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(input_size=obs_dim + goal_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    trainer = TD3(policy=policy,
                  qf1=qf1,
                  qf2=qf2,
                  target_qf1=target_qf1,
                  target_qf2=target_qf2,
                  target_policy=target_policy,
                  **variant['trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
        video_func = get_state_experiment_video_save_function(
            rollout_function,
            eval_env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)

    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    if variant.get("pretrained_algorithm_path", False):
        resume(variant)
        return

    if 'env' in variant:
        env_params = ENV_PARAMS[variant['env']]
        variant.update(env_params)

        if 'env_id' in env_params:
            if env_params['env_id'] in [
                    'pen-v0', 'pen-sparse-v0', 'door-v0', 'relocate-v0',
                    'hammer-v0', 'pen-sparse-v0', 'door-sparse-v0',
                    'relocate-sparse-v0', 'hammer-sparse-v0'
            ]:
                import mj_envs
            expl_env = gym.make(env_params['env_id'])
            eval_env = gym.make(env_params['env_id'])
        else:
            expl_env = NormalizedBoxEnv(variant['env_class']())
            eval_env = NormalizedBoxEnv(variant['env_class']())

        if variant.get('sparse_reward', False):
            expl_env = RewardWrapperEnv(expl_env, compute_hand_sparse_reward)
            eval_env = RewardWrapperEnv(eval_env, compute_hand_sparse_reward)

        if variant.get('add_env_demos', False):
            variant["path_loader_kwargs"]["demo_paths"].append(
                variant["env_demo_path"])

        if variant.get('add_env_offpolicy_data', False):
            variant["path_loader_kwargs"]["demo_paths"].append(
                variant["env_offpolicy_data_path"])
    else:
        expl_env = encoder_wrapped_env(variant)
        eval_env = encoder_wrapped_env(variant)

    path_loader_kwargs = variant.get("path_loader_kwargs", {})
    stack_obs = path_loader_kwargs.get("stack_obs", 1)
    if stack_obs > 1:
        expl_env = StackObservationEnv(expl_env, stack_obs=stack_obs)
        eval_env = StackObservationEnv(eval_env, stack_obs=stack_obs)

    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size
    if hasattr(expl_env, 'info_sizes'):
        env_info_sizes = expl_env.info_sizes
    else:
        env_info_sizes = dict()

    M = variant['layer_size']
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy_class = variant.get("policy_class", TanhGaussianPolicy)
    policy = policy_class(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **variant['policy_kwargs'],
    )

    buffer_policy = policy_class(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **variant['policy_kwargs'],
    )

    eval_policy = MakeDeterministic(policy)
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )

    expl_policy = policy
    exploration_kwargs = variant.get('exploration_kwargs', {})
    if exploration_kwargs:
        if exploration_kwargs.get("deterministic_exploration", False):
            expl_policy = MakeDeterministic(policy)

        exploration_strategy = exploration_kwargs.get("strategy", None)
        if exploration_strategy is None:
            pass
        elif exploration_strategy == 'ou':
            es = OUStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs['noise'],
                min_sigma=exploration_kwargs['noise'],
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        elif exploration_strategy == 'gauss_eps':
            es = GaussianAndEpislonStrategy(
                action_space=expl_env.action_space,
                max_sigma=exploration_kwargs['noise'],
                min_sigma=exploration_kwargs['noise'],  # constant sigma
                epsilon=0,
            )
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=expl_policy,
            )
        else:
            error

    if variant.get('replay_buffer_class',
                   EnvReplayBuffer) == AWREnvReplayBuffer:
        main_replay_buffer_kwargs = variant['replay_buffer_kwargs']
        main_replay_buffer_kwargs['env'] = expl_env
        main_replay_buffer_kwargs['qf1'] = qf1
        main_replay_buffer_kwargs['qf2'] = qf2
        main_replay_buffer_kwargs['policy'] = policy
    else:
        main_replay_buffer_kwargs = dict(
            max_replay_buffer_size=variant['replay_buffer_size'],
            env=expl_env,
        )
    replay_buffer_kwargs = dict(
        max_replay_buffer_size=variant['replay_buffer_size'],
        env=expl_env,
    )

    replay_buffer = variant.get('replay_buffer_class',
                                EnvReplayBuffer)(**main_replay_buffer_kwargs, )
    trainer = AWRSACTrainer(env=eval_env,
                            policy=policy,
                            qf1=qf1,
                            qf2=qf2,
                            target_qf1=target_qf1,
                            target_qf2=target_qf2,
                            buffer_policy=buffer_policy,
                            **variant['trainer_kwargs'])
    if variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env,
            policy,
        )
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant[
                'num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant[
                'min_num_steps_before_training'],
        )
    else:
        expl_path_collector = MdpPathCollector(
            expl_env,
            expl_policy,
        )
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant[
                'num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant[
                'min_num_steps_before_training'],
        )
    algorithm.to(ptu.device)

    demo_train_buffer = EnvReplayBuffer(**replay_buffer_kwargs, )
    demo_test_buffer = EnvReplayBuffer(**replay_buffer_kwargs, )

    if variant.get('save_paths', False):
        algorithm.post_train_funcs.append(save_paths)
    if variant.get('load_demos', False):
        path_loader_class = variant.get('path_loader_class', MDPPathLoader)
        path_loader = path_loader_class(trainer,
                                        replay_buffer=replay_buffer,
                                        demo_train_buffer=demo_train_buffer,
                                        demo_test_buffer=demo_test_buffer,
                                        **path_loader_kwargs)
        path_loader.load_demos()
    if variant.get('pretrain_policy', False):
        trainer.pretrain_policy_with_bc()
    if variant.get('pretrain_rl', False):
        trainer.pretrain_q_with_bc_data()
    if variant.get('save_pretrained_algorithm', False):
        p_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.p')
        pt_path = osp.join(logger.get_snapshot_dir(), 'pretrain_algorithm.pt')
        data = algorithm._get_snapshot()
        data['algorithm'] = algorithm
        torch.save(data, open(pt_path, "wb"))
        torch.save(data, open(p_path, "wb"))
    if variant.get('train_rl', True):
        algorithm.train()
def her_td3_experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    her_kwargs = variant['algo_kwargs']['her_kwargs']
    observation_key = her_kwargs['observation_key']
    desired_goal_key = her_kwargs['desired_goal_key']
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    obs_dim = env.observation_space.spaces['observation'].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces['desired_goal'].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(
            action_space=env.action_space,
            **variant['es_kwargs']
        )
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    num_ensemble_qs = variant.get("num_ensemble_qs", 0)
    ensemble_qs = [FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    ) for _ in range(num_ensemble_qs)]
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    render = variant.get("render", False)
    algorithm = HerExplorationTd3(
        env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        exploration_policy=exploration_policy,
        replay_buffer=replay_buffer,
        render=render,
        render_during_eval=render,
        ensemble_qs=ensemble_qs,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
def encoder_wrapped_td3bc_experiment(variant):
    representation_size = 128
    output_classes = 20

    model_class = variant.get('model_class', TimestepPredictionModel)
    model = model_class(
        representation_size,
        # decoder_output_activation=decoder_activation,
        output_classes=output_classes,
        **variant['model_kwargs'],
    )
    # model = torch.nn.DataParallel(model)

    model_path = variant.get("model_path")
    # model = load_local_or_remote_file(model_path)
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(ptu.device)
    model.eval()

    traj = np.load(variant.get("desired_trajectory"), allow_pickle=True)[0]

    goal_image = traj["observations"][-1]["image_observation"]
    goal_image = goal_image.reshape(1, 3, 500, 300).transpose([0, 1, 3, 2
                                                               ]) / 255.0
    # goal_image = goal_image.reshape(1, 300, 500, 3).transpose([0, 3, 1, 2]) / 255.0 # BECAUSE RLBENCH DEMOS ARENT IMAGE_ENV WRAPPED
    # goal_image = goal_image[:, :, :240, 60:500]
    goal_image = goal_image[:, :, 60:, 60:500]
    goal_image_pt = ptu.from_numpy(goal_image)
    save_image(goal_image_pt.data.cpu(), 'demos/goal.png', nrow=1)
    goal_latent = model.encode(goal_image_pt).detach().cpu().numpy().flatten()

    initial_image = traj["observations"][0]["image_observation"]
    initial_image = initial_image.reshape(1, 3, 500, 300).transpose(
        [0, 1, 3, 2]) / 255.0
    # initial_image = initial_image.reshape(1, 300, 500, 3).transpose([0, 3, 1, 2]) / 255.0
    # initial_image = initial_image[:, :, :240, 60:500]
    initial_image = initial_image[:, :, 60:, 60:500]
    initial_image_pt = ptu.from_numpy(initial_image)
    save_image(initial_image_pt.data.cpu(), 'demos/initial.png', nrow=1)
    initial_latent = model.encode(
        initial_image_pt).detach().cpu().numpy().flatten()

    # Move these to td3_bc and bc_v3 (or at least type for reward_params)
    reward_params = dict(
        goal_latent=goal_latent,
        initial_latent=initial_latent,
        type=variant["reward_params_type"],
    )

    config_params = variant.get("config_params")

    env = variant['env_class'](**variant['env_kwargs'])
    env = ImageEnv(
        env,
        recompute_reward=False,
        transpose=True,
        image_length=450000,
        reward_type="image_distance",
        # init_camera=sawyer_pusher_camera_upright_v2,
    )
    env = EncoderWrappedEnv(
        env, model, reward_params, config_params,
        **variant.get("encoder_wrapped_env_kwargs", dict()))

    expl_env = env  # variant['env_class'](**variant['env_kwargs'])
    eval_env = env  # variant['env_class'](**variant['env_kwargs'])

    observation_key = variant.get("observation_key", 'state_observation')
    # one of 'state_observation', 'latent_observation', 'concat_observation'
    desired_goal_key = 'latent_desired_goal'

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        **variant["exploration_kwargs"],
    )
    obs_dim = expl_env.observation_space.spaces[observation_key].low.size
    goal_dim = expl_env.observation_space.spaces[desired_goal_key].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        # output_activation=TorchMaxClamp(0.0),
        **variant['qf_kwargs'])
    qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        # output_activation=TorchMaxClamp(0.0),
        **variant['qf_kwargs'])
    target_qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        # output_activation=TorchMaxClamp(0.0),
        **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        # output_activation=TorchMaxClamp(0.0),
        **variant['qf_kwargs'])

    # Support for CNNPolicy based policy/target policy
    # Defaults to TanhMlpPolicy unless cnn_params is supplied in variant
    if 'cnn_params' in variant.keys():
        imsize = 48
        policy = CNNPolicy(
            input_width=imsize,
            input_height=imsize,
            output_size=action_dim,
            input_channels=3,
            **variant['cnn_params'],
            output_activation=torch.tanh,
        )
        target_policy = CNNPolicy(
            input_width=imsize,
            input_height=imsize,
            output_size=action_dim,
            input_channels=3,
            **variant['cnn_params'],
            output_activation=torch.tanh,
        )
    else:
        policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                               output_size=action_dim,
                               **variant['policy_kwargs'])
        target_policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                                      output_size=action_dim,
                                      **variant['policy_kwargs'])
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    demo_train_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    demo_test_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    td3bc_trainer = TD3BCTrainer(env=env,
                                 policy=policy,
                                 qf1=qf1,
                                 qf2=qf2,
                                 replay_buffer=replay_buffer,
                                 demo_train_buffer=demo_train_buffer,
                                 demo_test_buffer=demo_test_buffer,
                                 target_qf1=target_qf1,
                                 target_qf2=target_qf2,
                                 target_policy=target_policy,
                                 **variant['trainer_kwargs'])
    trainer = HERTrainer(td3bc_trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    if variant.get("save_video", True):
        video_func = VideoSaveFunction(
            env,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)

    td3bc_trainer.load_demos()

    td3bc_trainer.pretrain_policy_with_bc()
    td3bc_trainer.pretrain_q_with_bc_data()

    algorithm.train()
Beispiel #11
0
def experiment(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.torch.her.her import HERTrainer
    from railrl.torch.td3.td3 import TD3 as TD3Trainer
    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    from railrl.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from railrl.samplers.data_collector import GoalConditionedPathCollector
    from railrl.torch.grill.launcher import (
        grill_preprocess_variant,
        get_envs,
        get_exploration_strategy,
        full_experiment_variant_preprocess,
        train_vae_and_update_variant,
        get_video_save_func,
    )

    full_experiment_variant_preprocess(variant)
    if not variant['grill_variant'].get('do_state_exp', False):
        train_vae_and_update_variant(variant)
    variant = variant['grill_variant']

    grill_preprocess_variant(variant)
    eval_env = get_envs(variant)
    expl_env = get_envs(variant)
    es = get_exploration_strategy(variant, expl_env)

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (expl_env.observation_space.spaces[observation_key].low.size +
               expl_env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['td3_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    if variant.get("save_video", True):  # Does not work
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            eval_env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        eval_env.vae.to(ptu.device)
        expl_env.vae.to(ptu.device)

    algorithm.train()
Beispiel #12
0
def HER_baseline_td3_experiment(variant):
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.torch.her.her_td3 import HerTd3
    from railrl.torch.networks import MergedCNN, CNNPolicy
    import torch
    from multiworld.core.image_env import ImageEnv
    from railrl.misc.asset_loader import load_local_or_remote_file

    init_camera = variant.get("init_camera", None)
    presample_goals = variant.get('presample_goals', False)
    presampled_goals_path = get_presampled_goals_path(
        variant.get('presampled_goals_path', None))

    if 'env_id' in variant:
        import gym
        import multiworld
        multiworld.register_all_envs()
        env = gym.make(variant['env_id'])
    else:
        env = variant["env_class"](**variant['env_kwargs'])
    image_env = ImageEnv(
        env,
        variant.get('imsize'),
        reward_type='image_sparse',
        init_camera=init_camera,
        transpose=True,
        normalize=True,
    )
    if presample_goals:
        if presampled_goals_path is None:
            image_env.non_presampled_goal_img_is_garbage = True
            presampled_goals = variant['generate_goal_dataset_fctn'](
                env=image_env, **variant['goal_generation_kwargs'])
        else:
            presampled_goals = load_local_or_remote_file(
                presampled_goals_path).item()
        del image_env
        env = ImageEnv(
            env,
            variant.get('imsize'),
            reward_type='image_distance',
            init_camera=init_camera,
            transpose=True,
            normalize=True,
            presampled_goals=presampled_goals,
        )
    else:
        env = image_env

    es = get_exploration_strategy(variant, env)

    observation_key = variant.get('observation_key', 'image_observation')
    desired_goal_key = variant.get('desired_goal_key', 'image_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    imsize = variant['imsize']
    action_dim = env.action_space.low.size
    qf1 = MergedCNN(input_width=imsize,
                    input_height=imsize,
                    output_size=1,
                    input_channels=3 * 2,
                    added_fc_input_size=action_dim,
                    **variant['cnn_params'])
    qf2 = MergedCNN(input_width=imsize,
                    input_height=imsize,
                    output_size=1,
                    input_channels=3 * 2,
                    added_fc_input_size=action_dim,
                    **variant['cnn_params'])

    policy = CNNPolicy(
        input_width=imsize,
        input_height=imsize,
        added_fc_input_size=0,
        output_size=action_dim,
        input_channels=3 * 2,
        output_activation=torch.tanh,
        **variant['cnn_params'],
    )
    target_qf1 = MergedCNN(input_width=imsize,
                           input_height=imsize,
                           output_size=1,
                           input_channels=3 * 2,
                           added_fc_input_size=action_dim,
                           **variant['cnn_params'])
    target_qf2 = MergedCNN(input_width=imsize,
                           input_height=imsize,
                           output_size=1,
                           input_channels=3 * 2,
                           added_fc_input_size=action_dim,
                           **variant['cnn_params'])

    target_policy = CNNPolicy(
        input_width=imsize,
        input_height=imsize,
        added_fc_input_size=0,
        output_size=action_dim,
        input_channels=3 * 2,
        output_activation=torch.tanh,
        **variant['cnn_params'],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant["render"]
    base_kwargs['render_during_eval'] = variant["render"]
    her_kwargs = algo_kwargs['her_kwargs']
    her_kwargs['observation_key'] = observation_key
    her_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = HerTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       target_qf1=target_qf1,
                       target_qf2=target_qf2,
                       target_policy=target_policy,
                       exploration_policy=exploration_policy,
                       **variant['algo_kwargs'])

    algorithm.to(ptu.device)
    algorithm.train()
Beispiel #13
0
def tdm_td3_experiment_online_vae(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.state_distance.tdm_networks import TdmQf, TdmPolicy
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.torch.online_vae.online_vae_tdm_td3 import OnlineVaeTdmTd3
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
        'vectorized'] = vectorized

    norm_order = env.norm_order
    # variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
    #     'norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       observation_dim=obs_dim,
                       goal_dim=goal_dim,
                       action_dim=action_dim,
                       **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']['tdm_td3_kwargs']
    td3_kwargs = algo_kwargs['td3_kwargs']
    td3_kwargs['training_env'] = env
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algo_kwargs["replay_buffer"] = replay_buffer

    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)
    render = variant["render"]
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    algorithm = OnlineVaeTdmTd3(
        online_vae_kwargs=dict(vae=vae,
                               vae_trainer=t,
                               **variant['algo_kwargs']['online_vae_kwargs']),
        tdm_td3_kwargs=dict(env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs']['tdm_td3_kwargs']),
    )

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm._sample_max_tau_for_rollout(),
            decrement_tau=algorithm.cycle_taus_for_rollout,
            cycle_tau=algorithm.cycle_taus_for_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)

    algorithm.train()
Beispiel #14
0
def tdm_td3_experiment(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.core import logger
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.state_distance.tdm_networks import TdmQf, TdmPolicy
    from railrl.state_distance.tdm_td3 import TdmTd3
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    norm_order = env.norm_order
    variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       observation_dim=obs_dim,
                       goal_dim=goal_dim,
                       action_dim=action_dim,
                       **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    variant['replay_buffer_kwargs']['vectorized'] = vectorized
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant["render"]
    base_kwargs['render_during_eval'] = variant["render"]
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = TdmTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       **variant['algo_kwargs'])

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)
    if variant.get("save_video", True):
        logdir = logger.get_snapshot_dir()
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm.max_tau,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    algorithm.train()
def experiment(variant):
    vectorized = variant['td3_tdm_kwargs']['tdm_kwargs']['vectorized']
    env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    max_tau = variant['td3_tdm_kwargs']['tdm_kwargs']['max_tau']
    qf1 = TdmQf(env, vectorized=vectorized, **variant['qf_kwargs'])
    qf2 = TdmQf(env, vectorized=vectorized, **variant['qf_kwargs'])
    tdm_normalizer = TdmNormalizer(env,
                                   vectorized,
                                   max_tau=max_tau,
                                   **variant['tdm_normalizer_kwargs'])
    implicit_model = TdmToImplicitModel(
        env,
        qf1,
        tau=0,
    )
    vf = TdmVf(env=env,
               vectorized=vectorized,
               tdm_normalizer=tdm_normalizer,
               **variant['vf_kwargs'])
    policy = TdmPolicy(env=env,
                       tdm_normalizer=tdm_normalizer,
                       **variant['policy_kwargs'])
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    goal_slice = env.ob_to_goal_slice
    lbfgs_mpc_controller = TdmLBfgsBCMC(implicit_model,
                                        env,
                                        goal_slice=goal_slice,
                                        multitask_goal_slice=goal_slice,
                                        tdm_policy=policy,
                                        **variant['mpc_controller_kwargs'])
    state_only_mpc_controller = TdmLBfgsBStateOnlyCMC(
        vf,
        policy,
        env,
        goal_slice=goal_slice,
        multitask_goal_slice=goal_slice,
        **variant['state_only_mpc_controller_kwargs'])
    es = GaussianStrategy(action_space=env.action_space,
                          **variant['es_kwargs'])
    if variant['explore_with'] == 'TdmLBfgsBCMC':
        raw_exploration_policy = lbfgs_mpc_controller
    elif variant['explore_with'] == 'TdmLBfgsBStateOnlyCMC':
        raw_exploration_policy = state_only_mpc_controller
    else:
        raw_exploration_policy = policy
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=raw_exploration_policy,
    )
    if variant['eval_with'] == 'TdmLBfgsBCMC':
        eval_policy = lbfgs_mpc_controller
    elif variant['eval_with'] == 'TdmLBfgsBStateOnlyCMC':
        eval_policy = state_only_mpc_controller
    else:
        eval_policy = policy
    # variant['td3_tdm_kwargs']['base_kwargs']['eval_policy'] = eval_policy
    algorithm = TdmTd3(env=env,
                       policy=policy,
                       qf1=qf1,
                       qf2=qf2,
                       vf=vf,
                       exploration_policy=exploration_policy,
                       eval_policy=eval_policy,
                       replay_buffer=replay_buffer,
                       **variant['td3_tdm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Beispiel #16
0
def her_td3_experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    observation_key = variant['observation_key']
    desired_goal_key = variant['desired_goal_key']
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    replay_buffer = variant['replay_buffer_class'](
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    variant['count_based_sampler_kwargs']['replay_buffer'] = replay_buffer
    env = CountBasedGoalSamplingEnv(wrapped_env=env, **variant['count_based_sampler_kwargs'])

    obs_dim = env.observation_space.spaces['observation'].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces['desired_goal'].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(
            action_space=env.action_space,
            **variant['es_kwargs']
        )
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = HerTd3(
        env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        training_env=env,
        exploration_policy=exploration_policy,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )
    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
def her_td3_experiment(variant):
    import multiworld.envs.mujoco
    import multiworld.envs.pygame
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.exploration_strategies.epsilon_greedy import EpsilonGreedy
    from railrl.exploration_strategies.gaussian_strategy import GaussianStrategy
    from railrl.exploration_strategies.ou_strategy import OUStrategy
    from railrl.torch.grill.launcher import get_video_save_func
    from railrl.torch.her.her_td3 import HerTd3
    from railrl.data_management.obs_dict_replay_buffer import (
        ObsDictRelabelingBuffer)

    if 'env_id' in variant:
        env = gym.make(variant['env_id'])
    else:
        env = variant['env_class'](**variant['env_kwargs'])

    imsize = 84
    env = GymToMultiEnv(env.env)  # unwrap TimeLimit
    env = ImageEnv(env, non_presampled_goal_img_is_garbage=True)

    observation_key = variant['observation_key']
    desired_goal_key = variant['desired_goal_key']
    variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    if variant.get('normalize', False):
        raise NotImplementedError()

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = env.observation_space.spaces[observation_key].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces[desired_goal_key].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)

    use_images_for_q = variant["use_images_for_q"]
    use_images_for_pi = variant["use_images_for_pi"]

    qs = []
    for i in range(2):
        if use_images_for_q:
            image_q = MergedCNN(input_width=imsize,
                                input_height=imsize,
                                output_size=1,
                                input_channels=3,
                                added_fc_input_size=action_dim,
                                **variant['cnn_params'])
            q = ImageStateQ(image_q, None)
        else:
            state_q = FlattenMlp(input_size=action_dim + goal_dim,
                                 output_size=1,
                                 **variant['qf_kwargs'])
            q = ImageStateQ(None, state_q)
        qs.append(q)
    qf1, qf2 = qs

    if use_images_for_pi:
        image_policy = CNNPolicy(
            input_width=imsize,
            input_height=imsize,
            output_size=action_dim,
            input_channels=3,
            **variant['cnn_params'],
            output_activation=torch.tanh,
        )
        policy = ImageStatePolicy(image_policy, None)
    else:
        state_policy = TanhMlpPolicy(input_size=goal_dim,
                                     output_size=action_dim,
                                     **variant['policy_kwargs'])
        policy = ImageStatePolicy(None, state_policy)

    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = HerTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       replay_buffer=replay_buffer,
                       **variant['algo_kwargs'])
    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)
    algorithm.to(ptu.device)
    algorithm.train()
Beispiel #18
0
def grill_her_td3_experiment(variant):
    env = variant["env_class"](**variant['env_kwargs'])

    render = variant["render"]

    rdim = variant["rdim"]
    vae_path = variant["vae_paths"][str(rdim)]
    reward_params = variant.get("reward_params", dict())

    init_camera = variant.get("init_camera", None)
    if init_camera is None:
        camera_name = "topview"
    else:
        camera_name = None

    env = ImageEnv(
        env,
        84,
        init_camera=init_camera,
        camera_name=camera_name,
        transpose=True,
        normalize=True,
    )

    env = VAEWrappedEnv(
        env,
        vae_path,
        decode_goals=render,
        render_goals=render,
        render_rollouts=render,
        reward_params=reward_params,
        **variant.get('vae_wrapped_env_kwargs', {})
    )

    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    exploration_type = variant['exploration_type']
    exploration_noise = variant.get('exploration_noise', 0.1)
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space)
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=exploration_noise,
            min_sigma=exploration_noise,  # Constant sigma
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=exploration_noise,
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (
        env.observation_space.spaces[observation_key].low.size
        + env.observation_space.spaces[desired_goal_key].low.size
    )
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    training_mode = variant.get("training_mode", "train")
    testing_mode = variant.get("testing_mode", "test")

    testing_env = pickle.loads(pickle.dumps(env))
    testing_env.mode(testing_mode)

    training_env = pickle.loads(pickle.dumps(env))
    training_env.mode(training_mode)

    relabeling_env = pickle.loads(pickle.dumps(env))
    relabeling_env.mode(training_mode)
    relabeling_env.disable_render()

    video_vae_env = pickle.loads(pickle.dumps(env))
    video_vae_env.mode("video_vae")
    video_goal_env = pickle.loads(pickle.dumps(env))
    video_goal_env.mode("video_env")


    replay_buffer = ObsDictRelabelingBuffer(
        env=relabeling_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_kwargs']
    )
    variant["algo_kwargs"]["replay_buffer"] = replay_buffer
    algorithm = HerTd3(
        testing_env,
        training_env=training_env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        exploration_policy=exploration_policy,
        render=render,
        render_during_eval=render,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        **variant['algo_kwargs']
    )

    if ptu.gpu_enabled():
        print("using GPU")
        algorithm.to(ptu.device)
        for e in [testing_env, training_env, video_vae_env, video_goal_env]:
            e.vae.to(ptu.device)

    algorithm.train()

    if variant.get("save_video", True):
        logdir = logger.get_snapshot_dir()
        policy.train(False)
        filename = osp.join(logdir, 'video_final_env.mp4')
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        dump_video(video_goal_env, policy, filename, rollout_function)
        filename = osp.join(logdir, 'video_final_vae.mp4')
        dump_video(video_vae_env, policy, filename, rollout_function)
def her_td3_experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    observation_key = variant.get('observation_key', 'observation')
    desired_goal_key = variant.get('desired_goal_key', 'desired_goal')
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        **variant['replay_buffer_kwargs']
    )
    obs_dim = env.observation_space.spaces['observation'].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces['desired_goal'].low.size
    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(
            action_space=env.action_space,
            max_sigma=0.1,
            **variant['es_kwargs']
        )
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=0.1,
            min_sigma=0.1,  # Constant sigma
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=0.1,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = HerTd3(
        env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        exploration_policy=exploration_policy,
        replay_buffer=replay_buffer,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        **variant['algo_kwargs']
    )
    if ptu.gpu_enabled():
        qf1.to(ptu.device)
        qf2.to(ptu.device)
        policy.to(ptu.device)
        algorithm.to(ptu.device)
    algorithm.train()
Beispiel #20
0
def state_td3bc_experiment(variant):
    if variant.get('env_id', None):
        import gym
        import multiworld
        multiworld.register_all_envs()
        eval_env = gym.make(variant['env_id'])
        expl_env = gym.make(variant['env_id'])
    else:
        eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs'])
        eval_env = variant['env_class'](**eval_env_kwargs)
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es_strat =  variant.get('es', 'ou')
    if es_strat == 'ou':
        es = OUStrategy(
            action_space=expl_env.action_space,
            max_sigma=variant['exploration_noise'],
            min_sigma=variant['exploration_noise'],
        )
    elif es_strat == 'gauss_eps':
        es = GaussianAndEpislonStrategy(
            action_space=expl_env.action_space,
            max_sigma=.2,
            min_sigma=.2,  # constant sigma
            epsilon=.3,
        )
    else:
        raise ValueError("invalid exploration strategy provided")
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    demo_train_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        max_size=variant['replay_buffer_kwargs']['max_size']
    )
    demo_test_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        max_size=variant['replay_buffer_kwargs']['max_size'],
    )
    if variant.get('td3_bc', True):
        td3_trainer = TD3BCTrainer(
            env=expl_env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            replay_buffer=replay_buffer,
            demo_train_buffer=demo_train_buffer,
            demo_test_buffer=demo_test_buffer,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            target_policy=target_policy,
            **variant['td3_bc_trainer_kwargs']
        )
    else:
        td3_trainer = TD3(
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            target_policy=target_policy,
            **variant['td3_trainer_kwargs']
        )
    trainer = HERTrainer(td3_trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )

    if variant.get("save_video", True):
        if variant.get("presampled_goals", None):
            variant['image_env_kwargs']['presampled_goals'] = load_local_or_remote_file(variant['presampled_goals']).item()
        image_eval_env = ImageEnv(eval_env, **variant["image_env_kwargs"])
        image_eval_path_collector = GoalConditionedPathCollector(
            image_eval_env,
            policy,
            observation_key='state_observation',
            desired_goal_key='state_desired_goal',
        )
        image_expl_env = ImageEnv(expl_env, **variant["image_env_kwargs"])
        image_expl_path_collector = GoalConditionedPathCollector(
            image_expl_env,
            expl_policy,
            observation_key='state_observation',
            desired_goal_key='state_desired_goal',
        )
        video_func = VideoSaveFunction(
            image_eval_env,
            variant,
            image_expl_path_collector,
            image_eval_path_collector,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if variant.get('load_demos', False):
        td3_trainer.load_demos()
    if variant.get('pretrain_policy', False):
        td3_trainer.pretrain_policy_with_bc()
    if variant.get('pretrain_rl', False):
        td3_trainer.pretrain_q_with_bc_data()
    algorithm.train()
Beispiel #21
0
def td3_experiment(variant):
    import gym
    import multiworld.envs.mujoco
    import multiworld.envs.pygame
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.exploration_strategies.epsilon_greedy import EpsilonGreedy
    from railrl.exploration_strategies.gaussian_strategy import GaussianStrategy
    from railrl.exploration_strategies.ou_strategy import OUStrategy
    from railrl.torch.grill.launcher import get_state_experiment_video_save_function
    from railrl.torch.her.her_td3 import HerTd3
    from railrl.torch.td3.td3 import TD3
    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    from railrl.data_management.obs_dict_replay_buffer import (
        ObsDictReplayBuffer)
    from railrl.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from railrl.samplers.data_collector.path_collector import ObsDictPathCollector

    if 'env_id' in variant:
        eval_env = gym.make(variant['env_id'])
        expl_env = gym.make(variant['env_id'])
    else:
        eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs'])
        eval_env = variant['env_class'](**eval_env_kwargs)
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = variant['observation_key']
    # desired_goal_key = variant['desired_goal_key']
    # variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    # variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    if variant.get('normalize', False):
        raise NotImplementedError()

    # achieved_goal_key = desired_goal_key.replace("desired", "achieved")

    replay_buffer = ObsDictReplayBuffer(
        env=eval_env,
        observation_key=observation_key,
        # desired_goal_key=desired_goal_key,
        # achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=eval_env.action_space,
                        **variant['es_kwargs'])
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=eval_env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=eval_env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    trainer = TD3(policy=policy,
                  qf1=qf1,
                  qf2=qf2,
                  target_qf1=target_qf1,
                  target_qf2=target_qf2,
                  target_policy=target_policy,
                  **variant['trainer_kwargs'])
    observation_key = 'observation'
    desired_goal_key = 'desired_goal'
    eval_path_collector = ObsDictPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        # render=True,
        # desired_goal_key=desired_goal_key,
    )
    expl_path_collector = ObsDictPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        # render=True,
        # desired_goal_key=desired_goal_key,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    # if variant.get("save_video", False):
    #     rollout_function = rf.create_rollout_function(
    #         rf.multitask_rollout,
    #         max_path_length=algorithm.max_path_length,
    #         observation_key=observation_key,
    #         desired_goal_key=algorithm.desired_goal_key,
    #     )
    #     video_func = get_state_experiment_video_save_function(
    #         rollout_function,
    #         env,
    #         policy,
    #         variant,
    #     )
    #     algorithm.post_epoch_funcs.append(video_func)
    algorithm.to(ptu.device)
    algorithm.train()
Beispiel #22
0
def generate_vae_dataset(
        N=10000, test_p=0.9, use_cached=True, imsize=84, show=False,
        dataset_path=None, policy_path=None, action_space_sampling=False, env_class=SawyerPushAndPullDoorEnv, env_kwargs=None,
        action_plus_random_sampling=False, init_camera=sawyer_door_env_camera, ratio_action_sample_to_random=1 / 2, env_id=None,
):
    if policy_path is not None:
        filename = "/tmp/sawyer_door_push_and_pull_open_oracle+random_policy_data_closer_zoom_action_limited" + str(N) + ".npy"
    elif action_space_sampling:
        filename = "/tmp/sawyer_door_push_and_pull_open_zoomed_in_action_space_sampling" + str(N) + ".npy"
    else:
        filename = "/tmp/sawyer_door_push_and_pull_open" + str(N) + ".npy"
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    elif action_plus_random_sampling:
        if env_id is not None:
            import gym
            env = gym.make(env_id)
        else:
            env = env_class(**env_kwargs)
            env =  ImageEnv(
                env, imsize,
                transpose=True,
                init_camera=init_camera,
                normalize=True,
            )
        action_sampled_data = int(N*ratio_action_sample_to_random)
        dataset = np.zeros((N, imsize * imsize * 3), dtype=np.uint8)
        print('Action Space Sampling')
        for i in range(action_sampled_data):
            goal = env.sample_goal()
            env.set_to_goal(goal)
            img = env._get_flat_img()
            dataset[i, :] = unormalize_image(img)
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        policy = RandomPolicy(env.action_space)
        es = OUStrategy(action_space=env.action_space, theta=0)
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=policy,
        )
        print('Random Sampling')
        for i in range(action_sampled_data, N):
            if i % 20==0:
                env.reset()
                exploration_policy.reset()
            for _ in range(10):
                action = exploration_policy.get_action()[0]
                env.wrapped_env.step(
                    action
                )
            goal = env.sample_goal()
            env.set_to_goal_angle(goal['state_desired_goal'])
            img = env._get_flat_img()
            dataset[i, :] = unormalize_image(img)
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        env._wrapped_env.min_y_pos = .5
        info['env'] = env
    else:
        raise NotImplementedError()
    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    dataset_path=None,
    action_space_sampling=False,
    init_camera=None,
    env_class=None,
    env_kwargs=None,
):
    filename = "/tmp/sawyer_xyz_pos_control_new_zoom_cam" + str(N) + '.npy'
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    else:
        now = time.time()
        if env_kwargs == None:
            env_kwargs = dict()
        env = env_class(**env_kwargs)
        env = ImageEnv(
            env,
            imsize,
            transpose=True,
            init_camera=init_camera,
            normalize=True,
        )
        dataset = np.zeros((N, imsize * imsize * 3), dtype=np.uint8)
        if action_space_sampling:
            action_space = Box(np.array([-.1, .5, 0]), np.array([.1, .7, .5]))
            for i in range(N):
                env.set_to_goal(env.sample_goal())
                img = env._get_flat_img()
                dataset[i, :] = unormalize_image(img)
                if show:
                    cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                    cv2.waitKey(1)
                print(i)
            info['env'] = env
        else:
            policy = RandomPolicy(env.action_space)
            es = OUStrategy(action_space=env.action_space, theta=0)
            exploration_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=policy,
            )
            for i in range(N):
                # Move the goal out of the image
                env.wrapped_env.set_goal(np.array([100, 100, 100]))
                if i % 50 == 0:
                    print('Reset')
                    env.reset()
                    exploration_policy.reset()
                for _ in range(1):
                    action = exploration_policy.get_action()[0] * 10
                    env.wrapped_env.step(action)
                img = env.step(env.action_space.sample())[0]
                dataset[i, :] = img
                if show:
                    cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                    cv2.waitKey(1)
                print(i)

        print("done making training data", time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
def generate_vae_dataset(
    N=10000,
    test_p=0.9,
    use_cached=True,
    imsize=84,
    show=False,
    dataset_path=None,
    policy_path=None,
    ratio_oracle_policy_data_to_random=1 / 2,
    action_space_sampling=False,
    env_class=None,
    env_kwargs=None,
    action_plus_random_sampling=False,
    init_camera=sawyer_door_env_camera,
):
    if policy_path is not None:
        filename = "/tmp/sawyer_door_push_open_oracle+random_policy_data_closer_zoom_action_limited" + str(
            N) + ".npy"
    elif action_space_sampling:
        filename = "/tmp/sawyer_door_push_open_zoomed_in_action_space_sampling" + str(
            N) + ".npy"
    else:
        filename = "/tmp/sawyer_door_push_open" + str(N) + ".npy"
    info = {}
    if dataset_path is not None:
        filename = local_path_from_s3_or_local_path(dataset_path)
        dataset = np.load(filename)
    elif use_cached and osp.isfile(filename):
        dataset = np.load(filename)
        print("loaded data from saved file", filename)
    elif action_space_sampling:
        env = SawyerDoorPushOpenEnv(**env_kwargs)
        env = ImageEnv(
            env,
            imsize,
            transpose=False,
            init_camera=sawyer_door_env_camera,
            normalize=False,
        )
        action_space = Box(np.array([-env.max_x_pos, .5, .06]),
                           np.array([env.max_x_pos, env.max_y_pos, .06]))
        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            env.set_to_goal_pos(action_space.sample())  #move arm to spot
            goal = env.sample_goal()
            env.set_to_goal(goal)
            img = env.get_image().flatten()
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        info['env'] = env
    elif action_plus_random_sampling:
        env = env_class(**env_kwargs)
        env = ImageEnv(
            env,
            imsize,
            transpose=True,
            init_camera=init_camera,
            normalize=True,
        )
        action_space = Box(np.array([-env.max_x_pos, .5, .06]),
                           np.array([env.max_x_pos, .6, .06]))
        action_sampled_data = int(N / 2)
        dataset = np.zeros((N, imsize * imsize * 3))
        print('Action Space Sampling')
        for i in range(action_sampled_data):
            env.set_to_goal_pos(action_space.sample())  # move arm to spot
            goal = env.sample_goal()
            env.set_to_goal(goal)
            img = env._get_flat_img()
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        env._wrapped_env.min_y_pos = .6
        policy = RandomPolicy(env.action_space)
        es = OUStrategy(action_space=env.action_space, theta=0)
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=policy,
        )
        print('Random Sampling')
        for i in range(action_sampled_data, N):
            if i % 20 == 0:
                env.reset()
                exploration_policy.reset()
            for _ in range(10):
                action = exploration_policy.get_action()[0]
                env.wrapped_env.step(action)
            img = env._get_flat_img()
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        env._wrapped_env.min_y_pos = .5
        info['env'] = env
    else:
        now = time.time()
        env = SawyerDoorPushOpenEnv(max_angle=.5)
        env = ImageEnv(
            env,
            imsize,
            transpose=True,
            init_camera=sawyer_door_env_camera,
            normalize=True,
        )
        info['env'] = env
        policy = RandomPolicy(env.action_space)
        es = OUStrategy(action_space=env.action_space, theta=0)
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=policy,
        )
        dataset = np.zeros((N, imsize * imsize * 3))
        for i in range(N):
            if i % 100 == 0:
                env.reset()
                exploration_policy.reset()
            for _ in range(25):
                # env.wrapped_env.step(
                #     env.wrapped_env.action_space.sample()
                # )
                action = exploration_policy.get_action()[0]
                env.wrapped_env.step(action)
            goal = env.sample_goal_for_rollout()
            env.set_to_goal(goal)
            img = env.step(env.action_space.sample())[0]
            dataset[i, :] = img
            if show:
                cv2.imshow('img', img.reshape(3, 84, 84).transpose())
                cv2.waitKey(1)
            print(i)
        print("done making training data", filename, time.time() - now)
        np.save(filename, dataset)

    n = int(N * test_p)
    train_dataset = dataset[:n, :]
    test_dataset = dataset[n:, :]
    return train_dataset, test_dataset, info
Beispiel #25
0
def grill_her_td3_experiment(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.demos.her_td3bc import HerTD3BC
    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    grill_preprocess_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    demo_train_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    demo_test_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])

    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant["render"]
    base_kwargs['render_during_eval'] = variant["render"]
    her_kwargs = algo_kwargs['her_kwargs']
    her_kwargs['observation_key'] = observation_key
    her_kwargs['desired_goal_key'] = desired_goal_key
    # algorithm = HerTd3(
    #     env,
    #     qf1=qf1,
    #     qf2=qf2,
    #     policy=policy,
    #     exploration_policy=exploration_policy,
    #     **variant['algo_kwargs']
    # )
    env.vae.to(ptu.device)

    algorithm = HerTD3BC(env,
                         qf1=qf1,
                         qf2=qf2,
                         policy=policy,
                         exploration_policy=exploration_policy,
                         demo_train_buffer=demo_train_buffer,
                         demo_test_buffer=demo_test_buffer,
                         demo_path=variant["demo_path"],
                         add_demo_latents=True,
                         **variant['algo_kwargs'])

    if variant.get("save_video", True):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)

    algorithm.train()
def grill_her_td3_experiment_offpolicy_online_vae(variant):
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.torch.td3.td3 import TD3
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.exploration_strategies.gaussian_and_epislon import \
        GaussianAndEpislonStrategy
    from railrl.torch.vae.online_vae_offpolicy_algorithm import OnlineVaeOffpolicyAlgorithm

    import gc
    gc.collect()  # Ashvin: this line for a GPU memory error

    grill_preprocess_variant(variant)
    env = get_envs(variant)

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs'])
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
        # **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
        # **variant['policy_kwargs']
    )
    #es = get_exploration_strategy(varient, env)
    es = GaussianAndEpislonStrategy(
        action_space=env.action_space,
        max_sigma=.2,
        min_sigma=.2,  # constant sigma
        epsilon=variant.get('exploration_noise', 0.1),
    )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    vae = env.vae

    replay_buffer_class = variant.get("replay_buffer_class",
                                      OnlineVaeRelabelingBuffer)
    replay_buffer = replay_buffer_class(vae=env.vae,
                                        env=env,
                                        observation_key=observation_key,
                                        desired_goal_key=desired_goal_key,
                                        achieved_goal_key=achieved_goal_key,
                                        **variant['replay_buffer_kwargs'])
    replay_buffer.representation_size = vae.representation_size

    vae_trainer_class = variant.get("vae_trainer_class", ConvVAETrainer)
    vae_trainer = vae_trainer_class(env.vae,
                                    **variant['online_vae_trainer_kwargs'])
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = TD3(policy=policy,
                  qf1=qf1,
                  qf2=qf2,
                  target_qf1=target_qf1,
                  target_qf2=target_qf2,
                  target_policy=target_policy,
                  **variant['td3_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=variant['evaluation_goal_sampling_mode'],
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        env,
        expl_policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=variant['exploration_goal_sampling_mode'],
    )

    algorithm = OnlineVaeOffpolicyAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    if variant.get("save_video", True):
        video_func = VideoSaveFunction(
            env,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)

    algorithm.pretrain()
    algorithm.train()
Beispiel #27
0
def her_td3_experiment(variant):
    import gym
    import multiworld.envs.mujoco
    import multiworld.envs.pygame
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.exploration_strategies.epsilon_greedy import EpsilonGreedy
    from railrl.exploration_strategies.gaussian_strategy import GaussianStrategy
    from railrl.exploration_strategies.ou_strategy import OUStrategy
    from railrl.torch.grill.launcher import get_video_save_func
    from railrl.demos.her_td3bc import HerTD3BC
    from railrl.torch.networks import FlattenMlp, TanhMlpPolicy
    from railrl.data_management.obs_dict_replay_buffer import (
        ObsDictRelabelingBuffer)

    if 'env_id' in variant:
        env = gym.make(variant['env_id'])
    else:
        env = variant['env_class'](**variant['env_kwargs'])

    observation_key = variant['observation_key']
    desired_goal_key = variant['desired_goal_key']
    variant['algo_kwargs']['her_kwargs']['observation_key'] = observation_key
    variant['algo_kwargs']['her_kwargs']['desired_goal_key'] = desired_goal_key
    if variant.get('normalize', False):
        raise NotImplementedError()

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = env.observation_space.spaces['observation'].low.size
    action_dim = env.action_space.low.size
    goal_dim = env.observation_space.spaces['desired_goal'].low.size
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = HerTD3BC(env,
                         qf1=qf1,
                         qf2=qf2,
                         policy=policy,
                         exploration_policy=exploration_policy,
                         replay_buffer=replay_buffer,
                         demo_path=variant["demo_path"],
                         **variant['algo_kwargs'])
    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)
    algorithm.to(ptu.device)
    algorithm.train()
Beispiel #28
0
def experiment(variant):
    env_params = ENV_PARAMS[variant['env']]
    variant.update(env_params)

    expl_env = NormalizedBoxEnv(variant['env_class']())
    eval_env = NormalizedBoxEnv(variant['env_class']())
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    M = variant['layer_size']
    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        max_sigma=.2,
        min_sigma=.2,  # constant sigma
        epsilon=.3,
    )
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[M, M],
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[M, M],
    )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    eval_path_collector = MdpPathCollector(
        eval_env,
        policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = TD3(
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['trainer_kwargs']
    )
    if variant['collection_mode'] == 'online':
        expl_path_collector = MdpStepCollector(
            expl_env,
            expl_policy,
        )
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant['num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant['min_num_steps_before_training'],
        )
    else:
        expl_path_collector = MdpPathCollector(
            expl_env,
            expl_policy,
        )
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            max_path_length=variant['max_path_length'],
            batch_size=variant['batch_size'],
            num_epochs=variant['num_epochs'],
            num_eval_steps_per_epoch=variant['num_eval_steps_per_epoch'],
            num_expl_steps_per_train_loop=variant['num_expl_steps_per_train_loop'],
            num_trains_per_train_loop=variant['num_trains_per_train_loop'],
            min_num_steps_before_training=variant['min_num_steps_before_training'],
        )
    algorithm.to(ptu.device)
    algorithm.train()
Beispiel #29
0
def tdm_td3_experiment(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy
    )
    from railrl.state_distance.tdm_networks import TdmQf, TdmPolicy
    from railrl.state_distance.tdm_td3 import TdmTd3
    from railrl.state_distance.subgoal_planner import SubgoalPlanner
    from railrl.misc.asset_loader import local_path_from_s3_or_local_path
    import joblib

    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")

    vectorized = 'vectorized' in env.reward_type
    variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized
    variant['replay_buffer_kwargs']['vectorized'] = vectorized

    if 'ckpt' in variant:
        if 'ckpt_epoch' in variant:
            epoch = variant['ckpt_epoch']
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
        else:
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'params.pkl'))
        print("Loading ckpt from", filename)
        data = joblib.load(filename)
        qf1 = data['qf1']
        qf2 = data['qf2']
        policy = data['policy']
        variant['algo_kwargs']['base_kwargs']['reward_scale'] = policy.reward_scale
    else:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
        )
        goal_dim = (
            env.observation_space.spaces[desired_goal_key].low.size
        )
        action_dim = env.action_space.low.size

        variant['qf_kwargs']['vectorized'] = vectorized
        norm_order = env.norm_order
        variant['qf_kwargs']['norm_order'] = norm_order
        env.reset()
        _, rew, _, _ = env.step(env.action_space.sample())
        if hasattr(rew, "__len__"):
            variant['qf_kwargs']['output_dim'] = len(rew)
        qf1 = TdmQf(
            env=env,
            observation_dim=obs_dim,
            goal_dim=goal_dim,
            action_dim=action_dim,
            **variant['qf_kwargs']
        )
        qf2 = TdmQf(
            env=env,
            observation_dim=obs_dim,
            goal_dim=goal_dim,
            action_dim=action_dim,
            **variant['qf_kwargs']
        )
        policy = TdmPolicy(
            env=env,
            observation_dim=obs_dim,
            goal_dim=goal_dim,
            action_dim=action_dim,
            reward_scale=variant['algo_kwargs']['base_kwargs'].get('reward_scale', 1.0),
            **variant['policy_kwargs']
        )

    eval_policy = None
    if variant.get('eval_policy', None) == 'SubgoalPlanner':
        eval_policy = SubgoalPlanner(
            env,
            qf1,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            state_based=variant.get("do_state_exp", False),
            max_tau=variant['algo_kwargs']['tdm_kwargs']['max_tau'],
            reward_scale=variant['algo_kwargs']['base_kwargs'].get('reward_scale', 1.0),
            **variant['SubgoalPlanner_kwargs']
        )

    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )

    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant.get("render", False)
    base_kwargs['render_during_eval'] = variant.get("render_during_eval", False)
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = TdmTd3(
        env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        exploration_policy=exploration_policy,
        eval_policy=eval_policy,
        **variant['algo_kwargs']
    )

    if variant.get("test_ckpt", False):
        algorithm.post_epoch_funcs.append(get_update_networks_func(variant))

    vis_variant = variant.get('vis_kwargs', {})
    vis_list = vis_variant.get('vis_list', [])
    if vis_variant.get("save_video", True):
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm._sample_max_tau_for_rollout(),
            decrement_tau=algorithm.cycle_taus_for_rollout,
            cycle_tau=algorithm.cycle_taus_for_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
            vis_list=vis_list,
            dont_terminate=True,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)

    if ptu.gpu_enabled():
        print("using GPU")
        algorithm.cuda()
        if not variant.get("do_state_exp", False):
            env.vae.cuda()

    env.reset()
    if not variant.get("do_state_exp", False):
        env.dump_samples(epoch=None)
        env.dump_reconstructions(epoch=None)
        env.dump_latent_plots(epoch=None)

    algorithm.train()
def experiment(variant):
    rdim = variant["rdim"]
    use_env_goals = variant["use_env_goals"]
    vae_path = variant["vae_paths"][str(rdim)]
    render = variant["render"]
    wrap_mujoco_env = variant.get("wrap_mujoco_env", False)

    # vae = torch.load(vae_path)
    # print("loaded", vae_path)

    from railrl.envs.wrappers import ImageMujocoEnv, NormalizedBoxEnv
    from railrl.images.camera import sawyer_init_camera

    env = variant["env"](**variant['env_kwargs'])
    env = NormalizedBoxEnv(ImageMujocoEnv(
        env,
        imsize=84,
        keep_prev=0,
        init_camera=sawyer_init_camera,
    ))
    if wrap_mujoco_env:
        env = ImageMujocoEnv(env, 84, camera_name="topview", transpose=True, normalize=True)


    if use_env_goals:
        track_qpos_goal = variant.get("track_qpos_goal", 0)
        env = VAEWrappedImageGoalEnv(env, vae_path, use_vae_obs=True,
                                     use_vae_reward=True, use_vae_goals=True,
                                     render_goals=render, render_rollouts=render, track_qpos_goal=track_qpos_goal)
    else:
        env = VAEWrappedEnv(env, vae_path, use_vae_obs=True,
                            use_vae_reward=True, use_vae_goals=True,
                            render_goals=render, render_rollouts=render)

    env = MultitaskToFlatEnv(env)
    if variant['normalize']:
        env = NormalizedBoxEnv(env)
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space)
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=0.1,
            min_sigma=0.1,  # Constant sigma
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=0.1,
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    obs_dim = env.observation_space.low.size
    action_dim = env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 300],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[400, 300],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[400, 300],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    algorithm = TD3(
        env,
        training_env=env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        exploration_policy=exploration_policy,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
        env._wrapped_env.vae.to(ptu.device)