示例#1
0
def get_selection_policy(key, params):
    if key == 'AdaptiveQSelection':
        return AdaptiveQActionSelectionEntropy(post_pipeline=[hat.EntropyHat()], **params)
    if key == 'QSelectionCertainty':
        return QActionSelectionCertainty(post_pipeline=[hat.EntropyHat()], **params)
    if key == 'QSelection':
        return QActionSelection(post_pipeline=[hat.EnsembleHat()], **params)
    if key == 'DuelingQSelection':
        return QActionSelection(post_pipeline=[hat.IdxFilterHat(var_idx=1), hat.EnsembleHat()], **params)
    if key == 'EpsilonGreedy':
        return EpsilonGreedyActionSelection(post_pipeline=[hat.EnsembleHat()], **params)
    if key == 'AdaptiveEpsilonGreedy':
        return AdaptiveEpsilonGreedyActionSelection(post_pipeline=[hat.EntropyHat()], **params)
    if key == 'Greedy':
        return GreedyValueSelection(post_pipeline=[hat.EnsembleHat()], **params)
    if key == 'ThompsonGreedy':
        return EpsilonGreedyActionSelection(post_pipeline=[hat.ThompsonAggregation()], **params)
    raise ValueError('Unknown selection strategy')
示例#2
0
def get_selection_strategy(key, params):
    if key == 'AdaptiveQSelection':
        return sp.AdaptiveQActionSelectionEntropy(
            post_pipeline=[hat.EntropyHat()], **params)
    if key == 'QSelectionCertainty':
        return sp.QActionSelectionCertainty(post_pipeline=[hat.EntropyHat()],
                                            **params)
    if key == 'QSelection':
        return sp.QActionSelection(post_pipeline=[hat.EnsembleHat()], **params)
    if key == 'EpsilonGreedy':
        return sp.EpsilonGreedyActionSelection(
            post_pipeline=[hat.EnsembleHat()], **params)
    if key == 'AdaptiveEpsilonGreedy':
        return sp.AdaptiveEpsilonGreedyActionSelection(
            post_pipeline=[hat.EntropyHat()], **params)
    if key == 'Greedy':
        return sp.GreedyValueSelection(post_pipeline=[hat.EnsembleHat()],
                                       **params)
    raise ValueError('Unknown selection strategy')
示例#3
0
def run(root, path_script):
    experiment = Experiment(root=root)
    factory = experiment.get_factory()
    params = experiment.get_params()
    # callbacks_factory = experiment.get_factory(source_function='get_ens_callbacks')
    params['factory_args']['learner_args']['dump_path'] = root
    # params['factory_args']['dump_path'] = root

    Model = experiment.get_model_class()
    experiment.document_script(path_script, overwrite=params['overwrite'])
    env = MultiInstanceGym(**params['env_args'])
    params['factory_args']['model_args'][
        'in_nodes'] = env.observation_space.shape[0]
    params['factory_args']['model_args']['out_nodes'] = env.action_space.n
    params['factory_args']['env'] = env

    dqn_player = get_player(key=params['player_type'])
    selection_strategy = get_selection_strategy(
        params['selection_strategy'], params.get('selection_args', {}))

    with with_experiment(experiment=experiment, overwrite=params['overwrite']):
        memory = get_memory(params['memory_type'], params['memory_args'])
        params['factory_args']['learner_args']['memory'] = memory

        learner = DQNEnsemble(
            model_class=Model,
            trainer_factory=factory,
            memory=memory,
            env=env,
            player=dqn_player,
            selection_strategy=selection_strategy,
            trainer_args=params['factory_args'],
            n_model=params['n_learner'],
            dump_path=root,
            callbacks=[
                rcb.EpisodeUpdater(**params.get('memory_update', {}),
                                   frequency=5),
                rcb.UncertaintyUpdater(hat=hat.EntropyHat()),
                cb.Checkpointer(frequency=10),
                rcb.EnvironmentEvaluator(
                    env=env,
                    n_evaluations=10,
                    action_selector=sp.GreedyValueSelection(
                        post_pipeline=[hat.EnsembleHat()]),
                    metrics={
                        'det_val_reward_mean': np.mean,
                        'deter_val_reward_std': np.std
                    },
                    frequency=10,
                    epoch_name='det_val_epoch'),
                rcb.EnvironmentEvaluator(
                    env=env,
                    n_evaluations=10,
                    action_selector=get_selection_strategy(
                        params['eval_selection_strategy'],
                        params.get('selection_args', {})),
                    metrics={
                        'prob_val_reward_mean': np.mean,
                        'prob_val_reward_std': np.std
                    },
                    frequency=10,
                    epoch_name='prob_val_epoch'),
                rcb.EnsembleRewardPlotter(frequency=10,
                                          metrics={
                                              'det_val_reward_mean':
                                              'det_val_epoch',
                                              'prob_val_reward_mean':
                                              'prob_val_epoch',
                                          })
            ])
        learner.fit(**params['fit'])