Exemplo n.º 1
0
def experiment(variant):
    expl_env = gym.make('GoalGridworld-v0')
    eval_env = gym.make('GoalGridworld-v0')

    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.n
    qf = FlattenMlp(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        hidden_sizes=[400, 300],
    )
    target_qf = FlattenMlp(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        hidden_sizes=[400, 300],
    )
    eval_policy = ArgmaxDiscretePolicy(qf)
    exploration_strategy = EpsilonGreedy(action_space=expl_env.action_space, )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=exploration_strategy,
        policy=eval_policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(env=eval_env,
                                            **variant['replay_buffer_kwargs'])
    observation_key = 'observation'
    desired_goal_key = 'desired_goal'
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    trainer = DQNTrainer(qf=qf,
                         target_qf=target_qf,
                         **variant['trainer_kwargs'])
    trainer = HERTrainer(trainer)
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 2
0
def sac(variant):
    expl_env = gym.make(variant["env_name"])
    eval_env = gym.make(variant["env_name"])
    expl_env.seed(variant["seed"])
    eval_env.set_eval()

    mode = variant["mode"]
    archi = variant["archi"]
    if mode == "her":
        variant["her"] = dict(
            observation_key="observation",
            desired_goal_key="desired_goal",
            achieved_goal_key="achieved_goal",
            representation_goal_key="representation_goal",
        )

    replay_buffer = get_replay_buffer(variant, expl_env)
    qf1, qf2, target_qf1, target_qf2, policy, shared_base = get_networks(
        variant, expl_env)
    expl_policy = policy
    eval_policy = MakeDeterministic(policy)

    expl_path_collector, eval_path_collector = get_path_collector(
        variant, expl_env, eval_env, expl_policy, eval_policy)

    mode = variant["mode"]
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant["trainer_kwargs"],
    )
    if mode == "her":
        trainer = HERTrainer(trainer)
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"],
    )

    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 3
0
def twin_sac_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from rlkit.torch.sac.policies import MakeDeterministic
    from rlkit.torch.sac.sac import SACTrainer

    preprocess_rl_variant(variant)
    env = get_envs(variant)
    max_path_length = variant['max_path_length']
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **variant['policy_kwargs'])

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['twin_sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    if variant.get("do_state_exp", False):
        eval_path_collector = GoalConditionedPathCollector(
            env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
        expl_path_collector = GoalConditionedPathCollector(
            env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
    else:
        eval_path_collector = VAEWrappedEnvPathCollector(
            variant['evaluation_goal_sampling_mode'],
            env,
            MakeDeterministic(policy),
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
        expl_path_collector = VAEWrappedEnvPathCollector(
            variant['exploration_goal_sampling_mode'],
            env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    if variant.get("save_video", True):
        video_func = VideoSaveFunction(
            env,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)
    algorithm.train()
Exemplo n.º 4
0
def her_td3_experiment(variant):
    import gym

    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
    from rlkit.exploration_strategies.base import \
        PolicyWrappedWithExplorationStrategy
    from rlkit.exploration_strategies.gaussian_and_epislon import \
        GaussianAndEpislonStrategy
    from rlkit.launchers.launcher_util import setup_logger
    from rlkit.samplers.data_collector import GoalConditionedPathCollector
    from rlkit.torch.her.her import HERTrainer
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.td3.td3 import TD3
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    import rlkit.samplers.rollout_functions as rf
    from rlkit.torch.grill.launcher import get_state_experiment_video_save_function

    if 'env_id' in variant:
        eval_env = gym.make(variant['env_id'])
        expl_env = gym.make(variant['env_id'])
    else:
        eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs'])
        eval_env = variant['env_class'](**eval_env_kwargs)
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        max_sigma=.2,
        min_sigma=.2,  # constant sigma
        epsilon=.3,
    )
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    trainer = TD3(policy=policy,
                  qf1=qf1,
                  qf2=qf2,
                  target_qf1=target_qf1,
                  target_qf2=target_qf2,
                  target_policy=target_policy,
                  **variant['trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    if variant.get("save_video", False):
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
        video_func = get_state_experiment_video_save_function(
            rollout_function,
            eval_env,
            policy,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)

    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 5
0
def experiment(variant):
    import multiworld
    multiworld.register_all_envs()
    eval_env = gym.make('SawyerPushXYZEnv-v0')
    expl_env = gym.make('SawyerPushXYZEnv-v0')
    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        max_sigma=.2,
        min_sigma=.2,  # constant sigma
        epsilon=.3,
    )
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    trainer = TD3Trainer(
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['trainer_kwargs']
    )
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 6
0
def experiment(variant):
    import gym
    from multiworld.envs.mujoco import register_custom_envs

    register_custom_envs()
    observation_key = 'state_observation'
    desired_goal_key = 'xy_desired_goal'
    expl_env = gym.make(variant['env_id'])
    eval_env = gym.make(variant['env_id'])

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size

    M = variant['layer_size']
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim + goal_dim,
        action_dim=action_dim,
        hidden_sizes=[M, M],
    )
    eval_policy = MakeDeterministic(policy)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['trainer_kwargs']
    )
    trainer = HERTrainer(trainer)
    if variant['collection_mode'] == 'online':
        expl_step_collector = GoalConditionedStepCollector(
            expl_env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
        algorithm = TorchOnlineRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_step_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs']
        )
    else:
        expl_path_collector = GoalConditionedPathCollector(
            expl_env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
        )
        algorithm = TorchBatchRLAlgorithm(
            trainer=trainer,
            exploration_env=expl_env,
            evaluation_env=eval_env,
            exploration_data_collector=expl_path_collector,
            evaluation_data_collector=eval_path_collector,
            replay_buffer=replay_buffer,
            **variant['algo_kwargs']
        )
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 7
0
def experiment(variant):
    from multiworld.envs.mujoco import register_mujoco_envs
    register_mujoco_envs()
    env_id = variant['env_id']
    eval_env = gym.make(env_id)
    expl_env = gym.make(env_id)
    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'

    eval_env.reward_type = variant['reward_type']
    expl_env.reward_type = variant['reward_type']

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **variant['policy_kwargs'])
    eval_policy = MakeDeterministic(policy)
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 8
0
def twin_sac_experiment_online_vae(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer

    preprocess_rl_variant(variant)
    env = get_envs(variant)

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs'])
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=hidden_sizes,
    )

    vae = env.vae

    replay_buffer_class = variant.get("replay_buffer_class",
                                      OnlineVaeRelabelingBuffer)
    replay_buffer = replay_buffer_class(vae=env.vae,
                                        env=env,
                                        observation_key=observation_key,
                                        desired_goal_key=desired_goal_key,
                                        achieved_goal_key=achieved_goal_key,
                                        **variant['replay_buffer_kwargs'])

    vae_trainer_class = variant.get("vae_trainer_class", ConvVAETrainer)
    vae_trainer = vae_trainer_class(env.vae,
                                    **variant['online_vae_trainer_kwargs'])
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['twin_sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant['evaluation_goal_sampling_mode'],
        env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant['exploration_goal_sampling_mode'],
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    if variant.get("save_video", True):
        video_func = VideoSaveFunction(
            env,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Exemplo n.º 9
0
def skewfit_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer \
        import OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    import rlkit.torch.vae.vae_schedules as vae_schedules

    #### getting parameter for training VAE and RIG
    env = get_envs(variant)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    replay_buffer_kwargs = variant.get(
        'replay_buffer_kwargs',
        dict(
            start_skew_epoch=10,
            max_size=int(100000),
            fraction_goals_rollout_goals=0.2,
            fraction_goals_env_goals=0.5,
            exploration_rewards_type='None',
            vae_priority_type='vae_prob',
            priority_function_kwargs=dict(
                sampling_method='importance_sampling',
                decoder_distribution='gaussian_identity_variance',
                num_latents_to_sample=10,
            ),
            power=0,
            relabeling_goal_sampling_mode='vae_prior',
        ))
    online_vae_trainer_kwargs = variant.get('online_vae_trainer_kwargs',
                                            dict(beta=20, lr=1e-3))
    max_path_length = variant.get('max_path_length', 50)
    algo_kwargs = variant.get(
        'algo_kwargs',
        dict(
            batch_size=1024,
            num_epochs=1000,
            num_eval_steps_per_epoch=500,
            num_expl_steps_per_train_loop=500,
            num_trains_per_train_loop=1000,
            min_num_steps_before_training=10000,
            vae_training_schedule=vae_schedules.custom_schedule_2,
            oracle_data=False,
            vae_save_period=50,
            parallel_vae_train=False,
        ))
    twin_sac_trainer_kwargs = variant.get(
        'twin_sac_trainer_kwargs',
        dict(
            discount=0.99,
            reward_scale=1,
            soft_target_tau=1e-3,
            target_update_period=1,  # 1
            use_automatic_entropy_tuning=True,
        ))
    ############################################################################

    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=hidden_sizes)
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=hidden_sizes)
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=hidden_sizes)
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=hidden_sizes)
    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                hidden_sizes=hidden_sizes)

    vae = variant['vae_model']
    # create a replay buffer for training an online VAE
    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **replay_buffer_kwargs)
    # create an online vae_trainer to train vae on the fly
    vae_trainer = ConvVAETrainer(variant['vae_train_data'],
                                 variant['vae_test_data'], vae,
                                 **online_vae_trainer_kwargs)
    # create a SACTrainer to learn a soft Q-function and appropriate policy
    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **twin_sac_trainer_kwargs)
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant.get('evaluation_goal_sampling_mode', 'reset_of_env'),
        env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant.get('exploration_goal_sampling_mode', 'vae_prior'),
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        max_path_length=max_path_length,
        **algo_kwargs)

    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Exemplo n.º 10
