Example #1
0
def get_td3pg(evaluation_environment, parameters):
    """
    :param evaluation_environment:
    :param parameters:
    :return:
    """
    obs_dim = evaluation_environment.observation_space.low.size
    action_dim = evaluation_environment.action_space.low.size

    hidden_sizes_qf = parameters['hidden_sizes_qf']
    hidden_sizes_policy = parameters['hidden_sizes_policy']

    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes_qf,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes_qf,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes_qf,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes_qf,
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes_policy,
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes_policy,
    )
    es = GaussianStrategy(
        action_space=evaluation_environment.action_space,
        max_sigma=0.1,
        min_sigma=0.1,  # Constant sigma
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **parameters['trainer_params'])
    return exploration_policy, policy, trainer
Example #2
0
def get_td3_trainer(env, hidden_sizes=[256, 256], **kwargs):
    obs_dim = env.observation_space.low.size
    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    hidden_sizes=hidden_sizes)
    qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    hidden_sizes=hidden_sizes)
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           hidden_sizes=hidden_sizes)
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           hidden_sizes=hidden_sizes)
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           hidden_sizes=hidden_sizes)
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  hidden_sizes=hidden_sizes)

    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         hidden_sizes=hidden_sizes)
    return trainer
Example #3
0
def experiment(variant):
    expl_env = NormalizedBoxEnv(HalfCheetahEnv())
    eval_env = NormalizedBoxEnv(HalfCheetahEnv())
    obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant["qf_kwargs"])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant["qf_kwargs"])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant["qf_kwargs"])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant["qf_kwargs"])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant["policy_kwargs"])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  **variant["policy_kwargs"])
    es = GaussianStrategy(
        action_space=expl_env.action_space,
        max_sigma=0.1,
        min_sigma=0.1,  # Constant sigma
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es, policy=policy)
    eval_path_collector = MdpPathCollector(eval_env, policy)
    expl_path_collector = MdpPathCollector(expl_env, exploration_policy)
    replay_buffer = EnvReplayBuffer(variant["replay_buffer_size"], expl_env)
    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant["trainer_kwargs"])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant["algorithm_kwargs"])
    algorithm.to(ptu.device)
    algorithm.train()
Example #4
0
def run_rlkit(env, seed, log_dir):
    """
    Create rlkit model and training.

    :param seed: Random seed for the trial.
    :param log_dir: Log dir path.
    :return result csv file
    """
    reset_execution_environment()
    gt.reset()
    setup_logger(log_dir=log_dir)

    expl_env = NormalizedBoxEnv(env)
    eval_env = NormalizedBoxEnv(env)
    obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=params['qf_hidden_sizes'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     hidden_sizes=params['qf_hidden_sizes'])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=params['qf_hidden_sizes'])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            hidden_sizes=params['qf_hidden_sizes'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           hidden_sizes=params['policy_hidden_sizes'])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  hidden_sizes=params['policy_hidden_sizes'])
    es = RLkitGaussianStrategy(
        action_space=expl_env.action_space,
        max_sigma=params['sigma'],
        min_sigma=params['sigma'],
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    eval_path_collector = MdpPathCollector(
        eval_env,
        policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        exploration_policy,
    )
    replay_buffer = EnvReplayBuffer(
        params['replay_buffer_size'],
        expl_env,
    )
    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         discount=params['discount'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        num_epochs=params['n_epochs'],
        num_train_loops_per_epoch=params['steps_per_epoch'],
        num_trains_per_train_loop=params['n_train_steps'],
        num_expl_steps_per_train_loop=params['n_rollout_steps'],
        num_eval_steps_per_epoch=params['n_rollout_steps'],
        min_num_steps_before_training=params['min_buffer_size'],
        max_path_length=params['n_rollout_steps'],
        batch_size=params['buffer_batch_size'],
    )
    algorithm.to(ptu.device)
    algorithm.train()
    return osp.join(log_dir, 'progress.csv')
Example #5
0
def experiment(variant):
    dummy_env = make_env(variant['env'])
    obs_dim = dummy_env.observation_space.low.size
    action_dim = dummy_env.action_space.low.size
    expl_env = VectorEnv([
        lambda: make_env(variant['env'])
        for _ in range(variant['expl_env_num'])
    ])
    expl_env.seed(variant["seed"])
    expl_env.action_space.seed(variant["seed"])
    eval_env = SubprocVectorEnv([
        lambda: make_env(variant['env'])
        for _ in range(variant['eval_env_num'])
    ])
    eval_env.seed(variant["seed"])

    M = variant['layer_size']
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[M, M],
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[M, M],
    )
    es = GaussianStrategy(
        action_space=dummy_env.action_space,
        max_sigma=0.1,
        min_sigma=0.1,  # Constant sigma
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    eval_path_collector = VecMdpPathCollector(
        eval_env,
        policy,
    )
    expl_path_collector = VecMdpStepCollector(
        expl_env,
        exploration_policy,
    )
    replay_buffer = TorchReplayBuffer(
        variant['replay_buffer_size'],
        dummy_env,
    )
    trainer = TD3Trainer(
        policy=policy,
        target_policy=target_policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        **variant['trainer_kwargs'],
    )
    algorithm = TorchVecOnlineRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'],
    )
    algorithm.to(ptu.device)
    algorithm.train()
