コード例 #1
0
def experiment(variant):
    env = NormalizedBoxEnv(variant['env_class']())

    obs_dim = int(np.prod(env.observation_space.low.shape))
    action_dim = int(np.prod(env.action_space.low.shape))
    vectorized = variant['sac_tdm_kwargs']['tdm_kwargs']['vectorized']
    qf = FlattenMlp(input_size=obs_dim + action_dim + env.goal_dim + 1,
                    output_size=env.goal_dim if vectorized else 1,
                    **variant['qf_params'])
    vf = FlattenMlp(input_size=obs_dim + env.goal_dim + 1,
                    output_size=env.goal_dim if vectorized else 1,
                    **variant['vf_params'])
    policy = TanhGaussianPolicy(obs_dim=obs_dim + env.goal_dim + 1,
                                action_dim=action_dim,
                                **variant['policy_params'])
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_params'])
    algorithm = TdmSac(env=env,
                       policy=policy,
                       qf=qf,
                       vf=vf,
                       replay_buffer=replay_buffer,
                       **variant['sac_tdm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #2
0
def experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    # env = NormalizedBoxEnv(env)
    # tdm_normalizer = TdmNormalizer(
    #     env,
    #     vectorized=True,
    #     max_tau=variant['algo_kwargs']['tdm_kwargs']['max_tau'],
    # )
    tdm_normalizer = None
    qf = TdmQf(env=env,
               vectorized=True,
               tdm_normalizer=tdm_normalizer,
               **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       tdm_normalizer=tdm_normalizer,
                       **variant['policy_kwargs'])
    es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    qf_criterion = variant['qf_criterion_class']()
    ddpg_tdm_kwargs = variant['algo_kwargs']
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    ddpg_tdm_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer
    algorithm = TdmDdpg(env,
                        qf=qf,
                        replay_buffer=replay_buffer,
                        policy=policy,
                        exploration_policy=exploration_policy,
                        **variant['algo_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #3
0
def experiment(variant):
    env = NormalizedBoxEnv(variant['env_class']())

    obs_dim = int(np.prod(env.observation_space.low.shape))
    action_dim = int(np.prod(env.action_space.low.shape))
    gcm = FlattenMlp(input_size=env.goal_dim + obs_dim + action_dim + 1,
                     output_size=env.goal_dim,
                     **variant['gcm_kwargs'])
    policy = TanhMlpPolicy(input_size=obs_dim + env.goal_dim + 1,
                           output_size=action_dim,
                           **variant['policy_kwargs'])
    es = OUStrategy(
        action_space=env.action_space,
        theta=0.1,
        max_sigma=0.1,
        min_sigma=0.1,
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    gcm_criterion = variant['gcm_criterion_class'](
        **variant['gcm_criterion_kwargs'])
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['base_kwargs']['replay_buffer'] = replay_buffer
    algorithm = GcmDdpg(env,
                        gcm=gcm,
                        policy=policy,
                        exploration_policy=exploration_policy,
                        gcm_criterion=gcm_criterion,
                        **algo_kwargs)
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #4
0
def experiment(variant):
    env = NormalizedBoxEnv(variant['env_class']())

    obs_dim = int(np.prod(env.observation_space.low.shape))
    action_dim = int(np.prod(env.action_space.low.shape))
    vectorized = variant['algo_params']['tdm_kwargs']['vectorized']
    qf_class = variant['qf_class']
    vf_class = variant['vf_class']
    policy_class = variant['policy_class']
    qf = qf_class(observation_dim=obs_dim,
                  action_dim=action_dim,
                  goal_dim=env.goal_dim,
                  output_size=env.goal_dim if vectorized else 1,
                  **variant['qf_params'])
    vf = vf_class(observation_dim=obs_dim,
                  goal_dim=env.goal_dim,
                  output_size=env.goal_dim if vectorized else 1,
                  **variant['qf_params'])
    policy = policy_class(obs_dim=obs_dim,
                          action_dim=action_dim,
                          goal_dim=env.goal_dim,
                          **variant['policy_params'])
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_params'])
    algorithm = TdmSac(env=env,
                       policy=policy,
                       qf=qf,
                       vf=vf,
                       replay_buffer=replay_buffer,
                       **variant['algo_params'])
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
コード例 #5
0
def experiment(variant):
    # env = NormalizedBoxEnv(Reacher7DofXyzGoalState())
    env = NormalizedBoxEnv(MultitaskPoint2DEnv())
    vectorized = True
    policy = StochasticTdmPolicy(env=env, **variant['policy_kwargs'])
    qf = TdmQf(env=env,
               vectorized=vectorized,
               norm_order=2,
               **variant['qf_kwargs'])
    vf = TdmVf(env=env, vectorized=vectorized, **variant['vf_kwargs'])
    replay_buffer_size = variant['algo_params']['base_kwargs'][
        'replay_buffer_size']
    replay_buffer = HerReplayBuffer(replay_buffer_size, env)
    algorithm = TdmSac(
        env,
        qf,
        vf,
        variant['algo_params']['sac_kwargs'],
        variant['algo_params']['tdm_kwargs'],
        variant['algo_params']['base_kwargs'],
        supervised_weight=variant['algo_params']['supervised_weight'],
        policy=policy,
        replay_buffer=replay_buffer,
    )
    if ptu.gpu_enabled():
        algorithm.cuda()

    algorithm.train()
コード例 #6
0
def experiment(variant):
    env = variant['env_class']()

    obs_dim = int(np.prod(env.observation_space.low.shape))
    action_dim = int(np.prod(env.action_space.low.shape))
    vectorized = variant['algo_params']['tdm_kwargs']['vectorized']
    # qf = StructuredQF(
    #     observation_dim=obs_dim,
    #     action_dim=action_dim,
    #     goal_dim=env.goal_dim,
    #     output_size=env.goal_dim if vectorized else 1,
    #     **variant['qf_params']
    # )
    qf = OneHotTauQF(observation_dim=obs_dim,
                     action_dim=action_dim,
                     goal_dim=env.goal_dim,
                     output_size=env.goal_dim if vectorized else 1,
                     **variant['qf_params'])
    vf = FlattenMlp(input_size=obs_dim + env.goal_dim + 1,
                    output_size=env.goal_dim if vectorized else 1,
                    **variant['vf_params'])
    policy = MlpPolicy(input_size=obs_dim + env.goal_dim + 1,
                       output_size=action_dim,
                       **variant['policy_params'])
    es = OUStrategy(
        action_space=env.action_space,
        theta=0.1,
        max_sigma=0.1,
        min_sigma=0.1,
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_params'])
    qf_criterion = variant['qf_criterion_class'](
        **variant['qf_criterion_params'])
    algo_params = variant['algo_params']
    algo_params['n3dpg_kwargs']['qf_criterion'] = qf_criterion
    plotter = Simple1DTdmPlotter(
        tdm=qf,
        # location_lst=np.array([-10, 0, 10]),
        # goal_lst=np.array([-10, 0, 5]),
        location_lst=np.array([-5, 0, 5]),
        goal_lst=np.array([-5, 0, 5]),
        max_tau=algo_params['tdm_kwargs']['max_tau'],
        grid_size=10,
    )
    algo_params['n3dpg_kwargs']['plotter'] = plotter
    algorithm = TdmN3dpg(env,
                         qf=qf,
                         vf=vf,
                         replay_buffer=replay_buffer,
                         policy=policy,
                         exploration_policy=exploration_policy,
                         **algo_params)
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #7
0
def experiment(variant):
    env = variant['env_class']()

    obs_dim = int(np.prod(env.observation_space.low.shape))
    action_dim = int(np.prod(env.action_space.low.shape))
    vectorized = variant['algo_params']['tdm_kwargs']['vectorized']
    if vectorized:
        qf = VectorizedDiscreteQFunction(observation_dim=int(
            np.prod(env.observation_space.low.shape)),
                                         action_dim=env.action_space.n,
                                         goal_dim=env.goal_dim,
                                         **variant['qf_params'])
        policy = ArgmaxDiscreteTdmPolicy(qf, **variant['policy_params'])
    else:
        qf = FlattenMlp(input_size=int(np.prod(env.observation_space.shape)) +
                        env.goal_dim + 1,
                        output_size=env.action_space.n,
                        **variant['qf_params'])
        policy = ArgmaxDiscretePolicy(qf)
    es = OUStrategy(
        action_space=env.action_space,
        theta=0.1,
        max_sigma=0.1,
        min_sigma=0.1,
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_params'])
    qf_criterion = variant['qf_criterion_class'](
        **variant['qf_criterion_params'])
    algo_params = variant['algo_params']
    algo_params['ddpg_kwargs']['qf_criterion'] = qf_criterion
    plotter = Simple1DTdmDiscretePlotter(
        tdm=qf,
        location_lst=np.array([-5, 0, 5]),
        goal_lst=np.array([-5, 0, 5]),
        max_tau=algo_params['tdm_kwargs']['max_tau'],
        grid_size=10,
    )
    algo_params['ddpg_kwargs']['plotter'] = plotter
    algorithm = TdmDdpg(env,
                        qf=qf,
                        replay_buffer=replay_buffer,
                        policy=policy,
                        exploration_policy=exploration_policy,
                        **algo_params)
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #8
0
def experiment(variant):
    env = variant['env_class'](**variant['env_kwargs'])
    tdm_normalizer = None
    qf1 = TdmQf(
        env=env,
        vectorized=True,
        tdm_normalizer=tdm_normalizer,
        **variant['qf_kwargs']
    )
    qf2 = TdmQf(
        env=env,
        vectorized=True,
        tdm_normalizer=tdm_normalizer,
        **variant['qf_kwargs']
    )
    policy = TdmPolicy(
        env=env,
        tdm_normalizer=tdm_normalizer,
        **variant['policy_kwargs']
    )
    es = OUStrategy(
        action_space=env.action_space,
        **variant['es_kwargs']
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(
        env=env,
        **variant['her_replay_buffer_kwargs']
    )
    qf_criterion = variant['qf_criterion_class']()
    algo_kwargs = variant['algo_kwargs']
    algo_kwargs['td3_kwargs']['qf_criterion'] = qf_criterion
    algo_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer
    algorithm = TdmTd3(
        env,
        qf1=qf1,
        qf2=qf2,
        replay_buffer=replay_buffer,
        policy=policy,
        exploration_policy=exploration_policy,
        **variant['algo_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #9
0
def experiment(variant):
    env_params = variant['env_params']
    env = MultiTaskSawyerXYZReachingEnv(env_params)
    tdm_normalizer = TdmNormalizer(
        env,
        vectorized=True,
        max_tau=variant['ddpg_tdm_kwargs']['tdm_kwargs']['max_tau'],
    )
    qf = TdmQf(
        env=env,
        vectorized=True,
        hidden_sizes=[variant['hidden_sizes'], variant['hidden_sizes']],
        structure='norm_difference',
        tdm_normalizer=tdm_normalizer,
    )
    policy = TdmPolicy(
        env=env,
        hidden_sizes=[variant['hidden_sizes'], variant['hidden_sizes']],
        tdm_normalizer=tdm_normalizer,
    )
    es = OUStrategy(
        action_space=env.action_space,
        **variant['es_kwargs']
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(
        env=env,
        **variant['her_replay_buffer_kwargs']
    )
    qf_criterion = variant['qf_criterion_class']()
    ddpg_tdm_kwargs = copy.deepcopy(variant['ddpg_tdm_kwargs'])
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    algorithm = TdmDdpg(
        env,
        qf=qf,
        replay_buffer=replay_buffer,
        policy=policy,
        exploration_policy=exploration_policy,
        **variant['ddpg_tdm_kwargs']
    )
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
コード例 #10
0
def experiment(variant):
    env = NormalizedBoxEnv(variant['env_class']())

    obs_dim = int(np.prod(env.observation_space.low.shape))
    action_dim = int(np.prod(env.action_space.low.shape))
    vectorized = variant['ddpg_tdm_kwargs']['tdm_kwargs']['vectorized']
    qf = StructuredQF(
        observation_dim=obs_dim,
        action_dim=action_dim,
        goal_dim=env.goal_dim,
        output_size=env.goal_dim if vectorized else 1,
        **variant['qf_kwargs']
    )
    policy = TanhMlpPolicy(
        input_size=obs_dim + env.goal_dim + 1,
        output_size=action_dim,
        **variant['policy_kwargs']
    )
    es = OUStrategy(
        action_space=env.action_space,
        **variant['es_kwargs']
    )
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(
        env=env,
        **variant['her_replay_buffer_kwargs']
    )
    qf_criterion = variant['qf_criterion_class'](
        **variant['qf_criterion_kwargs']
    )
    ddpg_tdm_kwargs = variant['ddpg_tdm_kwargs']
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    algorithm = TdmDdpg(
        env,
        qf=qf,
        replay_buffer=replay_buffer,
        policy=policy,
        exploration_policy=exploration_policy,
        **variant['ddpg_tdm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #11
0
def experiment(variant):
    env = NormalizedBoxEnv(Reacher7DofFullGoal())

    obs_dim = int(np.prod(env.observation_space.low.shape))
    action_dim = int(np.prod(env.action_space.low.shape))
    vectorized = variant['sac_tdm_kwargs']['tdm_kwargs']['vectorized']
    qf = FlattenMlp(
        input_size=obs_dim + action_dim + env.goal_dim + 1,
        output_size=env.goal_dim if vectorized else 1,
        **variant['qf_params']
    )
    vf = FlattenMlp(
        input_size=obs_dim + env.goal_dim + 1,
        output_size=env.goal_dim if vectorized else 1,
        **variant['vf_params']
    )
    policy = TanhGaussianPolicy(
        obs_dim=obs_dim + env.goal_dim + 1,
        action_dim=action_dim,
        **variant['policy_params']
    )
    mpc_controller = CollocationMpcController(
        env,
        qf,
        policy,
    )
    variant['sac_tdm_kwargs']['base_kwargs']['eval_policy'] = mpc_controller
    variant['sac_tdm_kwargs']['base_kwargs']['exploration_policy'] = (
        mpc_controller
    )
    replay_buffer = HerReplayBuffer(
        env=env,
        **variant['her_replay_buffer_params']
    )
    algorithm = TdmSac(
        env=env,
        policy=policy,
        qf=qf,
        vf=vf,
        replay_buffer=replay_buffer,
        **variant['sac_tdm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #12
0
def experiment(variant):
    vectorized = variant['sac_tdm_kwargs']['tdm_kwargs']['vectorized']
    # env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    max_tau = variant['sac_tdm_kwargs']['tdm_kwargs']['max_tau']
    tdm_normalizer = TdmNormalizer(
        env,
        vectorized,
        max_tau=max_tau,
        **variant['tdm_normalizer_kwargs']
    )
    qf = TdmQf(
        env=env,
        vectorized=vectorized,
        tdm_normalizer=tdm_normalizer,
        **variant['qf_kwargs']
    )
    vf = TdmVf(
        env=env,
        vectorized=vectorized,
        tdm_normalizer=tdm_normalizer,
        **variant['vf_kwargs']
    )
    policy = StochasticTdmPolicy(
        env=env,
        tdm_normalizer=tdm_normalizer,
        **variant['policy_kwargs']
    )
    replay_buffer = HerReplayBuffer(
        env=env,
        **variant['her_replay_buffer_kwargs']
    )
    algorithm = TdmSac(
        env=env,
        policy=policy,
        qf=qf,
        vf=vf,
        replay_buffer=replay_buffer,
        **variant['sac_tdm_kwargs']
    )
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #13
0
def experiment(variant):
    vectorized = variant['vectorized']
    norm_order = variant['norm_order']

    variant['ddpg_tdm_kwargs']['tdm_kwargs']['vectorized'] = vectorized
    variant['ddpg_tdm_kwargs']['tdm_kwargs']['norm_order'] = norm_order

    env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    max_tau = variant['ddpg_tdm_kwargs']['tdm_kwargs']['max_tau']
    tdm_normalizer = TdmNormalizer(env,
                                   vectorized,
                                   max_tau=max_tau,
                                   **variant['tdm_normalizer_kwargs'])
    qf = TdmQf(env=env,
               vectorized=vectorized,
               norm_order=norm_order,
               tdm_normalizer=tdm_normalizer,
               **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       tdm_normalizer=tdm_normalizer,
                       **variant['policy_kwargs'])
    es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    qf_criterion = variant['qf_criterion_class'](
        **variant['qf_criterion_kwargs'])
    ddpg_tdm_kwargs = variant['ddpg_tdm_kwargs']
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    ddpg_tdm_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer
    algorithm = TdmDdpg(env,
                        qf=qf,
                        replay_buffer=replay_buffer,
                        policy=policy,
                        exploration_policy=exploration_policy,
                        **variant['ddpg_tdm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #14
0
ファイル: her.py プロジェクト: Asap7772/rail-rl-franka-eval
def experiment(variant):
    env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    observation_dim = int(np.prod(env.observation_space.low.shape))
    action_dim = int(np.prod(env.action_space.low.shape))
    obs_normalizer = TorchFixedNormalizer(observation_dim)
    goal_normalizer = TorchFixedNormalizer(env.goal_dim)
    action_normalizer = TorchFixedNormalizer(action_dim)
    distance_normalizer = TorchFixedNormalizer(env.goal_dim)
    tdm_normalizer = TdmNormalizer(env,
                                   obs_normalizer=obs_normalizer,
                                   goal_normalizer=goal_normalizer,
                                   action_normalizer=action_normalizer,
                                   distance_normalizer=distance_normalizer,
                                   max_tau=1,
                                   **variant['tdm_normalizer_kwargs'])
    qf = HerQFunction(env=env, **variant['qf_kwargs'])
    policy = HerPolicy(env=env,
                       tdm_normalizer=tdm_normalizer,
                       **variant['policy_kwargs'])
    es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    qf_criterion = variant['qf_criterion_class'](
        **variant['qf_criterion_kwargs'])
    ddpg_tdm_kwargs = variant['ddpg_tdm_kwargs']
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    algorithm = HER(env,
                    qf=qf,
                    replay_buffer=replay_buffer,
                    policy=policy,
                    exploration_policy=exploration_policy,
                    **variant['ddpg_tdm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #15
0
def experiment(variant):
    env_params = variant['env_params']
    env = MultiTaskSawyerXYZReachingEnv(**env_params)
    max_tau = variant['ddpg_tdm_kwargs']['tdm_kwargs']['max_tau']
    tdm_normalizer = TdmNormalizer(
        env,
        vectorized=True,
        max_tau=max_tau,
    )
    qf = TdmQf(env=env,
               vectorized=True,
               norm_order=2,
               tdm_normalizer=tdm_normalizer,
               **variant['qf_kwargs'])
    policy = TdmPolicy(env=env,
                       tdm_normalizer=tdm_normalizer,
                       **variant['policy_kwargs'])
    es = OUStrategy(action_space=env.action_space, **variant['es_kwargs'])
    exploration_policy = PolicyWrappedWithExplorationStrategy(
        exploration_strategy=es,
        policy=policy,
    )
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    qf_criterion = variant['qf_criterion_class']()
    ddpg_tdm_kwargs = variant['ddpg_tdm_kwargs']
    ddpg_tdm_kwargs['ddpg_kwargs']['qf_criterion'] = qf_criterion
    ddpg_tdm_kwargs['tdm_kwargs']['tdm_normalizer'] = tdm_normalizer
    algorithm = TdmDdpg(env,
                        qf=qf,
                        replay_buffer=replay_buffer,
                        policy=policy,
                        exploration_policy=exploration_policy,
                        **variant['ddpg_tdm_kwargs'])
    if ptu.gpu_enabled():
        algorithm.cuda()
    algorithm.train()
コード例 #16
0
def experiment(variant):
    env = variant['env_class']()

    if variant['algo_params']['tdm_kwargs']['vectorized']:
        qf = VectorizedDiscreteQFunction(observation_dim=int(
            np.prod(env.observation_space.low.shape)),
                                         action_dim=env.action_space.n,
                                         goal_dim=env.goal_dim,
                                         **variant['qf_params'])
    else:
        qf = FlattenMlp(input_size=env.observation_space.low.size +
                        env.goal_dim + 1,
                        output_size=env.action_space.n,
                        **variant['qf_params'])
    policy = ArgmaxDiscreteTdmPolicy(qf, **variant['policy_params'])
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_params'])
    algorithm = TdmDqn(env,
                       qf=qf,
                       replay_buffer=replay_buffer,
                       policy=policy,
                       **variant['algo_params'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #17
0
def experiment(variant):
    vectorized = variant['sac_tdm_kwargs']['tdm_kwargs']['vectorized']
    env = NormalizedBoxEnv(variant['env_class'](**variant['env_kwargs']))
    max_tau = variant['sac_tdm_kwargs']['tdm_kwargs']['max_tau']
    qf = TdmQf(env, vectorized=vectorized, **variant['qf_kwargs'])
    tdm_normalizer = TdmNormalizer(env,
                                   vectorized,
                                   max_tau=max_tau,
                                   **variant['tdm_normalizer_kwargs'])
    implicit_model = TdmToImplicitModel(
        env,
        qf,
        tau=0,
    )
    vf = TdmVf(env=env,
               vectorized=vectorized,
               tdm_normalizer=tdm_normalizer,
               **variant['vf_kwargs'])
    policy = StochasticTdmPolicy(env=env,
                                 tdm_normalizer=tdm_normalizer,
                                 **variant['policy_kwargs'])
    replay_buffer = HerReplayBuffer(env=env,
                                    **variant['her_replay_buffer_kwargs'])
    goal_slice = env.ob_to_goal_slice
    lbfgs_mpc_controller = TdmLBfgsBCMC(implicit_model,
                                        env,
                                        goal_slice=goal_slice,
                                        multitask_goal_slice=goal_slice,
                                        **variant['mpc_controller_kwargs'])
    state_only_mpc_controller = TdmLBfgsBStateOnlyCMC(
        vf,
        policy,
        env,
        goal_slice=goal_slice,
        multitask_goal_slice=goal_slice,
        **variant['state_only_mpc_controller_kwargs'])
    es = GaussianStrategy(action_space=env.action_space,
                          **variant['es_kwargs'])
    if variant['explore_with'] == 'TdmLBfgsBCMC':
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=lbfgs_mpc_controller,
        )
        variant['sac_tdm_kwargs']['base_kwargs']['exploration_policy'] = (
            exploration_policy)
    elif variant['explore_with'] == 'TdmLBfgsBStateOnlyCMC':
        exploration_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=es,
            policy=state_only_mpc_controller,
        )
        variant['sac_tdm_kwargs']['base_kwargs']['exploration_policy'] = (
            exploration_policy)
    if variant['eval_with'] == 'TdmLBfgsBCMC':
        variant['sac_tdm_kwargs']['base_kwargs']['eval_policy'] = (
            lbfgs_mpc_controller)
    elif variant['eval_with'] == 'TdmLBfgsBStateOnlyCMC':
        variant['sac_tdm_kwargs']['base_kwargs']['eval_policy'] = (
            state_only_mpc_controller)
    algorithm = TdmSac(env=env,
                       policy=policy,
                       qf=qf,
                       vf=vf,
                       replay_buffer=replay_buffer,
                       **variant['sac_tdm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()