0
def experiment(variant):
    # unwrap the TimeLimitEnv wrapper since we manually terminate after 50 steps
    eval_env = gym.make('FetchReach-v1').env
    expl_env = gym.make('FetchReach-v1').env

    observation_key = 'observation'
    desired_goal_key = 'desired_goal'

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
    qf1 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **variant['policy_kwargs'])
    eval_policy = MakeDeterministic(policy)
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 11
0
def skewfit_experiment(cfgs):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer

    skewfit_preprocess_variant(cfgs)
    env = get_envs(cfgs)

    # TODO
    uniform_dataset_fn = cfgs.GENERATE_VAE_DATASET.get(
        'uniform_dataset_generator', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **cfgs.GENERATE_VAE_DATASET.generate_uniform_dataset_kwargs)
    else:
        uniform_dataset = None

    observation_key = cfgs.SKEW_FIT.get('observation_key',
                                        'latent_observation')
    desired_goal_key = cfgs.SKEW_FIT.get('desired_goal_key',
                                         'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = cfgs.Q_FUNCTION.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=cfgs.POLICY.get('hidden_sizes', [400, 300]),
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=env.vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        priority_function_kwargs=cfgs.PRIORITY_FUNCTION,
        **cfgs.REPLAY_BUFFER)
    vae_trainer = ConvVAETrainer(
        cfgs.VAE_TRAINER.train_data,
        cfgs.VAE_TRAINER.test_data,
        env.vae,
        beta=cfgs.VAE_TRAINER.beta,
        lr=cfgs.VAE_TRAINER.lr,
    )

    # assert 'vae_training_schedule' not in cfgs, "Just put it in algo_kwargs"
    max_path_length = cfgs.SKEW_FIT.max_path_length
    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **cfgs.TWIN_SAC_TRAINER)
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        cfgs.SKEW_FIT.evaluation_goal_sampling_mode,
        env,
        MakeDeterministic(policy),
        decode_goals=True,  # TODO check this
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        cfgs.SKEW_FIT.exploration_goal_sampling_mode,
        env,
        policy,
        decode_goals=True,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,  # TODO used in test vae
        max_path_length=max_path_length,
        parallel_vae_train=cfgs.VAE_TRAINER.parallel_train,
        **cfgs.ALGORITHM)

    if cfgs.SKEW_FIT.custom_goal_sampler == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Exemplo n.º 12
