Пример #1
0
def tdm_td3_experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    tdm_normalizer = None
    qf1 = TdmQf(env=env,
                vectorized=True,
                tdm_normalizer=tdm_normalizer,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                vectorized=True,
                tdm_normalizer=tdm_normalizer,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       tdm_normalizer=tdm_normalizer,
                       **variant['policy_kwargs'])
    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space)
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=0.1,
            min_sigma=0.1,  # Constant sigma
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=0.1,
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = variant['replay_buffer_class'](
        env=env, **variant['replay_buffer_kwargs'])
    qf_criterion = variant['qf_criterion_class']()
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['td3_kwargs']['qf_criterion'] = qf_criterion
    algo_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer
    algorithm = TdmTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       replay_buffer=replay_buffer,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       **algo_kwargs)
    algorithm.to(ptu.device)
    algorithm.train()
Пример #2
0
def experiment(variant):
    # env = NormalizedBoxEnv(Reacher7DofXyzGoalState())
    env = NormalizedBoxEnv(MultitaskPoint2DEnv())
    vectorized = True
    policy = StochasticTdmPolicy(env=env, **variant['policy_kwargs'])
    qf = TdmQf(env=env,
               vectorized=vectorized,
               norm_order=2,
               **variant['qf_kwargs'])
    vf = TdmVf(env=env, vectorized=vectorized, **variant['vf_kwargs'])
    replay_buffer_size = variant['algo_params']['base_kwargs'][
        'replay_buffer_size']
    replay_buffer = HerReplayBuffer(replay_buffer_size, env)
    algorithm = TdmSac(
        env,
        qf,
        vf,
        variant['algo_params']['sac_kwargs'],
        variant['algo_params']['tdm_kwargs'],
        variant['algo_params']['base_kwargs'],
        supervised_weight=variant['algo_params']['supervised_weight'],
        policy=policy,
        replay_buffer=replay_buffer,
    )
    if ptu.gpu_enabled():
        algorithm.cuda()

    algorithm.train()
