コード例 #1
0
def experiment(variant):
    from rlkit.envs.gym_minigrid.gym_minigrid import envs

    expl_env = ToolsEnv(**variant['env_kwargs'])
    eval_env = ToolsEnv(**variant['env_kwargs'])

    rollout_env = ToolsEnv(**variant['env_kwargs'])

    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.n
    layer_size = variant['algo_kwargs']['layer_size']
    lifetime = variant['env_kwargs'].get('time_horizon', 0) == 0
    if lifetime:
        assert eval_env.time_horizon == 0, 'cannot have time horizon for lifetime env'

    qf = gen_network(variant['algo_kwargs'], action_dim, layer_size)
    target_qf = gen_network(variant['algo_kwargs'], action_dim, layer_size)

    qf_criterion = nn.MSELoss()
    eval_policy = ArgmaxDiscretePolicy(qf)
    # eval_policy = SoftmaxQPolicy(qf)
    expl_policy = PolicyWrappedWithExplorationStrategy(
        EpsilonGreedyDecay(expl_env.action_space, 1e-5, 1, 0.1),
        eval_policy,
    )
    if lifetime:
        eval_policy = expl_policy
    # expl_policy = PolicyWrappedWithExplorationStrategy(
    #     EpsilonGreedy(expl_env.action_space, 0.5),
    #     eval_policy,
    # )
    if eval_env.time_horizon == 0:
        collector_class = LifetimeMdpPathCollector if lifetime else MdpPathCollector
    else:
        collector_class = MdpPathCollectorConfig
    eval_path_collector = collector_class(
        eval_env,
        eval_policy,
        # render=True
    )
    expl_path_collector = collector_class(expl_env, expl_policy)
    trainer = DoubleDQNTrainer(qf=qf,
                               target_qf=target_qf,
                               qf_criterion=qf_criterion,
                               **variant['algo_kwargs']['trainer_kwargs'])
    replay_buffer = EnvReplayBuffer(
        variant['algo_kwargs']['replay_buffer_size'], expl_env)
    #algo_class = TorchLifetimeRLAlgorithm if lifetime else TorchBatchRLAlgorithm
    algo_class = TorchHumanInputLifetimeRLAlgorithm
    algorithm = algo_class(trainer=trainer,
                           exploration_env=expl_env,
                           evaluation_env=eval_env,
                           rollout_env=rollout_env,
                           exploration_data_collector=expl_path_collector,
                           evaluation_data_collector=eval_path_collector,
                           replay_buffer=replay_buffer,
                           **variant['algo_kwargs']['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
コード例 #2
0
def gen_validation_envs(n, filename, **kwargs):
    envs = []
    seeds = np.random.randint(0, 100000, n).tolist()
    for idx in range(n):
        env_kwargs = dict(
            grid_size=8,
            # start agent at random pos
            agent_start_pos=None,
            health_cap=1000,
            gen_resources=True,
            fully_observed=False,
            task='make axe',
            make_rtype='sparse',
            fixed_reset=False,
            only_partial_obs=True,
            init_resources={
                'metal': 3,
                'wood': 3
            },
            resource_prob={
                'metal': 0.02,
                'wood': 0.02
            },
            fixed_expected_resources=True,
            end_on_task_completion=True,
            time_horizon=200,
            seed=seeds[idx])
        env_kwargs.update(**kwargs)
        env = ToolsEnv(**env_kwargs)
        envs.append(env)
    pickle.dump({'envs': envs, 'seeds': seeds}, open(filename, 'wb'))
    print('Generated %d envs at file: %s' % (n, filename))
コード例 #3
0
def gen_validation_envs(n, filename, **kwargs):
    envs = []
    seeds = np.random.randint(0, 100000, n).tolist()
    for idx in range(n):
        env_kwargs = {
            "agent_start_pos": None,
            "end_on_task_completion": True,
            "fixed_expected_resources": True,
            "fixed_reset": False,
            "fully_observed": False,
            "gen_resources": False,
            "grid_size": 8,
            "health_cap": 1000,
            "init_resources": {
                "metal": 2,
                "wood": 2
            },
            "make_rtype": "sparse",
            "only_partial_obs": True,
            "resource_prob": {
                "metal": 0.0,
                "wood": 0.0
            },
            "task": "make axe",
            "time_horizon": 200
        }
        env_kwargs = dict(
            grid_size=8,
            # start agent at random pos
            agent_start_pos=None,
            health_cap=1000,
            gen_resources=False,
            fully_observed=False,
            task='make axe',
            make_rtype='sparse',
            fixed_reset=False,
            only_partial_obs=True,
            init_resources={
                'metal': 1,
                'wood': 1
            },
            resource_prob={
                'metal': 0.0,
                'wood': 0.0
            },
            fixed_expected_resources=True,
            end_on_task_completion=True,
            time_horizon=100,
            seed=seeds[idx])
        env_kwargs.update(**kwargs)
        env = ToolsEnv(**env_kwargs)
        envs.append(env)
    pickle.dump({'envs': envs, 'seeds': seeds}, open(filename, 'wb'))
    json.dump(env_kwargs,
              open(filename.strip('.pkl') + '.json', 'w'),
              indent=4,
              sort_keys=True)
    print('Generated %d envs at file: %s' % (n, filename))