0
def _e2e_disentangled_experiment(max_path_length,
                                 encoder_kwargs,
                                 disentangled_qf_kwargs,
                                 qf_kwargs,
                                 twin_sac_trainer_kwargs,
                                 replay_buffer_kwargs,
                                 policy_kwargs,
                                 vae_evaluation_goal_sampling_mode,
                                 vae_exploration_goal_sampling_mode,
                                 base_env_evaluation_goal_sampling_mode,
                                 base_env_exploration_goal_sampling_mode,
                                 algo_kwargs,
                                 env_id=None,
                                 env_class=None,
                                 env_kwargs=None,
                                 observation_key='state_observation',
                                 desired_goal_key='state_desired_goal',
                                 achieved_goal_key='state_achieved_goal',
                                 latent_dim=2,
                                 vae_wrapped_env_kwargs=None,
                                 vae_path=None,
                                 vae_n_vae_training_kwargs=None,
                                 vectorized=False,
                                 save_video=True,
                                 save_video_kwargs=None,
                                 have_no_disentangled_encoder=False,
                                 **kwargs):
    if env_kwargs is None:
        env_kwargs = {}
    assert env_id or env_class

    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        train_env = gym.make(env_id)
        eval_env = gym.make(env_id)
    else:
        eval_env = env_class(**env_kwargs)
        train_env = env_class(**env_kwargs)

    train_env.goal_sampling_mode = base_env_exploration_goal_sampling_mode
    eval_env.goal_sampling_mode = base_env_evaluation_goal_sampling_mode

    if vae_path:
        vae = load_local_or_remote_file(vae_path)
    else:
        vae = get_n_train_vae(latent_dim=latent_dim,
                              env=eval_env,
                              **vae_n_vae_training_kwargs)

    train_env = VAEWrappedEnv(train_env,
                              vae,
                              imsize=train_env.imsize,
                              **vae_wrapped_env_kwargs)
    eval_env = VAEWrappedEnv(eval_env,
                             vae,
                             imsize=train_env.imsize,
                             **vae_wrapped_env_kwargs)

    obs_dim = train_env.observation_space.spaces[observation_key].low.size
    goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size
    action_dim = train_env.action_space.low.size

    encoder = ConcatMlp(input_size=obs_dim,
                        output_size=latent_dim,
                        **encoder_kwargs)

    def make_qf():
        if have_no_disentangled_encoder:
            return ConcatMlp(
                input_size=obs_dim + goal_dim + action_dim,
                output_size=1,
                **qf_kwargs,
            )
        else:
            return DisentangledMlpQf(encoder=encoder,
                                     preprocess_obs_dim=obs_dim,
                                     action_dim=action_dim,
                                     qf_kwargs=qf_kwargs,
                                     vectorized=vectorized,
                                     **disentangled_qf_kwargs)

    qf1 = make_qf()
    qf2 = make_qf()
    target_qf1 = make_qf()
    target_qf2 = make_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        vectorized=vectorized,
        **replay_buffer_kwargs)
    sac_trainer = SACTrainer(env=train_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             **twin_sac_trainer_kwargs)
    trainer = HERTrainer(sac_trainer)

    eval_path_collector = VAEWrappedEnvPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=vae_evaluation_goal_sampling_mode,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        train_env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=vae_exploration_goal_sampling_mode,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True)

        if have_no_disentangled_encoder:

            def v_function(obs):
                action = policy.get_actions(obs)
                obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
                return qf1(obs, action)

            add_heatmap = partial(add_heatmap_img_to_o_dict,
                                  v_function=v_function)
        else:

            def v_function(obs):
                action = policy.get_actions(obs)
                obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
                return qf1(obs, action, return_individual_q_vals=True)

            add_heatmap = partial(
                add_heatmap_imgs_to_o_dict,
                v_function=v_function,
                vectorized=vectorized,
            )
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            full_o_postprocess_func=add_heatmap if save_vf_heatmap else None,
        )
        img_keys = ['v_vals'] + [
            'v_vals_dim_{}'.format(dim) for dim in range(latent_dim)
        ]
        eval_video_func = get_save_video_function(rollout_function,
                                                  eval_env,
                                                  MakeDeterministic(policy),
                                                  get_extra_imgs=partial(
                                                      get_extra_imgs,
                                                      img_keys=img_keys),
                                                  tag="eval",
                                                  **save_video_kwargs)
        train_video_func = get_save_video_function(rollout_function,
                                                   train_env,
                                                   policy,
                                                   get_extra_imgs=partial(
                                                       get_extra_imgs,
                                                       img_keys=img_keys),
                                                   tag="train",
                                                   **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)
    algorithm.train()