Пример #3
0
def experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    # env = NormalizedBoxEnv(env)
    # tdm_normalizer = TdmNormalizer(
    #     env,
    #     vectorized=True,
    #     max_tau=variant['algo_kwargs']['tdm_kwargs']['max_tau'],
    # )
    tdm_normalizer = None
    qf = TdmQf(env=env,
               vectorized=True,
               tdm_normalizer=tdm_normalizer,
               **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       tdm_normalizer=tdm_normalizer,
                       **variant['policy_kwargs'])
    es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    qf_criterion = variant['qf_criterion_class']()
    ddpg_tdm_kwargs = variant['algo_kwargs']
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    ddpg_tdm_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer
    algorithm = TdmDdpg(env,
                        qf=qf,
                        replay_buffer=replay_buffer,
                        policy=policy,
                        exploration_policy=exploration_policy,
                        **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Пример #4
0
def experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    tdm_normalizer = None
    qf1 = TdmQf(
        env=env,
        vectorized=True,
        tdm_normalizer=tdm_normalizer,
        **variant['qf_kwargs']
    )
    qf2 = TdmQf(
        env=env,
        vectorized=True,
        tdm_normalizer=tdm_normalizer,
        **variant['qf_kwargs']
    )
    policy = TdmPolicy(
        env=env,
        tdm_normalizer=tdm_normalizer,
        **variant['policy_kwargs']
    )
    es = OUStrategy(
        action_space=env.action_space,
        **variant['es_kwargs']
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(
        env=env,
        **variant['her_replay_buffer_kwargs']
    )
    qf_criterion = variant['qf_criterion_class']()
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['td3_kwargs']['qf_criterion'] = qf_criterion
    algo_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer
    algorithm = TdmTd3(
        env,
        qf1=qf1,
        qf2=qf2,
        replay_buffer=replay_buffer,
        policy=policy,
        exploration_policy=exploration_policy,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    env_params = variant['env_params']
    env = MultiTaskSawyerXYZReachingEnv(env_params)
    tdm_normalizer = TdmNormalizer(
        env,
        vectorized=True,
        max_tau=variant['ddpg_tdm_kwargs']['tdm_kwargs']['max_tau'],
    )
    qf = TdmQf(
        env=env,
        vectorized=True,
        hidden_sizes=[variant['hidden_sizes'], variant['hidden_sizes']],
        structure='norm_difference',
        tdm_normalizer=tdm_normalizer,
    )
    policy = TdmPolicy(
        env=env,
        hidden_sizes=[variant['hidden_sizes'], variant['hidden_sizes']],
        tdm_normalizer=tdm_normalizer,
    )
    es = OUStrategy(
        action_space=env.action_space,
        **variant['es_kwargs']
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(
        env=env,
        **variant['her_replay_buffer_kwargs']
    )
    qf_criterion = variant['qf_criterion_class']()
    ddpg_tdm_kwargs = copy.deepcopy(variant['ddpg_tdm_kwargs'])
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    algorithm = TdmDdpg(
        env,
        qf=qf,
        replay_buffer=replay_buffer,
        policy=policy,
        exploration_policy=exploration_policy,
        **variant['ddpg_tdm_kwargs']
    )
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
def experiment(variant):
    vectorized = variant['sac_tdm_kwargs']['tdm_kwargs']['vectorized']
    # env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    max_tau = variant['sac_tdm_kwargs']['tdm_kwargs']['max_tau']
    tdm_normalizer = TdmNormalizer(
        env,
        vectorized,
        max_tau=max_tau,
        **variant['tdm_normalizer_kwargs']
    )
    qf = TdmQf(
        env=env,
        vectorized=vectorized,
        tdm_normalizer=tdm_normalizer,
        **variant['qf_kwargs']
    )
    vf = TdmVf(
        env=env,
        vectorized=vectorized,
        tdm_normalizer=tdm_normalizer,
        **variant['vf_kwargs']
    )
    policy = StochasticTdmPolicy(
        env=env,
        tdm_normalizer=tdm_normalizer,
        **variant['policy_kwargs']
    )
    replay_buffer = HerReplayBuffer(
        env=env,
        **variant['her_replay_buffer_kwargs']
    )
    algorithm = TdmSac(
        env=env,
        policy=policy,
        qf=qf,
        vf=vf,
        replay_buffer=replay_buffer,
        **variant['sac_tdm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    vectorized = variant['vectorized']
    norm_order = variant['norm_order']

    variant['ddpg_tdm_kwargs']['tdm_kwargs']['vectorized'] = vectorized
    variant['ddpg_tdm_kwargs']['tdm_kwargs']['norm_order'] = norm_order

    env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    max_tau = variant['ddpg_tdm_kwargs']['tdm_kwargs']['max_tau']
    tdm_normalizer = TdmNormalizer(env,
                                   vectorized,
                                   max_tau=max_tau,
                                   **variant['tdm_normalizer_kwargs'])
    qf = TdmQf(env=env,
               vectorized=vectorized,
               norm_order=norm_order,
               tdm_normalizer=tdm_normalizer,
               **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       tdm_normalizer=tdm_normalizer,
                       **variant['policy_kwargs'])
    es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    qf_criterion = variant['qf_criterion_class'](
        **variant['qf_criterion_kwargs'])
    ddpg_tdm_kwargs = variant['ddpg_tdm_kwargs']
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    ddpg_tdm_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer
    algorithm = TdmDdpg(env,
                        qf=qf,
                        replay_buffer=replay_buffer,
                        policy=policy,
                        exploration_policy=exploration_policy,
                        **variant['ddpg_tdm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
def experiment(variant):
    env_params = variant['env_params']
    env = MultiTaskSawyerXYZReachingEnv(**env_params)
    max_tau = variant['ddpg_tdm_kwargs']['tdm_kwargs']['max_tau']
    tdm_normalizer = TdmNormalizer(
        env,
        vectorized=True,
        max_tau=max_tau,
    )
    qf = TdmQf(env=env,
               vectorized=True,
               norm_order=2,
               tdm_normalizer=tdm_normalizer,
               **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       tdm_normalizer=tdm_normalizer,
                       **variant['policy_kwargs'])
    es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    qf_criterion = variant['qf_criterion_class']()
    ddpg_tdm_kwargs = variant['ddpg_tdm_kwargs']
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    ddpg_tdm_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer
    algorithm = TdmDdpg(env,
                        qf=qf,
                        replay_buffer=replay_buffer,
                        policy=policy,
                        exploration_policy=exploration_policy,
                        **variant['ddpg_tdm_kwargs'])
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
Пример #9
0
def experiment(variant):
    vectorized = variant['sac_tdm_kwargs']['tdm_kwargs']['vectorized']
    env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    max_tau = variant['sac_tdm_kwargs']['tdm_kwargs']['max_tau']
    qf = TdmQf(env, vectorized=vectorized, **variant['qf_kwargs'])
    tdm_normalizer = TdmNormalizer(env,
                                   vectorized,
                                   max_tau=max_tau,
                                   **variant['tdm_normalizer_kwargs'])
    implicit_model = TdmToImplicitModel(
        env,
        qf,
        tau=0,
    )
    vf = TdmVf(env=env,
               vectorized=vectorized,
               tdm_normalizer=tdm_normalizer,
               **variant['vf_kwargs'])
    policy = StochasticTdmPolicy(env=env,
                                 tdm_normalizer=tdm_normalizer,
                                 **variant['policy_kwargs'])
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    goal_slice = env.ob_to_goal_slice
    lbfgs_mpc_controller = TdmLBfgsBCMC(implicit_model,
                                        env,
                                        goal_slice=goal_slice,
                                        multitask_goal_slice=goal_slice,
                                        **variant['mpc_controller_kwargs'])
    state_only_mpc_controller = TdmLBfgsBStateOnlyCMC(
        vf,
        policy,
        env,
        goal_slice=goal_slice,
        multitask_goal_slice=goal_slice,
        **variant['state_only_mpc_controller_kwargs'])
    es = GaussianStrategy(action_space=env.action_space,
                          **variant['es_kwargs'])
    if variant['explore_with'] == 'TdmLBfgsBCMC':
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=lbfgs_mpc_controller,
        )
        variant['sac_tdm_kwargs']['base_kwargs']['exploration_policy'] = (
            exploration_policy)
    elif variant['explore_with'] == 'TdmLBfgsBStateOnlyCMC':
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=state_only_mpc_controller,
        )
        variant['sac_tdm_kwargs']['base_kwargs']['exploration_policy'] = (
            exploration_policy)
    if variant['eval_with'] == 'TdmLBfgsBCMC':
        variant['sac_tdm_kwargs']['base_kwargs']['eval_policy'] = (
            lbfgs_mpc_controller)
    elif variant['eval_with'] == 'TdmLBfgsBStateOnlyCMC':
        variant['sac_tdm_kwargs']['base_kwargs']['eval_policy'] = (
            state_only_mpc_controller)
    algorithm = TdmSac(env=env,
                       policy=policy,
                       qf=qf,
                       vf=vf,
                       replay_buffer=replay_buffer,
                       **variant['sac_tdm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Пример #10
0
def tdm_td3_experiment(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy
    )
    from railrl.state_distance.tdm_networks import TdmQf, TdmPolicy
    from railrl.state_distance.tdm_td3 import TdmTd3
    from railrl.state_distance.subgoal_planner import SubgoalPlanner
    from railrl.misc.asset_loader import local_path_from_s3_or_local_path
    import joblib

    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)

    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")

    vectorized = 'vectorized' in env.reward_type
    variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized
    variant['replay_buffer_kwargs']['vectorized'] = vectorized

    if 'ckpt' in variant:
        if 'ckpt_epoch' in variant:
            epoch = variant['ckpt_epoch']
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'itr_%d.pkl' % epoch))
        else:
            filename = local_path_from_s3_or_local_path(osp.join(variant['ckpt'], 'params.pkl'))
        print("Loading ckpt from", filename)
        data = joblib.load(filename)
        qf1 = data['qf1']
        qf2 = data['qf2']
        policy = data['policy']
        variant['algo_kwargs']['base_kwargs']['reward_scale'] = policy.reward_scale
    else:
        obs_dim = (
            env.observation_space.spaces[observation_key].low.size
        )
        goal_dim = (
            env.observation_space.spaces[desired_goal_key].low.size
        )
        action_dim = env.action_space.low.size

        variant['qf_kwargs']['vectorized'] = vectorized
        norm_order = env.norm_order
        variant['qf_kwargs']['norm_order'] = norm_order
        env.reset()
        _, rew, _, _ = env.step(env.action_space.sample())
        if hasattr(rew, "__len__"):
            variant['qf_kwargs']['output_dim'] = len(rew)
        qf1 = TdmQf(
            env=env,
            observation_dim=obs_dim,
            goal_dim=goal_dim,
            action_dim=action_dim,
            **variant['qf_kwargs']
        )
        qf2 = TdmQf(
            env=env,
            observation_dim=obs_dim,
            goal_dim=goal_dim,
            action_dim=action_dim,
            **variant['qf_kwargs']
        )
        policy = TdmPolicy(
            env=env,
            observation_dim=obs_dim,
            goal_dim=goal_dim,
            action_dim=action_dim,
            reward_scale=variant['algo_kwargs']['base_kwargs'].get('reward_scale', 1.0),
            **variant['policy_kwargs']
        )

    eval_policy = None
    if variant.get('eval_policy', None) == 'SubgoalPlanner':
        eval_policy = SubgoalPlanner(
            env,
            qf1,
            policy,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            state_based=variant.get("do_state_exp", False),
            max_tau=variant['algo_kwargs']['tdm_kwargs']['max_tau'],
            reward_scale=variant['algo_kwargs']['base_kwargs'].get('reward_scale', 1.0),
            **variant['SubgoalPlanner_kwargs']
        )

    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs']
    )

    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant.get("render", False)
    base_kwargs['render_during_eval'] = variant.get("render_during_eval", False)
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = TdmTd3(
        env,
        qf1=qf1,
        qf2=qf2,
        policy=policy,
        exploration_policy=exploration_policy,
        eval_policy=eval_policy,
        **variant['algo_kwargs']
    )

    if variant.get("test_ckpt", False):
        algorithm.post_epoch_funcs.append(get_update_networks_func(variant))

    vis_variant = variant.get('vis_kwargs', {})
    vis_list = vis_variant.get('vis_list', [])
    if vis_variant.get("save_video", True):
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm._sample_max_tau_for_rollout(),
            decrement_tau=algorithm.cycle_taus_for_rollout,
            cycle_tau=algorithm.cycle_taus_for_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
            vis_list=vis_list,
            dont_terminate=True,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            variant,
        )
        algorithm.post_epoch_funcs.append(video_func)

    if ptu.gpu_enabled():
        print("using GPU")
        algorithm.cuda()
        if not variant.get("do_state_exp", False):
            env.vae.cuda()

    env.reset()
    if not variant.get("do_state_exp", False):
        env.dump_samples(epoch=None)
        env.dump_reconstructions(epoch=None)
        env.dump_latent_plots(epoch=None)

    algorithm.train()
Пример #11
0
def tdm_twin_sac_experiment(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.state_distance.tdm_networks import (
        TdmQf,
        TdmVf,
        StochasticTdmPolicy,
    )
    from railrl.state_distance.tdm_twin_sac import TdmTwinSAC
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    norm_order = env.norm_order
    variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['vectorized'] = vectorized
    variant['vf_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['norm_order'] = norm_order
    variant['vf_kwargs']['norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    vf = TdmVf(env=env,
               observation_dim=obs_dim,
               goal_dim=goal_dim,
               **variant['vf_kwargs'])
    policy = StochasticTdmPolicy(env=env,
                                 observation_dim=obs_dim,
                                 goal_dim=goal_dim,
                                 action_dim=action_dim,
                                 **variant['policy_kwargs'])
    variant['replay_buffer_kwargs']['vectorized'] = vectorized
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant["render"]
    base_kwargs['render_during_eval'] = variant["render"]
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = TdmTwinSAC(env,
                           qf1=qf1,
                           qf2=qf2,
                           vf=vf,
                           policy=policy,
                           **variant['algo_kwargs'])

    if variant.get("save_video", True):
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm._sample_max_tau_for_rollout(),
            decrement_tau=algorithm.cycle_taus_for_rollout,
            cycle_tau=algorithm.cycle_taus_for_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)

    algorithm.train()
Пример #12
0
def tdm_td3_experiment_online_vae(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.data_management.online_vae_replay_buffer import \
        OnlineVaeRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.state_distance.tdm_networks import TdmQf, TdmPolicy
    from railrl.torch.vae.vae_trainer import ConvVAETrainer
    from railrl.torch.online_vae.online_vae_tdm_td3 import OnlineVaeTdmTd3
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    vae_trainer_kwargs = variant.get('vae_trainer_kwargs')
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
        'vectorized'] = vectorized

    norm_order = env.norm_order
    # variant['algo_kwargs']['tdm_td3_kwargs']['tdm_kwargs'][
    #     'norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                vectorized=vectorized,
                norm_order=norm_order,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       observation_dim=obs_dim,
                       goal_dim=goal_dim,
                       action_dim=action_dim,
                       **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    vae = env.vae

    replay_buffer = OnlineVaeRelabelingBuffer(
        vae=vae,
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']['tdm_td3_kwargs']
    td3_kwargs = algo_kwargs['td3_kwargs']
    td3_kwargs['training_env'] = env
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algo_kwargs["replay_buffer"] = replay_buffer

    t = ConvVAETrainer(variant['vae_train_data'],
                       variant['vae_test_data'],
                       vae,
                       beta=variant['online_vae_beta'],
                       **vae_trainer_kwargs)
    render = variant["render"]
    assert 'vae_training_schedule' not in variant, "Just put it in algo_kwargs"
    algorithm = OnlineVaeTdmTd3(
        online_vae_kwargs=dict(vae=vae,
                               vae_trainer=t,
                               **variant['algo_kwargs']['online_vae_kwargs']),
        tdm_td3_kwargs=dict(env=env,
                            qf1=qf1,
                            qf2=qf2,
                            policy=policy,
                            exploration_policy=exploration_policy,
                            **variant['algo_kwargs']['tdm_td3_kwargs']),
    )

    algorithm.to(ptu.device)
    vae.to(ptu.device)
    if variant.get("save_video", True):
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm._sample_max_tau_for_rollout(),
            decrement_tau=algorithm.cycle_taus_for_rollout,
            cycle_tau=algorithm.cycle_taus_for_rollout,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            algorithm.eval_policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)

    algorithm.train()
Пример #13
0
def tdm_td3_experiment(variant):
    import railrl.samplers.rollout_functions as rf
    import railrl.torch.pytorch_util as ptu
    from railrl.core import logger
    from railrl.data_management.obs_dict_replay_buffer import \
        ObsDictRelabelingBuffer
    from railrl.exploration_strategies.base import (
        PolicyWrappedWithExplorationStrategy)
    from railrl.state_distance.tdm_networks import TdmQf, TdmPolicy
    from railrl.state_distance.tdm_td3 import TdmTd3
    preprocess_rl_variant(variant)
    env = get_envs(variant)
    es = get_exploration_strategy(variant, env)
    observation_key = variant.get('observation_key', 'latent_observation')
    desired_goal_key = variant.get('desired_goal_key', 'latent_desired_goal')
    achieved_goal_key = desired_goal_key.replace("desired", "achieved")
    obs_dim = (env.observation_space.spaces[observation_key].low.size)
    goal_dim = (env.observation_space.spaces[desired_goal_key].low.size)
    action_dim = env.action_space.low.size

    vectorized = 'vectorized' in env.reward_type
    norm_order = env.norm_order
    variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['vectorized'] = vectorized
    variant['qf_kwargs']['norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                observation_dim=obs_dim,
                goal_dim=goal_dim,
                action_dim=action_dim,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       observation_dim=obs_dim,
                       goal_dim=goal_dim,
                       action_dim=action_dim,
                       **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    variant['replay_buffer_kwargs']['vectorized'] = vectorized
    replay_buffer = ObsDictRelabelingBuffer(
        env=env,
        observation_key=observation_key,
        desired_goal_key=desired_goal_key,
        achieved_goal_key=achieved_goal_key,
        **variant['replay_buffer_kwargs'])
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['replay_buffer'] = replay_buffer
    base_kwargs = algo_kwargs['base_kwargs']
    base_kwargs['training_env'] = env
    base_kwargs['render'] = variant["render"]
    base_kwargs['render_during_eval'] = variant["render"]
    tdm_kwargs = algo_kwargs['tdm_kwargs']
    tdm_kwargs['observation_key'] = observation_key
    tdm_kwargs['desired_goal_key'] = desired_goal_key
    algorithm = TdmTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       **variant['algo_kwargs'])

    algorithm.to(ptu.device)
    if not variant.get("do_state_exp", False):
        env.vae.to(ptu.device)
    if variant.get("save_video", True):
        logdir = logger.get_snapshot_dir()
        policy.train(False)
        rollout_function = rf.create_rollout_function(
            rf.tdm_rollout,
            init_tau=algorithm.max_tau,
            max_path_length=algorithm.max_path_length,
            observation_key=algorithm.observation_key,
            desired_goal_key=algorithm.desired_goal_key,
        )
        video_func = get_video_save_func(
            rollout_function,
            env,
            policy,
            variant,
        )
        algorithm.post_train_funcs.append(video_func)
    algorithm.train()
Пример #14
0
def tdm_td3_experiment(variant):
    variant['env_kwargs'].update(variant['reward_params'])
    env = variant['env_class'](**variant['env_kwargs'])

    multiworld_env = variant.get('multiworld_env', True)

    if multiworld_env is not True:
        env = MultitaskEnvToSilentMultitaskEnv(env)
        if variant["render"]:
            env.pause_on_goal = True

    if variant['normalize']:
        env = NormalizedBoxEnv(env)

    exploration_type = variant['exploration_type']
    if exploration_type == 'ou':
        es = OUStrategy(action_space=env.action_space,
                        max_sigma=0.1,
                        **variant['es_kwargs'])
    elif exploration_type == 'gaussian':
        es = GaussianStrategy(
            action_space=env.action_space,
            max_sigma=0.1,
            min_sigma=0.1,  # Constant sigma
            **variant['es_kwargs'],
        )
    elif exploration_type == 'epsilon':
        es = EpsilonGreedy(
            action_space=env.action_space,
            prob_random_action=0.1,
            **variant['es_kwargs'],
        )
    else:
        raise Exception("Invalid type: " + exploration_type)
    if multiworld_env is True:
        obs_dim = env.observation_space.spaces['observation'].low.size
        action_dim = env.action_space.low.size
        goal_dim = env.observation_space.spaces['desired_goal'].low.size
    else:
        obs_dim = action_dim = goal_dim = None
    vectorized = 'vectorized' in env.reward_type
    variant['algo_kwargs']['tdm_kwargs']['vectorized'] = vectorized

    norm_order = env.norm_order
    variant['algo_kwargs']['tdm_kwargs']['norm_order'] = norm_order

    qf1 = TdmQf(env=env,
                observation_dim=obs_dim,
                action_dim=action_dim,
                goal_dim=goal_dim,
                vectorized=vectorized,
                norm_order=norm_order,
                **variant['qf_kwargs'])
    qf2 = TdmQf(env=env,
                observation_dim=obs_dim,
                action_dim=action_dim,
                goal_dim=goal_dim,
                vectorized=vectorized,
                norm_order=norm_order,
                **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       observation_dim=obs_dim,
                       action_dim=action_dim,
                       goal_dim=goal_dim,
                       **variant['policy_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )

    relabeling_env = pickle.loads(pickle.dumps(env))

    algo_kwargs = variant['algo_kwargs']

    if multiworld_env is True:
        observation_key = variant.get('observation_key', 'state_observation')
        desired_goal_key = variant.get('desired_goal_key',
                                       'state_desired_goal')
        achieved_goal_key = variant.get('achieved_goal_key',
                                        'state_achieved_goal')
        replay_buffer = ObsDictRelabelingBuffer(
            env=relabeling_env,
            observation_key=observation_key,
            desired_goal_key=desired_goal_key,
            achieved_goal_key=achieved_goal_key,
            vectorized=vectorized,
            **variant['replay_buffer_kwargs'])
        algo_kwargs['tdm_kwargs']['observation_key'] = observation_key
        algo_kwargs['tdm_kwargs']['desired_goal_key'] = desired_goal_key
    else:
        replay_buffer = RelabelingReplayBuffer(
            env=relabeling_env, **variant['replay_buffer_kwargs'])

    # qf_criterion = variant['qf_criterion_class']()
    # algo_kwargs['td3_kwargs']['qf_criterion'] = qf_criterion
    algo_kwargs['td3_kwargs']['training_env'] = env
    if 'tau_schedule_kwargs' in variant:
        tau_schedule = IntPiecewiseLinearSchedule(
            **variant['tau_schedule_kwargs'])
    else:
        tau_schedule = None
    algo_kwargs['tdm_kwargs']['epoch_max_tau_schedule'] = tau_schedule

    algorithm = TdmTd3(env,
                       qf1=qf1,
                       qf2=qf2,
                       policy=policy,
                       exploration_policy=exploration_policy,
                       replay_buffer=replay_buffer,
                       **variant['algo_kwargs'])
    if ptu.gpu_enabled():
        qf1.to(ptu.device)
        qf2.to(ptu.device)
        policy.to(ptu.device)
        algorithm.to(ptu.device)
    algorithm.train()