예제 #1
0
def factory(Model, model_args, env_args, optim_args, memory_args, learner_args,
            name):
    model = Model(**model_args)
    env = CartPole(**env_args)

    optim = torch.optim.SGD(model.parameters(), **optim_args)
    crit = torch.nn.MSELoss()
    l_args = dict(learner_args)
    l_args['name'] = f"{learner_args['name']}_{name}"

    return pg.QLearner(env=env,
                       model=model,
                       optimizer=optim,
                       crit=crit,
                       action_selector=sp.QActionSelection(temperature=.3),
                       callbacks=[rcb.MemoryUpdater(1.)],
                       **l_args)
예제 #2
0
파일: factory.py 프로젝트: raharth/PyMatch
def factory(Model, model_args, env_args, optim_args, memory_args, learner_args,
            crit_args, temp, name):
    model = Model(**model_args)
    env = TorchGym(env_args['env_name'])

    optim = torch.optim.SGD(model.parameters(), **optim_args)
    crit = torch.nn.MSELoss(**crit_args)

    l_args = dict(learner_args)
    l_args['name'] = f"{learner_args['name']}_{name}"

    return rl.QLearner(env=env,
                       model=model,
                       optimizer=optim,
                       crit=crit,
                       action_selector=sp.QActionSelection(temperature=temp),
                       callbacks=[],
                       **l_args)
예제 #3
0
def get_selection_strategy(key, params):
    if key == 'AdaptiveQSelection':
        return sp.AdaptiveQActionSelectionEntropy(
            post_pipeline=[hat.EntropyHat()], **params)
    if key == 'QSelectionCertainty':
        return sp.QActionSelectionCertainty(post_pipeline=[hat.EntropyHat()],
                                            **params)
    if key == 'QSelection':
        return sp.QActionSelection(post_pipeline=[hat.EnsembleHat()], **params)
    if key == 'EpsilonGreedy':
        return sp.EpsilonGreedyActionSelection(
            post_pipeline=[hat.EnsembleHat()], **params)
    if key == 'AdaptiveEpsilonGreedy':
        return sp.AdaptiveEpsilonGreedyActionSelection(
            post_pipeline=[hat.EntropyHat()], **params)
    if key == 'Greedy':
        return sp.GreedyValueSelection(post_pipeline=[hat.EnsembleHat()],
                                       **params)
    raise ValueError('Unknown selection strategy')
예제 #4
0
def factory(Model, core, model_args, env_args, optim_args, memory_args,
            learner_args, name):
    model = Model(core, **model_args)
    env = CartPole(**env_args)
    memory_updater = MemoryUpdater(**memory_args)

    optim = torch.optim.SGD(model.parameters(), **optim_args)
    crit = torch.nn.MSELoss()
    learner_args['name'] = name

    return pg.QLearner(env=env,
                       model=model,
                       optimizer=optim,
                       memory_updater=memory_updater,
                       crit=crit,
                       action_selector=sp.QActionSelection(temperature=.3),
                       callbacks=[
                           cb.Checkpointer(),
                           rcb.EnvironmentEvaluator(env=env,
                                                    n_evaluations=10,
                                                    frequency=1),
                       ],
                       **learner_args)
예제 #5
0
        rcb.UncertaintyUpdater(),
        rcb.EnvironmentEvaluator(
            env=TorchGym(**params['factory_args']['env_args']),
            n_evaluations=10,
            action_selector=sp.GreedyValueSelection(
                post_pipeline=[EnsembleHat()]),
            metrics={
                'det_val_reward_mean': np.mean,
                'deter_val_reward_std': np.std
            },
            frequency=1,
            epoch_name='det_val_epoch'),
        rcb.EnvironmentEvaluator(
            env=TorchGym(**params['factory_args']['env_args']),
            n_evaluations=10,
            action_selector=sp.QActionSelection(temperature=params['temp'],
                                                post_pipeline=[EnsembleHat()]),
            metrics={
                'prob_val_reward_mean': np.mean,
                'prob_val_reward_std': np.std
            },
            frequency=1,
            epoch_name='prob_val_epoch'),
        rcb.EnsembleRewardPlotter(
            metrics={
                'det_val_reward_mean': 'det_val_epoch',
                'prob_val_reward_mean': 'prob_val_epoch',
            }),
    ])

learner.fit(**params['fit'])