Exemplo n.º 13
0
def skewfit_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer

    skewfit_preprocess_variant(variant)
    env = get_envs(variant)

    uniform_dataset_fn = variant.get("generate_uniform_dataset_fn", None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant["generate_uniform_dataset_kwargs"])
    else:
        uniform_dataset = None

    observation_key = variant.get("observation_key", "latent_observation")
    desired_goal_key = variant.get("desired_goal_key", "latent_desired_goal")
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get("hidden_sizes", [400, 300])
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=hidden_sizes)
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=hidden_sizes)
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=hidden_sizes)
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=hidden_sizes)
    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                hidden_sizes=hidden_sizes)

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=env.vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant["replay_buffer_kwargs"])
    vae_trainer = ConvVAETrainer(variant["vae_train_data"],
                                 variant["vae_test_data"], env.vae,
                                 **variant["online_vae_trainer_kwargs"])
    assert "vae_training_schedule" not in variant, "Just put it in algo_kwargs"
    max_path_length = variant["max_path_length"]

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant["twin_sac_trainer_kwargs"])
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant["evaluation_goal_sampling_mode"],
        env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant["exploration_goal_sampling_mode"],
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant["algo_kwargs"])

    if variant["custom_goal_sampler"] == "replay_buffer":
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
def skewfit_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from ROLL.online_LSTM_replay_buffer import OnlineLSTMRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp, FlattenPreprocessMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy, TanhGaussianPreprocessPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from ROLL.LSTM_trainer import ConvLSTMTrainer

    skewfit_preprocess_variant(variant)
    env = get_envs(variant)

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs'])
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)

    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])

    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        hidden_sizes=hidden_sizes,
    )

    vae = env.vae_original
    vae_original = env.vae_original
    lstm_segmented = env.lstm_segmented

    replay_buffer = OnlineLSTMRelabelingBuffer(
        vae_ori=vae_original,
        lstm_seg=lstm_segmented,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])

    vae_trainer_original = ConvVAETrainer(
        variant['vae_train_data'][1], variant['vae_test_data'][1],
        env.vae_original, **variant['online_vae_trainer_kwargs'])

    lstm_trainer_segmented = ConvLSTMTrainer(
        variant['vae_train_data'][0], variant['vae_test_data'][0],
        env.lstm_segmented, **variant['online_lstm_trainer_kwargs'])

    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['twin_sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = LSTMWrappedEnvPathCollector(
        variant['evaluation_goal_sampling_mode'],
        env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = LSTMWrappedEnvPathCollector(
        variant['exploration_goal_sampling_mode'],
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    keep_train_segmentation_lstm = variant.get('keep_train_segmentation_lstm',
                                               False)
    algorithm = OnlineLSTMAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae_original=vae_original,
        lstm_segmented=lstm_segmented,
        vae_trainer_original=vae_trainer_original,
        lstm_trainer_segmented=lstm_trainer_segmented,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        keep_train_segmentation_lstm=keep_train_segmentation_lstm,
        **variant['algo_kwargs'])

    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Exemplo n.º 15
0
def main(config: str, output: str, render: bool):
    disable_tensorflow_gpu()

    variant = load_config(config)
    disentanglement_indices = variant['her_variant'].get(
        'disentanglement_indices', None)
    desired_goal_key = variant['her_variant']['desired_goal_key']
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")

    # Load datasets.
    train_dset = load_dset(variant['env']['train_dataset'],
                           stack_images=variant['env'].get(
                               'stack_images', False))
    eval_dset = load_dset(variant['env']['eval_dataset'],
                          stack_images=variant['env'].get(
                              'stack_images', False))

    # Set logging.
    log_dir = create_exp_name(os.path.join(output, variant['env']['env_id']))
    print('Logging to:', log_dir)
    setup_logger(log_dir=log_dir, variant=variant, snapshot_mode='last')
    ptu.set_gpu_mode(torch.cuda.is_available())

    # Load disentanglement model and get disentanglement latent space.
    disentanglement_model = None
    disentanglement_space = None
    internal_keys = None
    goal_keys = None
    if variant['her_variant']['reward_params']['type'] in [
            'disentangled_distance', 'latent_and_disentangled_distance'
    ]:
        assert disentanglement_indices is not None
        disentanglement_model, disentanglement_space = load_disentanglement_model(
            train_dset,
            model_kwargs=variant['disentanglement'].get('model_kwargs', {}),
            factor_indices=disentanglement_indices,
            model_path=variant['env']['disentanglement_model_path'],
        )
        internal_keys = [
            'disentangled_achieved_goal', 'disentangled_desired_goal'
        ]
        goal_keys = ['disentangled_desired_goal']

    experiment_dir = variant.get('saved_experiment_dir', None)
    if experiment_dir is not None:
        # Read saved models from file.
        _, params, vae = load_experiment(experiment_dir)
        vae = params['vae']
        qf1 = params['trainer/qf1']
        qf2 = params['trainer/qf2']
        target_qf1 = params['trainer/target_qf1']
        target_qf2 = params['trainer/target_qf2']
        policy = params['trainer/policy']
        replay_buffer_snapshot = params['replay_buffer']

        # Create VAE trainer. (Note: This does not train the VAE.)
        vae, vae_trainer = train_vae(
            vae,
            dset=train_dset,
            vae_trainer_kwargs=variant['vae']['vae_trainer_kwargs'])

        # Recreate the DisentangledEnv, since saved_env has incorrect camera view for some reason
        env = create_disentangled_env(
            params['exploration/env'].wrapped_env.wrapped_env,
            vae,
            eval_dset=eval_dset,
            disentangled_env_kwargs=dict(
                disentanglement_model=disentanglement_model,
                disentanglement_space=disentanglement_space,
                disentanglement_indices=disentanglement_indices,
                desired_goal_key=desired_goal_key,
                reward_params=variant['her_variant']['reward_params'],
                **variant['env']['vae_wrapped_env_kwargs']))
    else:
        # Create and pre-train VAE.
        vae = VAE(
            x_shape=train_dset.observation_shape,
            representation_size=variant['her_variant']['representation_size'],
            num_factors=train_dset.num_factors,
            imsize=variant['env']['imsize'],
            **variant['vae']['vae_kwargs'])
        vae.to(ptu.device)
        vae, vae_trainer = train_vae(
            vae,
            dset=train_dset,
            **variant['vae']['pretrain_kwargs'],
            vae_trainer_kwargs=variant['vae']['vae_trainer_kwargs'])

        # Create the environment.
        register_all_envs()
        wrapped_env = gym.make(variant['env']['env_id'])
        env = create_disentangled_env(
            wrapped_env,
            vae,
            eval_dset=eval_dset,
            disentangled_env_kwargs=dict(
                disentanglement_model=disentanglement_model,
                disentanglement_space=disentanglement_space,
                disentanglement_indices=disentanglement_indices,
                desired_goal_key=desired_goal_key,
                reward_params=variant['her_variant']['reward_params'],
                render_rollouts=render,
                render_goals=render,
                **variant['env']['vae_wrapped_env_kwargs']))

        # Initialize policy model.
        hidden_sizes = variant.get('hidden_sizes', [400, 300])
        latent_dim = env.observation_space.spaces[
            'latent_observation'].low.size
        goal_dim = env.observation_space.spaces[desired_goal_key].low.size
        action_dim = env.action_space.low.size

        mlp_kwargs = dict(
            input_size=latent_dim + goal_dim + action_dim,
            output_size=1,
            hidden_sizes=hidden_sizes,
        )
        qf1 = FlattenMlp(**mlp_kwargs)
        qf2 = FlattenMlp(**mlp_kwargs)
        target_qf1 = FlattenMlp(**mlp_kwargs)
        target_qf2 = FlattenMlp(**mlp_kwargs)

        policy_kwargs = dict(
            obs_dim=latent_dim + goal_dim,
            action_dim=action_dim,
            hidden_sizes=hidden_sizes,
        )
        policy = TanhGaussianPolicy(**policy_kwargs)

        replay_buffer_snapshot = None

    replay_buffer = ReplayBuffer(
        vae=vae,
        env=env,
        snapshot=replay_buffer_snapshot,
        observation_key='latent_observation',
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        internal_keys=internal_keys,
        goal_keys=goal_keys,
        fraction_goals_env_goals=variant['her_variant']
        ['fraction_goals_env_goals'],
        fraction_goals_rollout_goals=variant['her_variant']
        ['fraction_goals_rollout_goals'],
        relabeling_goal_sampling_mode='custom_goal_sampler',
        power=variant['her_variant']['power'],
        **variant['replay_buffer_kwargs'])
    max_path_length = variant['env']['max_path_length']

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['twin_sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        'presampled',  # Use presampled images from eval_dataset
        env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key='latent_observation',
        desired_goal_key=desired_goal_key,
    )

    exploration_goal_sampler = variant['her_variant'][
        'exploration_goal_sampler']
    if exploration_goal_sampler == 'replay_buffer':
        expl_goal_sampling_mode = 'custom_goal_sampler'
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals
    else:
        assert exploration_goal_sampler in ['vae_prior', 'observation_space']
        expl_goal_sampling_mode = exploration_goal_sampler

    expl_path_collector = VAEWrappedEnvPathCollector(
        expl_goal_sampling_mode,
        env,
        policy,
        max_path_length,
        observation_key='latent_observation',
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Exemplo n.º 16
0
def experiment(args, variant):
    #eval_env = gym.make('FetchReach-v1')
    #expl_env = gym.make('FetchReach-v1')

    core_env = sim.DeepBuilderEnv(args.session_name, args.act_dim,
                                  args.box_dim, args.max_num_boxes,
                                  args.height_field_dim)
    eval_env = stuff.NormalizedActions(core_env)
    expl_env = stuff.NormalizedActions(core_env)

    observation_key = 'observation'
    desired_goal_key = 'desired_goal'

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size

    qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])

    qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])

    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])

    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])

    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **variant['policy_kwargs'])

    eval_policy = MakeDeterministic(policy)
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['sac_trainer_kwargs'])

    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 17
0
def experiment(variant):

    eval_env = Reacher(
        path_length=variant['algo_kwargs']['max_path_length'], reward_type='threshold')

    expl_env = Reacher(
        path_length=variant['algo_kwargs']['max_path_length'], reward_type='threshold')

    observation_key = 'observation'
    # ground truth goals
    desired_goal_key = 'desired_goal'
    achieved_goal_key = 'achieved_goal'
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    # eval_env.observation_space.spaces[observation_key].low.size
    obs_dim = 128
    action_dim = eval_env.action_space.low.size
    # eval_env.observation_space.spaces[desired_goal_key].low.size
    goal_dim = 128

    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim + goal_dim,
        action_dim=action_dim,
        **variant['policy_kwargs']
    )

    """
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim + goal_dim,
        action_dim=action_dim,
        **variant['policy_kwargs']
    )
    """
    eval_policy = MakeDeterministic(policy)
    trainer = SACTrainer(
        env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['sac_trainer_kwargs']
    )
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 18
0
def skewfit_experiment(variant, other_variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer

    skewfit_preprocess_variant(variant)
    env = get_envs(variant)

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs']
        )
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + env.observation_space.spaces[desired_goal_key].low.size
    )
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=env.vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    vae_trainer = ConvVAETrainer(
        variant['vae_train_data'],
        variant['vae_test_data'],
        env.vae,
        **variant['online_vae_trainer_kwargs']
    )
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = TD3Trainer(
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['td3_trainer_kwargs']
    )
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant['evaluation_goal_sampling_mode'],
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant['exploration_goal_sampling_mode'],
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        automatic_policy_schedule=other_variant,
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant['algo_kwargs']
    )

    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Exemplo n.º 19
