示例#1
0
def get_ens_callbacks(params, env):
    return [
        rcb.EpisodeUpdater(**params.get('memory_update', {})),
        cb.Checkpointer(frequency=10),
        rcb.UncertaintyUpdater(head=hat.EntropyHat()),
        rcb.EnvironmentEvaluator(env=env,
                                 n_evaluations=10,
                                 action_selector=sp.GreedyValueSelection(
                                     post_pipeline=[hat.EnsembleHat()]),
                                 metrics={
                                     'det_val_reward_mean': np.mean,
                                     'deter_val_reward_std': np.std
                                 },
                                 frequency=10,
                                 epoch_name='det_val_epoch'),
        rcb.EnvironmentEvaluator(env=env,
                                 n_evaluations=10,
                                 action_selector=get_selection_strategy(
                                     params['eval_selection_strategy'],
                                     params.get('selection_args', {})),
                                 metrics={
                                     'prob_val_reward_mean': np.mean,
                                     'prob_val_reward_std': np.std
                                 },
                                 frequency=10,
                                 epoch_name='prob_val_epoch'),
        rcb.EnsembleRewardPlotter(frequency=10,
                                  metrics={
                                      'det_val_reward_mean': 'det_val_epoch',
                                      'prob_val_reward_mean': 'prob_val_epoch',
                                  })
    ]
示例#2
0
def run(root, path_script):
    experiment = Experiment(root=root)
    factory = experiment.get_factory()
    params = experiment.get_params()
    # callbacks_factory = experiment.get_factory(source_function='get_ens_callbacks')
    params['factory_args']['learner_args']['dump_path'] = root
    # params['factory_args']['dump_path'] = root

    Model = experiment.get_model_class()
    experiment.document_script(path_script, overwrite=params['overwrite'])
    env = MultiInstanceGym(**params['env_args'])
    params['factory_args']['model_args'][
        'in_nodes'] = env.observation_space.shape[0]
    params['factory_args']['model_args']['out_nodes'] = env.action_space.n
    params['factory_args']['env'] = env

    dqn_player = get_player(key=params['player_type'])
    selection_strategy = get_selection_strategy(
        params['selection_strategy'], params.get('selection_args', {}))

    with with_experiment(experiment=experiment, overwrite=params['overwrite']):
        memory = get_memory(params['memory_type'], params['memory_args'])
        params['factory_args']['learner_args']['memory'] = memory

        learner = DQNEnsemble(
            model_class=Model,
            trainer_factory=factory,
            memory=memory,
            env=env,
            player=dqn_player,
            selection_strategy=selection_strategy,
            trainer_args=params['factory_args'],
            n_model=params['n_learner'],
            dump_path=root,
            callbacks=[
                rcb.EpisodeUpdater(**params.get('memory_update', {}),
                                   frequency=5),
                rcb.UncertaintyUpdater(hat=hat.EntropyHat()),
                cb.Checkpointer(frequency=10),
                rcb.EnvironmentEvaluator(
                    env=env,
                    n_evaluations=10,
                    action_selector=sp.GreedyValueSelection(
                        post_pipeline=[hat.EnsembleHat()]),
                    metrics={
                        'det_val_reward_mean': np.mean,
                        'deter_val_reward_std': np.std
                    },
                    frequency=10,
                    epoch_name='det_val_epoch'),
                rcb.EnvironmentEvaluator(
                    env=env,
                    n_evaluations=10,
                    action_selector=get_selection_strategy(
                        params['eval_selection_strategy'],
                        params.get('selection_args', {})),
                    metrics={
                        'prob_val_reward_mean': np.mean,
                        'prob_val_reward_std': np.std
                    },
                    frequency=10,
                    epoch_name='prob_val_epoch'),
                rcb.EnsembleRewardPlotter(frequency=10,
                                          metrics={
                                              'det_val_reward_mean':
                                              'det_val_epoch',
                                              'prob_val_reward_mean':
                                              'prob_val_epoch',
                                          })
            ])
        learner.fit(**params['fit'])
示例#3
0
    post_pipeline=[EnsembleHatStd()])