Example #6
0
def td3_experiment(variant):
    import rlkit.samplers.rollout_functions as rf
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from rlkit.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)

    from rlkit.torch.td3.td3 import TD3 as TD3Trainer
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm

    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    # preprocess_rl_variant(variant)
    env = get_envs(variant)
    expl_env = env
    eval_env = env
    es = get_exploration_strategy(variant, env)

    if variant.get("use_masks", False):
        mask_wrapper_kwargs = variant.get("mask_wrapper_kwargs", dict())

        expl_mask_distribution_kwargs = variant[
            "expl_mask_distribution_kwargs"]
        expl_mask_distribution = DiscreteDistribution(
            **expl_mask_distribution_kwargs)
        expl_env = RewardMaskWrapper(env, expl_mask_distribution,
                                     **mask_wrapper_kwargs)

        eval_mask_distribution_kwargs = variant[
            "eval_mask_distribution_kwargs"]
        eval_mask_distribution = DiscreteDistribution(
            **eval_mask_distribution_kwargs)
        eval_env = RewardMaskWrapper(env, eval_mask_distribution,
                                     **mask_wrapper_kwargs)
        env = eval_env

    max_path_length = variant['max_path_length']

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key',
                                    'latent_achieved_goal')
    # achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size +
               env.observation_space.spaces[desired_goal_key].low.size)

    action_dim = env.action_space.low.size
    qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])

    if variant.get("use_subgoal_policy", False):
        from rlkit.policies.timed_policy import SubgoalPolicyWrapper

        subgoal_policy_kwargs = variant.get('subgoal_policy_kwargs', {})

        policy = SubgoalPolicyWrapper(wrapped_policy=policy,
                                      env=env,
                                      episode_length=max_path_length,
                                      **subgoal_policy_kwargs)
        target_policy = SubgoalPolicyWrapper(wrapped_policy=target_policy,
                                             env=env,
                                             episode_length=max_path_length,
                                             **subgoal_policy_kwargs)

    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        # use_masks=variant.get("use_masks", False),
        **variant['replay_buffer_kwargs'])

    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['td3_trainer_kwargs'])
    # if variant.get("use_masks", False):
    #     from rlkit.torch.her.her import MaskedHERTrainer
    #     trainer = MaskedHERTrainer(trainer)
    # else:
    trainer = HERTrainer(trainer)
    if variant.get("do_state_exp", False):
        eval_path_collector = GoalConditionedPathCollector(
            eval_env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            # use_masks=variant.get("use_masks", False),
            # full_mask=True,
        )
        expl_path_collector = GoalConditionedPathCollector(
            expl_env,
            expl_policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            # use_masks=variant.get("use_masks", False),
        )
    else:
        eval_path_collector = VAEWrappedEnvPathCollector(
            env,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            goal_sampling_mode=['evaluation_goal_sampling_mode'],
        )
        expl_path_collector = VAEWrappedEnvPathCollector(
            env,
            expl_policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            goal_sampling_mode=['exploration_goal_sampling_mode'],
        )

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **variant['algo_kwargs'])

    vis_variant = variant.get('vis_kwargs', {})
    vis_list = vis_variant.get('vis_list', [])
    if variant.get("save_video", True):
        if variant.get("do_state_exp", False):
            rollout_function = rf.create_rollout_function(
                rf.multitask_rollout,
                max_path_length=max_path_length,
                observation_key=observation_key,
                desired_goal_key=desired_goal_key,
                # use_masks=variant.get("use_masks", False),
                # full_mask=True,
                # vis_list=vis_list,
            )
            video_func = get_video_save_func(
                rollout_function,
                env,
                policy,
                variant,
            )
        else:
            video_func = VideoSaveFunction(
                env,
                variant,
            )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)
    algorithm.train()
Example #7
0
def experiment(variant):
    img_size = 64
    train_top10 = VisualRandomizationConfig(
        image_directory='./experiment_textures/train/top10',
        whitelist=[
            'Floor', 'Roof', 'Wall1', 'Wall2', 'Wall3', 'Wall4',
            'diningTable_visible'
        ],
        apply_arm=False,
        apply_gripper=False,
        apply_floor=True)
    expl_env = gym.make('reach_target_easy-vision-v0',
                        sparse=False,
                        img_size=img_size,
                        force_randomly_place=True,
                        force_change_position=False,
                        blank=True)
    expl_env = wrappers.FlattenDictWrapper(expl_env, dict_keys=['observation'])
    t_fn = variant["t_fn"]
    expl_env = TransformObservationWrapper(expl_env, t_fn)
    obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size
    conv_args = {
        "input_width": 64,
        "input_height": 64,
        "input_channels": 3,
        "kernel_sizes": [4, 4, 3],
        "n_channels": [32, 64, 64],
        "strides": [2, 1, 1],
        "paddings": [0, 0, 0],
        "hidden_sizes": [1024, 512],
        "batch_norm_conv": False,
        "batch_norm_fc": False,
        'init_w': 1e-4,
        "hidden_init": nn.init.orthogonal_,
        "hidden_activation": nn.ReLU(),
    }

    qf1 = FlattenCNN(output_size=1,
                     added_fc_input_size=action_dim,
                     **variant['qf_kwargs'],
                     **conv_args)
    qf2 = FlattenCNN(output_size=1,
                     added_fc_input_size=action_dim,
                     **variant['qf_kwargs'],
                     **conv_args)
    target_qf1 = FlattenCNN(output_size=1,
                            added_fc_input_size=action_dim,
                            **variant['qf_kwargs'],
                            **conv_args)
    target_qf2 = FlattenCNN(output_size=1,
                            added_fc_input_size=action_dim,
                            **variant['qf_kwargs'],
                            **conv_args)
    policy = TanhCNNPolicy(output_size=action_dim,
                           **variant['policy_kwargs'],
                           **conv_args)
    target_policy = TanhCNNPolicy(output_size=action_dim,
                                  **variant['policy_kwargs'],
                                  **conv_args)
    # es = GaussianStrategy(
    #     action_space=expl_env.action_space,
    #     max_sigma=0.3,
    #     min_sigma=0.1,  # Constant sigma
    # )

    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        epsilon=0.3,
        max_sigma=0.0,
        min_sigma=0.0,  #constant sigma 0
        decay_period=1000000)

    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    expl_path_collector = MdpPathCollector(
        expl_env,
        exploration_policy,
    )
    replay_buffer = EnvReplayBuffer(variant['replay_buffer_size'], expl_env)
    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['trainer_kwargs'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=None,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=None,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Example #8
0
def experiment(variant, agent="SAC"):

    # Make sure agent is a valid choice
    assert agent in AGENTS, "Invalid agent selected. Selected: {}. Valid options: {}".format(
        agent, AGENTS)

    # Get environment configs for expl and eval envs and create the appropriate envs
    # suites[0] is expl and suites[1] is eval
    suites = []
    for env_config in (variant["expl_environment_kwargs"],
                       variant["eval_environment_kwargs"]):
        # Load controller
        controller = env_config.pop("controller")
        if controller in set(ALL_CONTROLLERS):
            # This is a default controller
            controller_config = load_controller_config(
                default_controller=controller)
        else:
            # This is a string to the custom controller
            controller_config = load_controller_config(custom_fpath=controller)
        # Create robosuite env and append to our list
        suites.append(
            suite.make(
                **env_config,
                has_renderer=False,
                has_offscreen_renderer=False,
                use_object_obs=True,
                use_camera_obs=False,
                reward_shaping=True,
                controller_configs=controller_config,
            ))
    # Create gym-compatible envs
    expl_env = NormalizedBoxEnv(GymWrapper(suites[0]))
    eval_env = NormalizedBoxEnv(GymWrapper(suites[1]))

    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs'],
    )

    # Define references to variables that are agent-specific
    trainer = None
    eval_policy = None
    expl_policy = None

    # Instantiate trainer with appropriate agent
    if agent == "SAC":
        expl_policy = TanhGaussianPolicy(
            obs_dim=obs_dim,
            action_dim=action_dim,
            **variant['policy_kwargs'],
        )
        eval_policy = MakeDeterministic(expl_policy)
        trainer = SACTrainer(env=eval_env,
                             policy=expl_policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             **variant['trainer_kwargs'])
    elif agent == "TD3":
        eval_policy = TanhMlpPolicy(input_size=obs_dim,
                                    output_size=action_dim,
                                    **variant['policy_kwargs'])
        target_policy = TanhMlpPolicy(input_size=obs_dim,
                                      output_size=action_dim,
                                      **variant['policy_kwargs'])
        es = GaussianStrategy(
            action_space=expl_env.action_space,
            max_sigma=0.1,
            min_sigma=0.1,  # Constant sigma
        )
        expl_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=eval_policy,
        )
        trainer = TD3Trainer(policy=eval_policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             target_policy=target_policy,
                             **variant['trainer_kwargs'])
    else:
        print("Error: No valid agent chosen!")

    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    eval_path_collector = MdpPathCollector(
        eval_env,
        eval_policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        expl_policy,
    )

    # Define algorithm
    algorithm = CustomTorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Example #9