0
def encoder_wrapped_td3bc_experiment(variant):
    representation_size = 128
    output_classes = 20

    model_class = variant.get('model_class', TimestepPredictionModel)
    model = model_class(
        representation_size,
        # decoder_output_activation=decoder_activation,
        output_classes=output_classes,
        **variant['model_kwargs'],
    )
    # model = torch.nn.DataParallel(model)

    model_path = variant.get("model_path")
    # model = load_local_or_remote_file(model_path)
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(ptu.device)
    model.eval()

    traj = np.load(variant.get("desired_trajectory"), allow_pickle=True)[0]

    goal_image = traj["observations"][-1]["image_observation"]
    goal_image = goal_image.reshape(1, 3, 500, 300).transpose([0, 1, 3, 2]) / 255.0
    # goal_image = goal_image.reshape(1, 300, 500, 3).transpose([0, 3, 1, 2]) / 255.0 # BECAUSE RLBENCH DEMOS ARENT IMAGE_ENV WRAPPED
    # goal_image = goal_image[:, :, :240, 60:500]
    goal_image = goal_image[:, :, 60:, 60:500]
    goal_image_pt = ptu.from_numpy(goal_image)
    save_image(goal_image_pt.data.cpu(), 'demos/goal.png', nrow=1)
    goal_latent = model.encode(goal_image_pt).detach().cpu().numpy().flatten()

    initial_image = traj["observations"][0]["image_observation"]
    initial_image = initial_image.reshape(1, 3, 500, 300).transpose([0, 1, 3, 2]) / 255.0
    # initial_image = initial_image.reshape(1, 300, 500, 3).transpose([0, 3, 1, 2]) / 255.0
    # initial_image = initial_image[:, :, :240, 60:500]
    initial_image = initial_image[:, :, 60:, 60:500]
    initial_image_pt = ptu.from_numpy(initial_image)
    save_image(initial_image_pt.data.cpu(), 'demos/initial.png', nrow=1)
    initial_latent = model.encode(initial_image_pt).detach().cpu().numpy().flatten()

    # Move these to td3_bc and bc_v3 (or at least type for reward_params)
    reward_params = dict(
        goal_latent=goal_latent,
        initial_latent=initial_latent,
        type=variant["reward_params_type"],
    )

    config_params = variant.get("config_params")

    env = variant['env_class'](**variant['env_kwargs'])
    env = ImageEnv(env,
        recompute_reward=False,
        transpose=True,
        image_length=450000,
        reward_type="image_distance",
        # init_camera=sawyer_pusher_camera_upright_v2,
    )
    env = EncoderWrappedEnv(
        env,
        model,
        reward_params,
        config_params,
        **variant.get("encoder_wrapped_env_kwargs", dict())
    )

    expl_env = env # variant['env_class'](**variant['env_kwargs'])
    eval_env = env # variant['env_class'](**variant['env_kwargs'])

    observation_key = variant.get("observation_key", 'state_observation')
    # one of 'state_observation', 'latent_observation', 'concat_observation'
    desired_goal_key = 'latent_desired_goal'

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        **variant["exploration_kwargs"],
    )
    obs_dim = expl_env.observation_space.spaces[observation_key].low.size
    goal_dim = expl_env.observation_space.spaces[desired_goal_key].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = ConcatMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        # output_activation=TorchMaxClamp(0.0),
        **variant['qf_kwargs']
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        # output_activation=TorchMaxClamp(0.0),
        **variant['qf_kwargs']
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        # output_activation=TorchMaxClamp(0.0),
        **variant['qf_kwargs']
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        # output_activation=TorchMaxClamp(0.0),
        **variant['qf_kwargs']
    )

    # Support for CNNPolicy based policy/target policy
    # Defaults to TanhMlpPolicy unless cnn_params is supplied in variant
    if 'cnn_params' in variant.keys():
        imsize = 48
        policy = CNNPolicy(input_width=imsize,
                           input_height=imsize,
                           output_size=action_dim,
                           input_channels=3,
                           **variant['cnn_params'],
                           output_activation=torch.tanh,
        )
        target_policy = CNNPolicy(input_width=imsize,
                           input_height=imsize,
                           output_size=action_dim,
                           input_channels=3,
                           **variant['cnn_params'],
                           output_activation=torch.tanh,
        )
    else:
        policy = TanhMlpPolicy(
            input_size=obs_dim + goal_dim,
            output_size=action_dim,
            **variant['policy_kwargs']
        )
        target_policy = TanhMlpPolicy(
            input_size=obs_dim + goal_dim,
            output_size=action_dim,
            **variant['policy_kwargs']
        )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    demo_train_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    demo_test_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    td3bc_trainer = TD3BCTrainer(
        env=env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        replay_buffer=replay_buffer,
        demo_train_buffer=demo_train_buffer,
        demo_test_buffer=demo_test_buffer,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['trainer_kwargs']
    )
    trainer = HERTrainer(td3bc_trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )

    if variant.get("save_video", True):
        video_func = VideoSaveFunction(
            env,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)

    td3bc_trainer.load_demos()

    td3bc_trainer.pretrain_policy_with_bc()
    td3bc_trainer.pretrain_q_with_bc_data()

    algorithm.train()
Exemplo n.º 20
0
def grill_her_sac_experiment(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy, MakeDeterministic
    from rlkit.torch.sac.sac import SACTrainer
    from rlkit.torch.her.her import HERTrainer
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.samplers.data_collector import GoalConditionedPathCollector
    from rlkit.torch.grill.launcher import (grill_preprocess_variant, get_envs,
                                            get_exploration_strategy)

    full_experiment_variant_preprocess(variant)
    variant = variant['grill_variant']
    grill_preprocess_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])

    policy = TanhGaussianPolicy(obs_dim=obs_dim,
                                action_dim=action_dim,
                                **variant['policy_kwargs'])
    eval_policy = MakeDeterministic(policy)
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    eval_path_collector = GoalConditionedPathCollector(
        env,
        eval_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        env,
        exploration_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])

    trainer = SACTrainer(env=env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])
    return algorithm