memory = PriorityMemory(**params["factory_args"]['memory_args'])
params['factory_args']['learner_args']['memory'] = memory

learner = DQNEnsemble(
    model_class=Model,
    trainer_factory=factory,
    memory=memory,
    env=env,
    player=dqn_player,
    selection_strategy=selection_strategy,
    trainer_args=params['factory_args'],
    n_model=params['n_learner'],
    callbacks=[
        rcb.EpisodeUpdater(**params.get('memory_update', {})),
        cb.Checkpointer(frequency=1),
        rcb.UncertaintyUpdater(),
        rcb.EnvironmentEvaluator(
            env=TorchGym(**params['factory_args']['env_args']),
            n_evaluations=10,
            action_selector=sp.GreedyValueSelection(
                post_pipeline=[EnsembleHat()]),
            metrics={
                'det_val_reward_mean': np.mean,
                'deter_val_reward_std': np.std
            },
            frequency=1,
            epoch_name='det_val_epoch'),
        rcb.EnvironmentEvaluator(
            env=TorchGym(**params['factory_args']['env_args']),
示例#4
0
def run(root, path_script):
    print(root, path_script)
    experiment = Experiment(root=root)
    factory = experiment.get_factory()
    params = experiment.get_params()
    params['factory_args']['learner_args']['dump_path'] = root

    Model = experiment.get_model_class()
    experiment.document_script(path_script, overwrite=params['overwrite'])
    # env = MultiInstanceGym(**params['factory_args']['env_args'])
    env = TorchGym(params['factory_args']['env_args']['env_name'])
    params['factory_args']['model_args'][
        'in_nodes'] = env.observation_space.shape[0]
    params['factory_args']['model_args']['out_nodes'] = env.action_space.n

    dqn_player = DQNPlayer()
    # selection_strategy = sp.AdaptiveQActionSelectionEntropy(warm_up=0,
    #                                                         post_pipeline=[EnsembleHatStd()])
    # selection_strategy = sp.QActionSelection(post_pipeline=[EnsembleHat()])
    selection_strategy = sp.EpsilonGreedyActionSelection(
        action_space=[0, 1, 2, 3], post_pipeline=[EnsembleHat()])

    with with_experiment(experiment=experiment, overwrite=params['overwrite']):
        memory = Memory(**params["factory_args"]['memory_args'])
        params['factory_args']['learner_args']['memory'] = memory

        learner = DQNEnsemble(
            model_class=Model,
            trainer_factory=factory,
            memory=memory,
            env=env,
            player=dqn_player,
            selection_strategy=selection_strategy,
            trainer_args=params['factory_args'],
            n_model=params['n_learner'],
            callbacks=[
                rcb.EpisodeUpdater(**params.get('memory_update', {})),
                cb.Checkpointer(frequency=1),
                # rcb.UncertaintyUpdater(),
                rcb.EnvironmentEvaluator(
                    env=TorchGym(
                        params['factory_args']['env_args']['env_name']),
                    n_evaluations=10,
                    action_selector=sp.GreedyValueSelection(
                        post_pipeline=[EnsembleHat()]),
                    metrics={
                        'det_val_reward_mean': np.mean,
                        'deter_val_reward_std': np.std
                    },
                    frequency=1,
                    epoch_name='det_val_epoch'),
                rcb.EnvironmentEvaluator(env=TorchGym(
                    params['factory_args']['env_args']['env_name']),
                                         n_evaluations=10,
                                         action_selector=selection_strategy,
                                         metrics={
                                             'prob_val_reward_mean': np.mean,
                                             'prob_val_reward_std': np.std
                                         },
                                         frequency=1,
                                         epoch_name='prob_val_epoch'),
                rcb.EnsembleRewardPlotter(
                    metrics={
                        'det_val_reward_mean': 'det_val_epoch',
                        'prob_val_reward_mean': 'prob_val_epoch',
                    }),
            ])

        # learner.load_checkpoint(path=f'{root}/checkpoint', tag='checkpoint')
        learner.fit(**params['fit'])