0
def experiment(variant):
    expl_env = gym.make('pick_and_lift-state-v0',sparse=True, not_special_p=0.5, ground_p = 0, special_is_grip=True, img_size=256, force_randomly_place=False, force_change_position=False)

    observation_key = 'observation'
    desired_goal_key = 'desired_goal'

    achieved_goal_key = "achieved_goal"
    replay_buffer = ObsDictRelabelingBuffer(
        env=expl_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    action_dim = expl_env.action_space.low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    es = GaussianStrategy(
        action_space=expl_env.action_space,
        max_sigma=0.3,
        min_sigma=0.1,
        decay_period=1000000  # Constant sigma
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    trainer = TD3Trainer(
        # env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['td3_trainer_kwargs']
    )
    
    trainer = HERTrainer(trainer)
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=None,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=None,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
Example #10
0
def rl_context_experiment(variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.torch.td3.td3 import TD3 as TD3Trainer
    from rlkit.torch.sac.sac import SACTrainer
    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    from rlkit.torch.networks import ConcatMlp, TanhMlpPolicy
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.sac.policies import MakeDeterministic

    preprocess_rl_variant(variant)
    max_path_length = variant['max_path_length']
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = variant.get('achieved_goal_key', 'latent_achieved_goal')

    contextual_mdp = variant.get('contextual_mdp', True)
    print("contextual_mdp:", contextual_mdp)

    mask_variant = variant.get('mask_variant', {})
    mask_conditioned = mask_variant.get('mask_conditioned', False)
    print("mask_conditioned:", mask_conditioned)

    if mask_conditioned:
        assert contextual_mdp

    if 'sac' in variant['algorithm'].lower():
        rl_algo = 'sac'
    elif 'td3' in variant['algorithm'].lower():
        rl_algo = 'td3'
    else:
        raise NotImplementedError
    print("RL algorithm:", rl_algo)

    ### load the example dataset, if running checkpoints ###
    if 'ckpt' in variant:
        import os.path as osp
        example_set_variant = variant.get('example_set_variant', dict())
        example_set_variant['use_cache'] = True
        example_set_variant['cache_path'] = osp.join(variant['ckpt'], 'example_dataset.npy')

    if mask_conditioned:
        env = get_envs(variant)
        mask_format = mask_variant['param_variant']['mask_format']
        assert mask_format in ['vector', 'matrix', 'distribution', 'cond_distribution']
        goal_dim = env.observation_space.spaces[desired_goal_key].low.size
        if mask_format in ['vector']:
            context_dim_for_networks = goal_dim + goal_dim
        elif mask_format in ['matrix', 'distribution', 'cond_distribution']:
            context_dim_for_networks = goal_dim + (goal_dim * goal_dim)
        else:
            raise TypeError

        if 'ckpt' in variant:
            from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
            import os.path as osp

            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'masks.npy'))
            masks = np.load(filename, allow_pickle=True)[()]
        else:
            masks = get_mask_params(
                env=env,
                example_set_variant=variant['example_set_variant'],
                param_variant=mask_variant['param_variant'],
            )

        mask_keys = list(masks.keys())
        context_keys = [desired_goal_key] + mask_keys
    else:
        context_keys = [desired_goal_key]


    def contextual_env_distrib_and_reward(mode='expl'):
        assert mode in ['expl', 'eval']
        env = get_envs(variant)

        if mode == 'expl':
            goal_sampling_mode = variant.get('expl_goal_sampling_mode', None)
        elif mode == 'eval':
            goal_sampling_mode = variant.get('eval_goal_sampling_mode', None)
        if goal_sampling_mode not in [None, 'example_set']:
            env.goal_sampling_mode = goal_sampling_mode

        mask_ids_for_training = mask_variant.get('mask_ids_for_training', None)

        if mask_conditioned:
            context_distrib = MaskDictDistribution(
                env,
                desired_goal_keys=[desired_goal_key],
                mask_format=mask_format,
                masks=masks,
                max_subtasks_to_focus_on=mask_variant.get('max_subtasks_to_focus_on', None),
                prev_subtask_weight=mask_variant.get('prev_subtask_weight', None),
                mask_distr=mask_variant.get('train_mask_distr', None),
                mask_ids=mask_ids_for_training,
            )
            reward_fn = ContextualMaskingRewardFn(
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                mask_keys=mask_keys,
                mask_format=mask_format,
                use_g_for_mean=mask_variant['use_g_for_mean'],
                use_squared_reward=mask_variant.get('use_squared_reward', False),
            )
        else:
            if goal_sampling_mode == 'example_set':
                example_dataset = gen_example_sets(get_envs(variant), variant['example_set_variant'])
                assert len(example_dataset['list_of_waypoints']) == 1
                from rlkit.envs.contextual.set_distributions import GoalDictDistributionFromSet
                context_distrib = GoalDictDistributionFromSet(
                    example_dataset['list_of_waypoints'][0],
                    desired_goal_keys=[desired_goal_key],
                )
            else:
                context_distrib = GoalDictDistributionFromMultitaskEnv(
                    env,
                    desired_goal_keys=[desired_goal_key],
                )
            reward_fn = ContextualRewardFnFromMultitaskEnv(
                env=env,
                achieved_goal_from_observation=IndexIntoAchievedGoal(achieved_goal_key),
                desired_goal_key=desired_goal_key,
                achieved_goal_key=achieved_goal_key,
                additional_obs_keys=variant['contextual_replay_buffer_kwargs'].get('observation_keys', None),
            )
        diag_fn = GoalConditionedDiagnosticsToContextualDiagnostics(
            env.goal_conditioned_diagnostics,
            desired_goal_key=desired_goal_key,
            observation_key=observation_key,
        )
        env = ContextualEnv(
            env,
            context_distribution=context_distrib,
            reward_fn=reward_fn,
            observation_key=observation_key,
            contextual_diagnostics_fns=[diag_fn],
            update_env_info_fn=delete_info if not variant.get('keep_env_infos', False) else None,
        )
        return env, context_distrib, reward_fn

    env, context_distrib, reward_fn = contextual_env_distrib_and_reward(mode='expl')
    eval_env, eval_context_distrib, _ = contextual_env_distrib_and_reward(mode='eval')

    if mask_conditioned:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + context_dim_for_networks
        )
    elif contextual_mdp:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + env.observation_space.spaces[desired_goal_key].low.size
        )
    else:
        obs_dim = env.observation_space.spaces[observation_key].low.size

    action_dim = env.action_space.low.size

    if 'ckpt' in variant and 'ckpt_epoch' in variant:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp

        ckpt_epoch = variant['ckpt_epoch']
        if ckpt_epoch is not None:
            epoch = variant['ckpt_epoch']
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
        else:
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'params.pkl'))
        print("Loading ckpt from", filename)
        data = torch.load(filename)
        qf1 = data['trainer/qf1']
        qf2 = data['trainer/qf2']
        target_qf1 = data['trainer/target_qf1']
        target_qf2 = data['trainer/target_qf2']
        policy = data['trainer/policy']
        eval_policy = data['evaluation/policy']
        expl_policy = data['exploration/policy']
    else:
        qf1 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        qf2 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        target_qf1 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        target_qf2 = ConcatMlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            **variant['qf_kwargs']
        )
        if rl_algo == 'td3':
            policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            target_policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            expl_policy = create_exploration_policy(
                env, policy,
                exploration_version=variant['exploration_type'],
                exploration_noise=variant['exploration_noise'],
            )
            eval_policy = policy
        elif rl_algo == 'sac':
            policy = TanhGaussianPolicy(
                obs_dim=obs_dim,
                action_dim=action_dim,
                **variant['policy_kwargs']
            )
            expl_policy = policy
            eval_policy = MakeDeterministic(policy)

    post_process_mask_fn = partial(
        full_post_process_mask_fn,
        mask_conditioned=mask_conditioned,
        mask_variant=mask_variant,
        context_distrib=context_distrib,
        context_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
    )

    def context_from_obs_dict_fn(obs_dict):
        context_dict = {
            desired_goal_key: obs_dict[achieved_goal_key]
        }

        if mask_conditioned:
            sample_masks_for_relabeling = mask_variant.get('sample_masks_for_relabeling', True)
            if sample_masks_for_relabeling:
                batch_size = next(iter(obs_dict.values())).shape[0]
                sampled_contexts = context_distrib.sample(batch_size)
                for mask_key in mask_keys:
                    context_dict[mask_key] = sampled_contexts[mask_key]
            else:
                for mask_key in mask_keys:
                    context_dict[mask_key] = obs_dict[mask_key]

        return context_dict

    def concat_context_to_obs(batch, replay_buffer=None, obs_dict=None, next_obs_dict=None, new_contexts=None):
        obs = batch['observations']
        next_obs = batch['next_observations']
        batch_size = obs.shape[0]
        if mask_conditioned:
            if obs_dict is not None and new_contexts is not None:
                if not mask_variant.get('relabel_masks', True):
                    for k in mask_keys:
                        new_contexts[k] = next_obs_dict[k][:]
                    batch.update(new_contexts)
                if not mask_variant.get('relabel_goals', True):
                    new_contexts[desired_goal_key] = next_obs_dict[desired_goal_key][:]
                    batch.update(new_contexts)

                new_contexts = post_process_mask_fn(obs_dict, new_contexts)
                batch.update(new_contexts)

            if mask_format in ['vector', 'matrix']:
                goal = batch[desired_goal_key]
                mask = batch['mask'].reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, goal, mask], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, goal, mask], axis=1)
            elif mask_format == 'distribution':
                goal = batch[desired_goal_key]
                sigma_inv = batch['mask_sigma_inv'].reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, goal, sigma_inv], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, goal, sigma_inv], axis=1)
            elif mask_format == 'cond_distribution':
                goal = batch[desired_goal_key]
                mu_w = batch['mask_mu_w']
                mu_g = batch['mask_mu_g']
                mu_A = batch['mask_mu_mat']
                sigma_inv = batch['mask_sigma_inv']
                if mask_variant['use_g_for_mean']:
                    mu_w_given_g = goal
                else:
                    mu_w_given_g = mu_w + np.squeeze(mu_A @ np.expand_dims(goal - mu_g, axis=-1), axis=-1)
                sigma_w_given_g_inv = sigma_inv.reshape((batch_size, -1))
                batch['observations'] = np.concatenate([obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
                batch['next_observations'] = np.concatenate([next_obs, mu_w_given_g, sigma_w_given_g_inv], axis=1)
            else:
                raise NotImplementedError
        elif contextual_mdp:
            goal = batch[desired_goal_key]
            batch['observations'] = np.concatenate([obs, goal], axis=1)
            batch['next_observations'] = np.concatenate([next_obs, goal], axis=1)
        else:
            batch['observations'] = obs
            batch['next_observations'] = next_obs

        return batch

    if 'observation_keys' not in variant['contextual_replay_buffer_kwargs']:
        variant['contextual_replay_buffer_kwargs']['observation_keys'] = []
    observation_keys = variant['contextual_replay_buffer_kwargs']['observation_keys']
    if observation_key not in observation_keys:
        observation_keys.append(observation_key)
    if achieved_goal_key not in observation_keys:
        observation_keys.append(achieved_goal_key)

    replay_buffer = ContextualRelabelingReplayBuffer(
        env=env,
        context_keys=context_keys,
        context_distribution=context_distrib,
        sample_context_from_obs_dict_fn=context_from_obs_dict_fn,
        reward_fn=reward_fn,
        post_process_batch_fn=concat_context_to_obs,
        **variant['contextual_replay_buffer_kwargs']
    )

    if rl_algo == 'td3':
        trainer = TD3Trainer(
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            target_policy=target_policy,
            **variant['td3_trainer_kwargs']
        )
    elif rl_algo == 'sac':
        trainer = SACTrainer(
            env=env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            **variant['sac_trainer_kwargs']
        )

    def create_path_collector(
            env,
            policy,
            mode='expl',
            mask_kwargs={},
    ):
        assert mode in ['expl', 'eval']

        save_env_in_snapshot = variant.get('save_env_in_snapshot', True)

        if mask_conditioned:
            if 'rollout_mask_order' in mask_kwargs:
                rollout_mask_order = mask_kwargs['rollout_mask_order']
            else:
                if mode == 'expl':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_expl', 'fixed')
                elif mode == 'eval':
                    rollout_mask_order = mask_variant.get('rollout_mask_order_for_eval', 'fixed')
                else:
                    raise TypeError

            if 'mask_distr' in mask_kwargs:
                mask_distr = mask_kwargs['mask_distr']
            else:
                if mode == 'expl':
                    mask_distr = mask_variant['expl_mask_distr']
                elif mode == 'eval':
                    mask_distr = mask_variant['eval_mask_distr']
                else:
                    raise TypeError

            if 'mask_ids' in mask_kwargs:
                mask_ids = mask_kwargs['mask_ids']
            else:
                if mode == 'expl':
                    mask_ids = mask_variant.get('mask_ids_for_expl', None)
                elif mode == 'eval':
                    mask_ids = mask_variant.get('mask_ids_for_eval', None)
                else:
                    raise TypeError

            prev_subtask_weight = mask_variant.get('prev_subtask_weight', None)
            max_subtasks_to_focus_on = mask_variant.get('max_subtasks_to_focus_on', None)
            max_subtasks_per_rollout = mask_variant.get('max_subtasks_per_rollout', None)

            mode = mask_variant.get('context_post_process_mode', None)
            if mode in ['dilute_prev_subtasks_uniform', 'dilute_prev_subtasks_fixed']:
                prev_subtask_weight = 0.5

            return MaskPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                concat_context_to_obs_fn=concat_context_to_obs,
                save_env_in_snapshot=save_env_in_snapshot,
                mask_sampler=(context_distrib if mode=='expl' else eval_context_distrib),
                mask_distr=mask_distr.copy(),
                mask_ids=mask_ids,
                max_path_length=max_path_length,
                rollout_mask_order=rollout_mask_order,
                prev_subtask_weight=prev_subtask_weight,
                max_subtasks_to_focus_on=max_subtasks_to_focus_on,
                max_subtasks_per_rollout=max_subtasks_per_rollout,
            )
        elif contextual_mdp:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=context_keys,
                save_env_in_snapshot=save_env_in_snapshot,
            )
        else:
            return ContextualPathCollector(
                env,
                policy,
                observation_key=observation_key,
                context_keys_for_policy=[],
                save_env_in_snapshot=save_env_in_snapshot,
            )

    expl_path_collector = create_path_collector(env, expl_policy, mode='expl')
    eval_path_collector = create_path_collector(eval_env, eval_policy, mode='eval')

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        max_path_length=max_path_length,
        **variant['algo_kwargs']
    )

    algorithm.to(ptu.device)

    if variant.get("save_video", True):
        save_period = variant.get('save_video_period', 50)
        dump_video_kwargs = variant.get("dump_video_kwargs", dict())
        dump_video_kwargs['horizon'] = max_path_length

        renderer = EnvRenderer(**variant.get('renderer_kwargs', {}))

        def add_images(env, state_distribution):
            state_env = env.env
            image_goal_distribution = AddImageDistribution(
                env=state_env,
                base_distribution=state_distribution,
                image_goal_key='image_desired_goal',
                renderer=renderer,
            )
            img_env = InsertImagesEnv(state_env, renderers={
                'image_observation' : renderer,
            })
            context_env = ContextualEnv(
                img_env,
                context_distribution=image_goal_distribution,
                reward_fn=reward_fn,
                observation_key=observation_key,
                update_env_info_fn=delete_info,
            )
            return context_env

        img_eval_env = add_images(eval_env, eval_context_distrib)

        if variant.get('log_eval_video', True):
            video_path_collector = create_path_collector(img_eval_env, eval_policy, mode='eval')
            rollout_function = video_path_collector._rollout_fn
            eval_video_func = get_save_video_function(
                rollout_function,
                img_eval_env,
                eval_policy,
                tag="eval",
                imsize=variant['renderer_kwargs']['width'],
                image_format='CHW',
                save_video_period=save_period,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(eval_video_func)

        # additional eval videos for mask conditioned case
        if mask_conditioned:
            default_list = [
                'atomic',
                'atomic_seq',
                'cumul_seq',
                'full',
            ]
            eval_rollouts_for_videos = mask_variant.get('eval_rollouts_for_videos', default_list)
            for key in eval_rollouts_for_videos:
                assert key in default_list

            if 'cumul_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            cumul_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_cumul" if mask_conditioned else "eval",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'full' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            full=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_full",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

            if 'atomic_seq' in eval_rollouts_for_videos:
                video_path_collector = create_path_collector(
                    img_eval_env,
                    eval_policy,
                    mode='eval',
                    mask_kwargs=dict(
                        mask_distr=dict(
                            atomic_seq=1.0
                        ),
                    ),
                )
                rollout_function = video_path_collector._rollout_fn
                eval_video_func = get_save_video_function(
                    rollout_function,
                    img_eval_env,
                    eval_policy,
                    tag="eval_atomic",
                    imsize=variant['renderer_kwargs']['width'],
                    image_format='HWC',
                    save_video_period=save_period,
                    **dump_video_kwargs
                )
                algorithm.post_train_funcs.append(eval_video_func)

        if variant.get('log_expl_video', True) and not variant['algo_kwargs'].get('eval_only', False):
            img_expl_env = add_images(env, context_distrib)
            video_path_collector = create_path_collector(img_expl_env, expl_policy, mode='expl')
            rollout_function = video_path_collector._rollout_fn
            expl_video_func = get_save_video_function(
                rollout_function,
                img_expl_env,
                expl_policy,
                tag="expl",
                imsize=variant['renderer_kwargs']['width'],
                image_format='CHW',
                save_video_period=save_period,
                **dump_video_kwargs
            )
            algorithm.post_train_funcs.append(expl_video_func)

    addl_collectors = []
    addl_log_prefixes = []
    if mask_conditioned and mask_variant.get('log_mask_diagnostics', True):
        default_list = [
            'atomic',
            'atomic_seq',
            'cumul_seq',
            'full',
        ]
        eval_rollouts_to_log = mask_variant.get('eval_rollouts_to_log', default_list)
        for key in eval_rollouts_to_log:
            assert key in default_list

        # atomic masks
        if 'atomic' in eval_rollouts_to_log:
            for mask_id in eval_path_collector.mask_ids:
                mask_kwargs=dict(
                    mask_ids=[mask_id],
                    mask_distr=dict(
                        atomic=1.0,
                    ),
                )
                collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
                addl_collectors.append(collector)
            addl_log_prefixes += [
                'mask_{}/'.format(''.join(str(mask_id)))
                for mask_id in eval_path_collector.mask_ids
            ]

        # full mask
        if 'full' in eval_rollouts_to_log:
            mask_kwargs=dict(
                mask_distr=dict(
                    full=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_full/')

        # cumulative, sequential mask
        if 'cumul_seq' in eval_rollouts_to_log:
            mask_kwargs=dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    cumul_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_cumul_seq/')

        # atomic, sequential mask
        if 'atomic_seq' in eval_rollouts_to_log:
            mask_kwargs=dict(
                rollout_mask_order='fixed',
                mask_distr=dict(
                    atomic_seq=1.0,
                ),
            )
            collector = create_path_collector(eval_env, eval_policy, mode='eval', mask_kwargs=mask_kwargs)
            addl_collectors.append(collector)
            addl_log_prefixes.append('mask_atomic_seq/')

        def get_mask_diagnostics(unused):
            from rlkit.core.logging import append_log, add_prefix, OrderedDict
            log = OrderedDict()
            for prefix, collector in zip(addl_log_prefixes, addl_collectors):
                paths = collector.collect_new_paths(
                    max_path_length,
                    variant['algo_kwargs']['num_eval_steps_per_epoch'],
                    discard_incomplete_paths=True,
                )
                old_path_info = eval_env.get_diagnostics(paths)

                keys_to_keep = []
                for key in old_path_info.keys():
                    if ('env_infos' in key) and ('final' in key) and ('Mean' in key):
                        keys_to_keep.append(key)
                path_info = OrderedDict()
                for key in keys_to_keep:
                    path_info[key] = old_path_info[key]

                generic_info = add_prefix(
                    path_info,
                    prefix,
                )
                append_log(log, generic_info)

            for collector in addl_collectors:
                collector.end_epoch(0)
            return log

        algorithm._eval_get_diag_fns.append(get_mask_diagnostics)
        
    if 'ckpt' in variant:
        from rlkit.misc.asset_loader import local_path_from_s3_or_local_path
        import os.path as osp
        assert variant['algo_kwargs'].get('eval_only', False)

        def update_networks(algo, epoch):
            if 'ckpt_epoch' in variant:
                return

            if epoch % algo._eval_epoch_freq == 0:
                filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
                print("Loading ckpt from", filename)
                data = torch.load(filename)#, map_location='cuda:1')
                eval_policy = data['evaluation/policy']
                eval_policy.to(ptu.device)
                algo.eval_data_collector._policy = eval_policy
                for collector in addl_collectors:
                    collector._policy = eval_policy

        algorithm.post_train_funcs.insert(0, update_networks)

    algorithm.train()
Example #11
0
def experiment(variant, args):
    # Doesn't work :(
    #import gym
    #expl_env = NormalizedBoxEnv( gym.make(args.env) )
    #eval_env = NormalizedBoxEnv( gym.make(args.env) )

    if 'Ant' in args.env:
        expl_env = NormalizedBoxEnv( AntEnv() )
        eval_env = NormalizedBoxEnv( AntEnv() )
    elif 'InvertedPendulum' in args.env:
        expl_env = NormalizedBoxEnv( InvertedPendulumEnv() )
        eval_env = NormalizedBoxEnv( InvertedPendulumEnv() )
    elif 'HalfCheetah' in args.env:
        expl_env = NormalizedBoxEnv( HalfCheetahEnv() )
        eval_env = NormalizedBoxEnv( HalfCheetahEnv() )
    elif 'Hopper' in args.env:
        expl_env = NormalizedBoxEnv( HopperEnv() )
        eval_env = NormalizedBoxEnv( HopperEnv() )
    elif 'Reacher' in args.env:
        expl_env = NormalizedBoxEnv( ReacherEnv() )
        eval_env = NormalizedBoxEnv( ReacherEnv() )
    elif 'Swimmer' in args.env:
        expl_env = NormalizedBoxEnv( SwimmerEnv() )
        eval_env = NormalizedBoxEnv( SwimmerEnv() )
    elif 'Walker2d' in args.env:
        expl_env = NormalizedBoxEnv( Walker2dEnv() )
        eval_env = NormalizedBoxEnv( Walker2dEnv() )
    else:
        raise ValueError(args.env)

    # Back to normal.
    obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    es = GaussianStrategy(
        action_space=expl_env.action_space,
        max_sigma=0.1,
        min_sigma=0.1,  # Constant sigma
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    eval_path_collector = MdpPathCollector(
        eval_env,
        policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        exploration_policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = TD3Trainer(
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['trainer_kwargs']
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
def prepare_trainer(algorithm, expl_env, obs_dim, action_dim, pretrained_policy_load, variant):
    print("Preparing for {} trainer.".format(algorithm))
    if algorithm == "SAC":
        if not pretrained_policy_load:
            M = variant['layer_size']
            qf1 = FlattenMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                hidden_sizes=[M, M],
            )
            qf2 = FlattenMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                hidden_sizes=[M, M],
            )
            target_qf1 = FlattenMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                hidden_sizes=[M, M],
            )
            target_qf2 = FlattenMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                hidden_sizes=[M, M],
            )
            policy = TanhGaussianPolicy(
                obs_dim=obs_dim,
                action_dim=action_dim,
                hidden_sizes=[M, M],
            )
        else:
            snapshot = torch.load(pretrained_policy_load)
            qf1 = snapshot['trainer/qf1']
            qf2 = snapshot['trainer/qf2']
            target_qf1 = snapshot['trainer/target_qf1']
            target_qf2 = snapshot['trainer/target_qf2']
            policy = snapshot['exploration/policy']
            if variant['trainer_kwargs']['use_automatic_entropy_tuning']:
                log_alpha = snapshot['trainer/log_alpha'] 
                variant['trainer_kwargs']['log_alpha'] = log_alpha
                alpha_optimizer = snapshot['trainer/alpha_optimizer'] 
                variant['trainer_kwargs']['alpha_optimizer'] = alpha_optimizer
            print("loaded the pretrained policy {}".format(pretrained_policy_load))
        
        eval_policy = MakeDeterministic(policy)
        expl_policy = policy

        trainer = SACTrainer(
            env=expl_env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            **variant['trainer_kwargs']
        )

    elif algorithm == "TD3":
        if not pretrained_policy_load:
            qf1 = FlattenMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                **variant['qf_kwargs']
            )
            qf2 = FlattenMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                **variant['qf_kwargs']
            )
            target_qf1 = FlattenMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                **variant['qf_kwargs']
            )
            target_qf2 = FlattenMlp(
                input_size=obs_dim + action_dim,
                output_size=1,
                **variant['qf_kwargs']
            )
            policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            target_policy = TanhMlpPolicy(
                input_size=obs_dim,
                output_size=action_dim,
                **variant['policy_kwargs']
            )
            es = GaussianStrategy(
                action_space=expl_env.action_space,
                max_sigma=0.1,
                min_sigma=0.1,  # Constant sigma
            )
            exploration_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=es,
                policy=policy,
            )
            expl_policy = exploration_policy
            eval_policy = policy
        else:
            pass

        trainer = TD3Trainer(
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            target_policy=target_policy,
            **variant['trainer_kwargs']
        )

    return expl_policy, eval_policy, trainer