Exemplo n.º 21
0
def _disentangled_her_twin_sac_experiment_v2(
        max_path_length,
        encoder_kwargs,
        disentangled_qf_kwargs,
        qf_kwargs,
        twin_sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        evaluation_goal_sampling_mode,
        exploration_goal_sampling_mode,
        algo_kwargs,
        save_video=True,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        # Video parameters
        latent_dim=2,
        save_video_kwargs=None,
        **kwargs
):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm

    if save_video_kwargs is None:
        save_video_kwargs = {}

    if env_kwargs is None:
        env_kwargs = {}
    assert env_id or env_class

    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        train_env = gym.make(env_id)
        eval_env = gym.make(env_id)
    else:
        eval_env = env_class(**env_kwargs)
        train_env = env_class(**env_kwargs)

    obs_dim = train_env.observation_space.spaces[observation_key].low.size
    goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size
    action_dim = train_env.action_space.low.size

    encoder = ConcatMlp(
        input_size=goal_dim,
        output_size=latent_dim,
        **encoder_kwargs
    )

    qf1 = DisentangledMlpQf(
        encoder=encoder,
        preprocess_obs_dim=obs_dim,
        action_dim=action_dim,
        qf_kwargs=qf_kwargs,
        **disentangled_qf_kwargs
    )
    qf2 = DisentangledMlpQf(
        encoder=encoder,
        preprocess_obs_dim=obs_dim,
        action_dim=action_dim,
        qf_kwargs=qf_kwargs,
        **disentangled_qf_kwargs
    )
    target_qf1 = DisentangledMlpQf(
        encoder=Detach(encoder),
        preprocess_obs_dim=obs_dim,
        action_dim=action_dim,
        qf_kwargs=qf_kwargs,
        **disentangled_qf_kwargs
    )
    target_qf2 = DisentangledMlpQf(
        encoder=Detach(encoder),
        preprocess_obs_dim=obs_dim,
        action_dim=action_dim,
        qf_kwargs=qf_kwargs,
        **disentangled_qf_kwargs
    )

    policy = TanhGaussianPolicy(
        obs_dim=obs_dim + goal_dim,
        action_dim=action_dim,
        **policy_kwargs
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **replay_buffer_kwargs
    )
    sac_trainer = SACTrainer(
        env=train_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **twin_sac_trainer_kwargs
    )
    trainer = HERTrainer(sac_trainer)

    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=evaluation_goal_sampling_mode,
    )
    expl_path_collector = GoalConditionedPathCollector(
        train_env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=exploration_goal_sampling_mode,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs,
    )
    algorithm.to(ptu.device)

    if save_video:
        save_vf_heatmap = save_video_kwargs.get('save_vf_heatmap', True)

        def v_function(obs):
            action = policy.get_actions(obs)
            obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
            return qf1(obs, action, return_individual_q_vals=True)
        add_heatmap = partial(add_heatmap_imgs_to_o_dict, v_function=v_function)
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            full_o_postprocess_func=add_heatmap if save_vf_heatmap else None,
        )

        img_keys = ['v_vals'] + [
            'v_vals_dim_{}'.format(dim) for dim
            in range(latent_dim)
        ]
        eval_video_func = get_save_video_function(
            rollout_function,
            eval_env,
            MakeDeterministic(policy),
            tag="eval",
            get_extra_imgs=partial(get_extra_imgs, img_keys=img_keys),
            **save_video_kwargs
        )
        train_video_func = get_save_video_function(
            rollout_function,
            train_env,
            policy,
            tag="train",
            get_extra_imgs=partial(get_extra_imgs, img_keys=img_keys),
            **save_video_kwargs
        )
        decoder = ConcatMlp(
            input_size=obs_dim,
            output_size=obs_dim,
            hidden_sizes=[128, 128],
        )
        decoder.to(ptu.device)

        # algorithm.post_train_funcs.append(train_decoder(variant, encoder, decoder))
        # algorithm.post_train_funcs.append(plot_encoder_function(variant, encoder))
        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_achieved_goal'))
        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_desired_goal'))
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)



    algorithm.train()
Exemplo n.º 22
0
def experiment(variant):
    expl_env = gym.make('pick_and_lift-state-v0',sparse=True, not_special_p=0.5, ground_p = 0, special_is_grip=True, img_size=256, force_randomly_place=False, force_change_position=False)

    observation_key = 'observation'
    desired_goal_key = 'desired_goal'

    achieved_goal_key = "achieved_goal"
    replay_buffer = ObsDictRelabelingBuffer(
        env=expl_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    action_dim = expl_env.action_space.low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    es = GaussianStrategy(
        action_space=expl_env.action_space,
        max_sigma=0.3,
        min_sigma=0.1,
        decay_period=1000000  # Constant sigma
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    trainer = TD3Trainer(
        # env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['td3_trainer_kwargs']
    )
    
    trainer = HERTrainer(trainer)
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=None,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=None,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 23
0
def her_sac_experiment(
        max_path_length,
        qf_kwargs,
        twin_sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        evaluation_goal_sampling_mode,
        exploration_goal_sampling_mode,
        algo_kwargs,
        save_video=True,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        observation_key='state_observation',
        desired_goal_key='state_desired_goal',
        achieved_goal_key='state_achieved_goal',
        # Video parameters
        save_video_kwargs=None,
        exploration_policy_kwargs=None,
        **kwargs
):
    if exploration_policy_kwargs is None:
        exploration_policy_kwargs = {}
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.torch.networks import ConcatMlp
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    if not save_video_kwargs:
        save_video_kwargs = {}

    if env_kwargs is None:
        env_kwargs = {}

    assert env_id or env_class
    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        train_env = gym.make(env_id)
        eval_env = gym.make(env_id)
    else:
        eval_env = env_class(**env_kwargs)
        train_env = env_class(**env_kwargs)

    obs_dim = (
            train_env.observation_space.spaces[observation_key].low.size
            + train_env.observation_space.spaces[desired_goal_key].low.size
    )
    action_dim = train_env.action_space.low.size
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **qf_kwargs
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim,
        action_dim=action_dim,
        **policy_kwargs
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **replay_buffer_kwargs
    )
    trainer = SACTrainer(
        env=train_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **twin_sac_trainer_kwargs
    )
    trainer = HERTrainer(trainer)

    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=evaluation_goal_sampling_mode,
    )
    exploration_policy = create_exploration_policy(
        train_env, policy, **exploration_policy_kwargs)
    expl_path_collector = GoalConditionedPathCollector(
        train_env,
        exploration_policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode=exploration_goal_sampling_mode,
    )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs
    )
    algorithm.to(ptu.device)

    if save_video:
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            return_dict_obs=True,
        )
        eval_video_func = get_save_video_function(
            rollout_function,
            eval_env,
            MakeDeterministic(policy),
            tag="eval",
            **save_video_kwargs
        )
        train_video_func = get_save_video_function(
            rollout_function,
            train_env,
            exploration_policy,
            tag="expl",
            **save_video_kwargs
        )

        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_achieved_goal'))
        # algorithm.post_train_funcs.append(plot_buffer_function(
            # save_video_period, 'state_desired_goal'))
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)

    algorithm.train()
Exemplo n.º 24
0
def _use_disentangled_encoder_distance(
        max_path_length,
        encoder_kwargs,
        disentangled_qf_kwargs,
        qf_kwargs,
        sac_trainer_kwargs,
        replay_buffer_kwargs,
        policy_kwargs,
        evaluation_goal_sampling_mode,
        exploration_goal_sampling_mode,
        algo_kwargs,
        env_id=None,
        env_class=None,
        env_kwargs=None,
        encoder_key_prefix='encoder',
        encoder_input_prefix='state',
        latent_dim=2,
        reward_mode=EncoderWrappedEnv.ENCODER_DISTANCE_REWARD,
        # Video parameters
        save_video=True,
        save_video_kwargs=None,
        save_vf_heatmap=True,
        **kwargs):
    if save_video_kwargs is None:
        save_video_kwargs = {}
    if env_kwargs is None:
        env_kwargs = {}
    assert env_id or env_class
    vectorized = (
        reward_mode == EncoderWrappedEnv.VECTORIZED_ENCODER_DISTANCE_REWARD)

    if env_id:
        import gym
        import multiworld
        multiworld.register_all_envs()
        raw_train_env = gym.make(env_id)
        raw_eval_env = gym.make(env_id)
    else:
        raw_eval_env = env_class(**env_kwargs)
        raw_train_env = env_class(**env_kwargs)

    raw_train_env.goal_sampling_mode = exploration_goal_sampling_mode
    raw_eval_env.goal_sampling_mode = evaluation_goal_sampling_mode

    raw_obs_dim = (
        raw_train_env.observation_space.spaces['state_observation'].low.size)
    action_dim = raw_train_env.action_space.low.size

    encoder = ConcatMlp(input_size=raw_obs_dim,
                        output_size=latent_dim,
                        **encoder_kwargs)
    encoder = Identity()
    encoder.input_size = raw_obs_dim
    encoder.output_size = raw_obs_dim

    np_encoder = EncoderFromNetwork(encoder)
    train_env = EncoderWrappedEnv(
        raw_train_env,
        np_encoder,
        encoder_input_prefix,
        key_prefix=encoder_key_prefix,
        reward_mode=reward_mode,
    )
    eval_env = EncoderWrappedEnv(
        raw_eval_env,
        np_encoder,
        encoder_input_prefix,
        key_prefix=encoder_key_prefix,
        reward_mode=reward_mode,
    )
    observation_key = '{}_observation'.format(encoder_key_prefix)
    desired_goal_key = '{}_desired_goal'.format(encoder_key_prefix)
    achieved_goal_key = '{}_achieved_goal'.format(encoder_key_prefix)
    obs_dim = train_env.observation_space.spaces[observation_key].low.size
    goal_dim = train_env.observation_space.spaces[desired_goal_key].low.size

    def make_qf():
        return DisentangledMlpQf(encoder=encoder,
                                 preprocess_obs_dim=obs_dim,
                                 action_dim=action_dim,
                                 qf_kwargs=qf_kwargs,
                                 vectorized=vectorized,
                                 **disentangled_qf_kwargs)

    qf1 = make_qf()
    qf2 = make_qf()
    target_qf1 = make_qf()
    target_qf2 = make_qf()

    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **policy_kwargs)

    replay_buffer = ObsDictRelabelingBuffer(
        env=train_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        vectorized=vectorized,
        **replay_buffer_kwargs)
    sac_trainer = SACTrainer(env=train_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             **sac_trainer_kwargs)
    trainer = HERTrainer(sac_trainer)

    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        MakeDeterministic(policy),
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode='env',
    )
    expl_path_collector = GoalConditionedPathCollector(
        train_env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        goal_sampling_mode='env',
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=train_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **algo_kwargs)
    algorithm.to(ptu.device)

    if save_video:

        def v_function(obs):
            action = policy.get_actions(obs)
            obs, action = ptu.from_numpy(obs), ptu.from_numpy(action)
            return qf1(obs, action, return_individual_q_vals=True)

        add_heatmap = partial(
            add_heatmap_imgs_to_o_dict,
            v_function=v_function,
            vectorized=vectorized,
        )
        rollout_function = rf.create_rollout_function(
            rf.multitask_rollout,
            max_path_length=max_path_length,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            full_o_postprocess_func=add_heatmap if save_vf_heatmap else None,
        )
        img_keys = ['v_vals'] + [
            'v_vals_dim_{}'.format(dim) for dim in range(latent_dim)
        ]
        eval_video_func = get_save_video_function(rollout_function,
                                                  eval_env,
                                                  MakeDeterministic(policy),
                                                  get_extra_imgs=partial(
                                                      get_extra_imgs,
                                                      img_keys=img_keys),
                                                  tag="eval",
                                                  **save_video_kwargs)
        train_video_func = get_save_video_function(rollout_function,
                                                   train_env,
                                                   policy,
                                                   get_extra_imgs=partial(
                                                       get_extra_imgs,
                                                       img_keys=img_keys),
                                                   tag="train",
                                                   **save_video_kwargs)
        algorithm.post_train_funcs.append(eval_video_func)
        algorithm.post_train_funcs.append(train_video_func)
    algorithm.train()