def experiment(variant):
    expl_env = envs[variant['env']](variant['dr'])
    expl_env = wrappers.FlattenDictWrapper(expl_env, dict_keys=['observation'])
    t_fn = variant["t_fn"]
    expl_env = TransformObservationWrapper(expl_env, t_fn)
    action_dim = expl_env.action_space.low.size
    conv_args = {
        "input_width": 16,
        "input_height": 16,
        "input_channels": 8,
        "kernel_sizes": [4],
        "n_channels": [32],
        "strides": [4],
        "paddings": [0],
        "hidden_sizes": [1024, 512],
        "batch_norm_conv": False,
        "batch_norm_fc": False,
        'init_w': 1e-4,
        "hidden_init": nn.init.orthogonal_,
        "hidden_activation": nn.ReLU(),
    }

    qf1 = FlattenCNN(output_size=1,
                     added_fc_input_size=action_dim,
                     **variant['qf_kwargs'],
                     **conv_args)
    qf2 = FlattenCNN(output_size=1,
                     added_fc_input_size=action_dim,
                     **variant['qf_kwargs'],
                     **conv_args)
    target_qf1 = FlattenCNN(output_size=1,
                            added_fc_input_size=action_dim,
                            **variant['qf_kwargs'],
                            **conv_args)
    target_qf2 = FlattenCNN(output_size=1,
                            added_fc_input_size=action_dim,
                            **variant['qf_kwargs'],
                            **conv_args)
    policy = TanhCNNPolicy(output_size=action_dim,
                           **variant['policy_kwargs'],
                           **conv_args)
    target_policy = TanhCNNPolicy(output_size=action_dim,
                                  **variant['policy_kwargs'],
                                  **conv_args)
    if variant['noise'] == "eps":
        es = GaussianAndEpislonStrategy(
            action_space=expl_env.action_space,
            epsilon=0.3,
            max_sigma=0.0,
            min_sigma=0.0,  #constant sigma 0
            decay_period=1000000)
    elif variant['noise'] == "gaussian":
        es = GaussianStrategy(action_space=expl_env.action_space,
                              max_sigma=0.3,
                              min_sigma=0.1,
                              decay_period=1000000)
    else:
        print("unsupported param for --noise")
        assert False
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    expl_path_collector = MdpPathCollector(
        expl_env,
        exploration_policy,
    )
    replay_buffer = EnvReplayBuffer(variant['replay_buffer_size'], expl_env)
    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['trainer_kwargs'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=None,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=None,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    expl_env = envs[variant['env']](variant['dr'])
    expl_env = TransformObservationWrapper(expl_env, variant['main_t_fn'])

    observation_key = 'observation'
    desired_goal_key = 'desired_goal'
    achieved_goal_key = "achieved_goal"
    replay_buffer = imgObsDictRelabelingBuffer(
        env=expl_env,
        rerendering_env=rerendering_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        t_fn=variant['t_fn'],
        **variant['replay_buffer_kwargs'])
    # obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size
    if variant['mlf']:
        if variant['alt'] == 'both':
            conv_args = {
                "input_width": 16,
                "input_height": 16,
                "input_channels": 16,
                "kernel_sizes": [4],
                "n_channels": [32],
                "strides": [4],
                "paddings": [0],
                "hidden_sizes": [1024, 512],
                # "added_fc_input_size": action_dim,
                "batch_norm_conv": False,
                "batch_norm_fc": False,
                'init_w': 1e-4,
                "hidden_init": nn.init.orthogonal_,
                "hidden_activation": nn.ReLU(),
            }
        else:
            conv_args = {
                "input_width": 16,
                "input_height": 16,
                "input_channels": 8,
                "kernel_sizes": [4],
                "n_channels": [32],
                "strides": [4],
                "paddings": [0],
                "hidden_sizes": [1024, 512],
                "batch_norm_conv": False,
                "batch_norm_fc": False,
                'init_w': 1e-4,
                "hidden_init": nn.init.orthogonal_,
                "hidden_activation": nn.ReLU(),
            }
    else:
        if variant['alt'] == 'both':
            conv_args = {
                "input_width": 64,
                "input_height": 64,
                "input_channels": 6,
                "kernel_sizes": [4, 4, 3],
                "n_channels": [32, 64, 64],
                "strides": [2, 1, 1],
                "paddings": [0, 0, 0],
                "hidden_sizes": [1024, 512],
                "batch_norm_conv": False,
                "batch_norm_fc": False,
                'init_w': 1e-4,
                "hidden_init": nn.init.orthogonal_,
                "hidden_activation": nn.ReLU(),
            }
        else:
            conv_args = {
                "input_width": 64,
                "input_height": 64,
                "input_channels": 3,
                "kernel_sizes": [4, 4, 3],
                "n_channels": [32, 64, 64],
                "strides": [2, 1, 1],
                "paddings": [0, 0, 0],
                "hidden_sizes": [1024, 512],
                "batch_norm_conv": False,
                "batch_norm_fc": False,
                'init_w': 1e-4,
                "hidden_init": nn.init.orthogonal_,
                "hidden_activation": nn.ReLU(),
            }

    qf1 = FlattenCNN(output_size=1,
                     added_fc_input_size=action_dim,
                     **variant['qf_kwargs'],
                     **conv_args)
    qf2 = FlattenCNN(output_size=1,
                     added_fc_input_size=action_dim,
                     **variant['qf_kwargs'],
                     **conv_args)
    target_qf1 = FlattenCNN(output_size=1,
                            added_fc_input_size=action_dim,
                            **variant['qf_kwargs'],
                            **conv_args)
    target_qf2 = FlattenCNN(output_size=1,
                            added_fc_input_size=action_dim,
                            **variant['qf_kwargs'],
                            **conv_args)
    policy = TanhCNNPolicy(output_size=action_dim,
                           **variant['policy_kwargs'],
                           **conv_args)
    target_policy = TanhCNNPolicy(output_size=action_dim,
                                  **variant['policy_kwargs'],
                                  **conv_args)
    if variant['noise'] == "eps":
        es = GaussianAndEpislonStrategy(
            action_space=expl_env.action_space,
            epsilon=0.3,
            max_sigma=0.0,
            min_sigma=0.0,  #constant sigma 0
            decay_period=1000000)
    elif variant['noise'] == "gaussian":
        es = GaussianStrategy(action_space=expl_env.action_space,
                              max_sigma=0.3,
                              min_sigma=0.1,
                              decay_period=1000000)
    else:
        print("unsupported param for --noise")
        assert False

    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    # eval_path_collector = MdpPathCollector(
    #     eval_env,
    #     policy,
    # )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        exploration_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    eval_path_collector = GoalConditionedPathCollector(
        expl_env,
        exploration_policy.policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['trainer_kwargs'])
    # trainer = HERTrainer(trainer)
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=None,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=None,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Example #15
0
def experiment(variant):
    base_expl_env = PointMassEnv(n=variant["num_tasks"],
                                 reward_type=variant["reward_type"])
    expl_env = FlatGoalEnv(base_expl_env, append_goal_to_obs=True)

    base_eval_env = PointMassEnv(n=variant["num_tasks"],
                                 reward_type=variant["reward_type"])
    eval_env = FlatGoalEnv(base_eval_env, append_goal_to_obs=True)
    obs_dim = expl_env.observation_space.low.size
    action_dim = expl_env.action_space.low.size

    print(expl_env.observation_space, expl_env.action_space)
    qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                     output_size=1,
                     **variant['qf_kwargs'])
    target_qf1 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    target_qf2 = FlattenMlp(input_size=obs_dim + action_dim,
                            output_size=1,
                            **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    es = GaussianStrategy(
        action_space=expl_env.action_space,
        max_sigma=0.1,
        min_sigma=0.1,  # Constant sigma
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    eval_path_collector = MdpPathCollector(
        eval_env,
        policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        exploration_policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['trainer_kwargs'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    algorithm.train()
Example #16
0
File: td3.py Project: cvigoe/rlkit
def experiment(variant):
    # #    expl_env = NormalizedBoxEnv(HalfCheetahEnv())
    #     expl_env = NormalizedBoxEnv(gym.make('Walker2d-v2'))
    # #    eval_env = NormalizedBoxEnv(HalfCheetahEnv())
    #     eval_env = NormalizedBoxEnv(gym.make('Walker2d-v2'))
    #     obs_dim = expl_env.observation_space.low.size
    #     action_dim = expl_env.action_space.low.size

    expl_env = NormalizedBoxEnv(gym.make('activesearchrl-v0'))
    eval_env = NormalizedBoxEnv(gym.make('activesearchrl-v0'))
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                    output_size=1,
                    **variant['qf_kwargs'])
    target_qf1 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    target_qf2 = ConcatMlp(input_size=obs_dim + action_dim,
                           output_size=1,
                           **variant['qf_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    target_policy = TanhMlpPolicy(input_size=obs_dim,
                                  output_size=action_dim,
                                  **variant['policy_kwargs'])
    es = GaussianStrategy(
        action_space=expl_env.action_space,
        max_sigma=0.1,
        min_sigma=0.1,  # Constant sigma
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    eval_path_collector = MdpPathCollector(
        eval_env,
        policy,
    )
    expl_path_collector = MdpPathCollector(
        expl_env,
        exploration_policy,
    )
    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )
    trainer = TD3Trainer(policy=policy,
                         qf1=qf1,
                         qf2=qf2,
                         target_qf1=target_qf1,
                         target_qf2=target_qf2,
                         target_policy=target_policy,
                         **variant['trainer_kwargs'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Example #17
0
def skewfit_experiment(variant, other_variant):
    import rlkit.torch.pytorch_util as ptu
    from rlkit.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from rlkit.torch.networks import FlattenMlp, TanhMlpPolicy
    from rlkit.torch.sac.policies import TanhGaussianPolicy
    from rlkit.torch.vae.vae_trainer import ConvVAETrainer

    skewfit_preprocess_variant(variant)
    env = get_envs(variant)

    uniform_dataset_fn = variant.get('generate_uniform_dataset_fn', None)
    if uniform_dataset_fn:
        uniform_dataset = uniform_dataset_fn(
            **variant['generate_uniform_dataset_kwargs']
        )
    else:
        uniform_dataset = None

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (
            env.observation_space.spaces[observation_key].low.size
            + env.observation_space.spaces[desired_goal_key].low.size
    )
    action_dim = env.action_space.low.size
    hidden_sizes = variant.get('hidden_sizes', [400, 300])
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=hidden_sizes,
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=hidden_sizes,
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=env.vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    vae_trainer = ConvVAETrainer(
        variant['vae_train_data'],
        variant['vae_test_data'],
        env.vae,
        **variant['online_vae_trainer_kwargs']
    )
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    max_path_length = variant['max_path_length']

    trainer = TD3Trainer(
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['td3_trainer_kwargs']
    )
    trainer = HERTrainer(trainer)
    eval_path_collector = VAEWrappedEnvPathCollector(
        variant['evaluation_goal_sampling_mode'],
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = VAEWrappedEnvPathCollector(
        variant['exploration_goal_sampling_mode'],
        env,
        policy,
        max_path_length,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )

    algorithm = OnlineVaeAlgorithm(
        automatic_policy_schedule=other_variant,
        trainer=trainer,
        exploration_env=env,
        evaluation_env=env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        vae=vae,
        vae_trainer=vae_trainer,
        uniform_dataset=uniform_dataset,
        max_path_length=max_path_length,
        **variant['algo_kwargs']
    )

    if variant['custom_goal_sampler'] == 'replay_buffer':
        env.custom_goal_sampler = replay_buffer.sample_buffer_goals

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    algorithm.train()
Example #18
0
def experiment(variant):
    import multiworld
    multiworld.register_all_envs()
    eval_env = gym.make('SawyerPushXYZEnv-v0')
    expl_env = gym.make('SawyerPushXYZEnv-v0')
    observation_key = 'state_observation'
    desired_goal_key = 'state_desired_goal'
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    es = GaussianAndEpislonStrategy(
        action_space=expl_env.action_space,
        max_sigma=.2,
        min_sigma=.2,  # constant sigma
        epsilon=.3,
    )
    obs_dim = expl_env.observation_space.spaces['observation'].low.size
    goal_dim = expl_env.observation_space.spaces['desired_goal'].low.size
    action_dim = expl_env.action_space.low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + goal_dim + action_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    expl_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    trainer = TD3Trainer(
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['trainer_kwargs']
    )
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        expl_policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
Example #19
0
def experiment(variant):
    # unwrap the TimeLimitEnv wrapper since we manually termiante after 50 steps
    eval_env = gym.make('FetchReach-v1').env
    expl_env = gym.make('FetchReach-v1').env

    observation_key = 'observation'
    desired_goal_key = 'desired_goal'

    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    replay_buffer = ObsDictRelabelingBuffer(
        env=eval_env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )
    obs_dim = eval_env.observation_space.spaces['observation'].low.size
    action_dim = eval_env.action_space.low.size
    goal_dim = eval_env.observation_space.spaces['desired_goal'].low.size
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim + goal_dim,
        output_size=1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    target_policy = TanhMlpPolicy(
        input_size=obs_dim + goal_dim,
        output_size=action_dim,
        **variant['policy_kwargs']
    )

    trainer = TD3Trainer(
        # env=eval_env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        target_qf1=target_qf1,
        target_qf2=target_qf2,
        target_policy=target_policy,
        **variant['td3_trainer_kwargs']
    )
    trainer = HERTrainer(trainer)
    eval_path_collector = GoalConditionedPathCollector(
        eval_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    expl_path_collector = GoalConditionedPathCollector(
        expl_env,
        policy,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
    )
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()