Exemplo n.º 25
0
def relabeling_tsac_experiment(variant):
    if 'presample_goals' in variant:
        raise NotImplementedError()
    if 'env_id' in variant:
        eval_env = gym.make(variant['env_id'])
        expl_env = gym.make(variant['env_id'])
    else:
        eval_env = variant['env_class'](**variant['env_kwargs'])
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = variant['observation_key']
    desired_goal_key = variant['desired_goal_key']
    if variant.get('normalize', False):
        raise NotImplementedError()

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
    qf1 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim + goal_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **variant['policy_kwargs'])
    max_path_length = variant['max_path_length']
    eval_policy = MakeDeterministic(policy)
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['twin_sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        eval_policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])
    # if variant.get("save_video", False):
    #     rollout_function = rf.create_rollout_function(
    #         rf.multitask_rollout,
    #         max_path_length=algorithm.max_path_length,
    #         observation_key=algorithm.observation_key,
    #         desired_goal_key=algorithm.desired_goal_key,
    #     )
    #     video_func = get_video_save_func(
    #         rollout_function,
    #         env,
    #         policy,
    #         variant,
    #     )
    #     algorithm.post_epoch_funcs.append(video_func)
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 26
0
def td3_experiment_offpolicy_online_vae(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer
    from rlkit.torch.td3.td3 import TD3
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from rlkit.exploration_strategies.gaussian_and_epislon import \
        GaussianAndEpislonStrategy
    from rlkit.torch.vae.online_vae_offpolicy_algorithm import OnlineVaeOffpolicyAlgorithm

    preprocess_rl_variant(variant)
    env = get_envs(variant)

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs'])
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = ConcatMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
        # **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
        # **variant['policy_kwargs']
    )

    es = GaussianAndEpislonStrategy(
        action_space=env.action_space,
        max_sigma=.2,
        min_sigma=.2,  # constant sigma
        epsilon=.3,
    )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    vae = env.vae

    replay_buffer_class = variant.get("replay_buffer_class",
                                      OnlineVaeRelabelingBuffer)
    replay_buffer = replay_buffer_class(vae=env.vae,
                                        env=env,
                                        observation_key=observation_key,
                                        desired_goal_key=desired_goal_key,
                                        achieved_goal_key=achieved_goal_key,
                                        **variant['replay_buffer_kwargs'])
    replay_buffer.representation_size = vae.representation_size

    vae_trainer_class = variant.get("vae_trainer_class", ConvVAETrainer)
    vae_trainer = vae_trainer_class(env.vae,
                                    **variant['online_vae_trainer_kwargs'])
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = TD3(policy=policy,
                  qf1=qf1,
                  qf2=qf2,
                  target_qf1=target_qf1,
                  target_qf2=target_qf2,
                  target_policy=target_policy,
                  **variant['td3_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant['evaluation_goal_sampling_mode'],
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant['exploration_goal_sampling_mode'],
        env,
        expl_policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeOffpolicyAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    if variant.get("save_video", True):
        video_func = VideoSaveFunction(
            env,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)

    algorithm.pretrain()
    algorithm.train()
Exemplo n.º 27
0
def td3_experiment(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)

    from rlkit.torch.td3.td3 import TD3 as TD3Trainer
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm

    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    # preprocess_rl_variant(variant)
    env = get_envs(variant)
    expl_env = env
    eval_env = env
    es = get_exploration_strategy(variant, env)

    if variant.get("use_masks", False):
        mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict())

        expl_mask_distribution_kwargs = variant[
            "expl_mask_distribution_kwargs"]
        expl_mask_distribution = DiscreteDistribution(
            **expl_mask_distribution_kwargs)
        expl_env = RewardMaskWrapper(env, expl_mask_distribution,
                                     **mask_wrapper_kwargs)

        eval_mask_distribution_kwargs = variant[
            "eval_mask_distribution_kwargs"]
        eval_mask_distribution = DiscreteDistribution(
            **eval_mask_distribution_kwargs)
        eval_env = RewardMaskWrapper(env, eval_mask_distribution,
                                     **mask_wrapper_kwargs)
        env = eval_env

    max_path_length = variant['max_path_length']

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key',
                                    'latent_achieved_goal')
    # achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)

    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])

    if variant.get("use_subgoal_policy", False):
        from rlkit.policies.timed_policy import SubgoalPolicyWrapper

        subgoal_policy_kwargs = variant.get('subgoal_policy_kwargs', {})

        policy = SubgoalPolicyWrapper(wrapped_policy=policy,
                                      env=env,
                                      episode_length=max_path_length,
                                      **subgoal_policy_kwargs)
        target_policy = SubgoalPolicyWrapper(wrapped_policy=target_policy,
                                             env=env,
                                             episode_length=max_path_length,
                                             **subgoal_policy_kwargs)

    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        # use_masks=variant.get("use_masks", False),
        **variant['replay_buffer_kwargs'])

    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['td3_trainer_kwargs'])
    # if variant.get("use_masks", False):
    #     from rlkit.torch.her.her import MaskedHERTrainer
    #     trainer = MaskedHERTrainer(trainer)
    # else:
    trainer = HERTrainer(trainer)
    if variant.get("do_state_exp", False):
        eval_path_collector = GoalConditionedPathCollector(
            eval_env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            # use_masks=variant.get("use_masks", False),
            # full_mask=True,
        )
        expl_path_collector = GoalConditionedPathCollector(
            expl_env,
            expl_policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            # use_masks=variant.get("use_masks", False),
        )
    else:
        eval_path_collector = VAEWrappedEnvPathCollector(
            env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            goal_sampling_mode=['evaluation_goal_sampling_mode'],
        )
        expl_path_collector = VAEWrappedEnvPathCollector(
            env,
            expl_policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            goal_sampling_mode=['exploration_goal_sampling_mode'],
        )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    vis_variant = variant.get('vis_kwargs', {})
    vis_list = vis_variant.get('vis_list', [])
    if variant.get("save_video", True):
        if variant.get("do_state_exp", False):
            rollout_function = rf.create_rollout_function(
                rf.multitask_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                desired_goal_key=desired_goal_key,
                # use_masks=variant.get("use_masks", False),
                # full_mask=True,
                # vis_list=vis_list,
            )
            video_func = get_video_save_func(
                rollout_function,
                env,
                policy,
                variant,
            )
        else:
            video_func = VideoSaveFunction(
                env,
                variant,
            )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)
    algorithm.train()
def experiment(variant):
    env = gym.make('RLkitGoalUR-v0')._start_ros_services()
    eval_env = gym.make('RLkitGoalUR-v0')
    expl_env = gym.make('RLkitGoalUR-v0')

    observation_key = 'observation'
    desired_goal_key = 'desired_goal'

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim + goal_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    policy = TanhGaussianPolicy(obs_dim=obs_dim + goal_dim,
                                action_dim=action_dim,
                                **variant['policy_kwargs'])
    eval_policy = MakeDeterministic(policy)
    trainer = SACTrainer(env=eval_env,
                         policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         **variant['sac_trainer_kwargs'])
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        eval_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Exemplo n.º 29
0
def state_td3bc_experiment(variant):
    if variant.get('env_id', None):
        import gym
        import multiworld
        multiworld.register_all_envs()
        eval_env = gym.make(variant['env_id'])
        expl_env = gym.make(variant['env_id'])
    else:
        eval_env_kwargs = variant.get('eval_env_kwargs', variant['env_kwargs'])
        eval_env = variant['env_class'](**eval_env_kwargs)
        expl_env = variant['env_class'](**variant['env_kwargs'])

    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es_strat = variant.get('es', 'ou')
    if es_strat == 'ou':
        es = OUStrategy(
            action_space=expl_env.action_space,
            max_sigma=variant['exploration_noise'],
            min_sigma=variant['exploration_noise'],
        )
    elif es_strat == 'gauss_eps':
        es = GaussianAndEpislonStrategy(
            action_space=expl_env.action_space,
            max_sigma=.2,
            min_sigma=.2,  # constant sigma
            epsilon=.3,
        )
    else:
        raise ValueError("invalid exploration strategy provided")
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + goal_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim + goal_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    demo_train_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        max_size=variant['replay_buffer_kwargs']['max_size'])
    demo_test_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        max_size=variant['replay_buffer_kwargs']['max_size'],
    )
    if variant.get('td3_bc', True):
        td3_trainer = TD3BCTrainer(env=expl_env,
                                   policy=policy,
                                   qf1=qf1,
                                   qf2=qf2,
                                   replay_buffer=replay_buffer,
                                   demo_train_buffer=demo_train_buffer,
                                   demo_test_buffer=demo_test_buffer,
                                   target_qf1=target_qf1,
                                   target_qf2=target_qf2,
                                   target_policy=target_policy,
                                   **variant['td3_bc_trainer_kwargs'])
    else:
        td3_trainer = TD3(policy=policy,
                          qf1=qf1,
                          qf2=qf2,
                          target_qf1=target_qf1,
                          target_qf2=target_qf2,
                          target_policy=target_policy,
                          **variant['td3_trainer_kwargs'])
    trainer = HERTrainer(td3_trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs'])

    if variant.get("save_video", True):
        if variant.get("presampled_goals", None):
            variant['image_env_kwargs'][
                'presampled_goals'] = load_local_or_remote_file(
                    variant['presampled_goals']).item()
        image_eval_env = ImageEnv(eval_env, **variant["image_env_kwargs"])
        image_eval_path_collector = GoalConditionedPathCollector(
            image_eval_env,
            policy,
            observation_key='state_observation',
            desired_goal_key='state_desired_goal',
        )
        image_expl_env = ImageEnv(expl_env, **variant["image_env_kwargs"])
        image_expl_path_collector = GoalConditionedPathCollector(
            image_expl_env,
            expl_policy,
            observation_key='state_observation',
            desired_goal_key='state_desired_goal',
        )
        video_func = VideoSaveFunction(
            image_eval_env,
            variant,
            image_expl_path_collector,
            image_eval_path_collector,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if variant.get('load_demos', False):
        td3_trainer.load_demos()
    if variant.get('pretrain_policy', False):
        td3_trainer.pretrain_policy_with_bc()
    if variant.get('pretrain_rl', False):
        td3_trainer.pretrain_q_with_bc_data()
    algorithm.train()