示例#1
0
def factory(Model, model_args, env_args, optim_args, memory_args, learner_args,
            name):
    model = Model(**model_args)
    env = CartPole(**env_args)

    optim = torch.optim.SGD(model.parameters(), **optim_args)
    crit = torch.nn.MSELoss()
    l_args = dict(learner_args)
    l_args['name'] = f"{learner_args['name']}_{name}"

    return pg.QLearner(env=env,
                       model=model,
                       optimizer=optim,
                       crit=crit,
                       action_selector=sp.QActionSelection(temperature=.3),
                       callbacks=[rcb.MemoryUpdater(1.)],
                       **l_args)
示例#2
0
def factory(Model, model_args, env_args, optim_args, memory_args, learner_args,
            crit_args, temp, name):
    model = Model(**model_args)
    env = TorchGym(env_args['env_name'])

    optim = torch.optim.SGD(model.parameters(), **optim_args)
    crit = torch.nn.MSELoss(**crit_args)

    l_args = dict(learner_args)
    l_args['name'] = f"{learner_args['name']}_{name}"

    return rl.QLearner(env=env,
                       model=model,
                       optimizer=optim,
                       crit=crit,
                       action_selector=sp.QActionSelection(temperature=temp),
                       callbacks=[],
                       **l_args)
示例#3
0
def factory(Model, core, model_args, env_args, optim_args, memory_args,
            learner_args, name):
    model = Model(core, **model_args)
    env = CartPole(**env_args)
    memory_updater = MemoryUpdater(**memory_args)

    optim = torch.optim.SGD(model.parameters(), **optim_args)
    crit = torch.nn.MSELoss()
    learner_args['name'] = name

    return pg.QLearner(env=env,
                       model=model,
                       optimizer=optim,
                       memory_updater=memory_updater,
                       crit=crit,
                       action_selector=sp.QActionSelection(temperature=.3),
                       callbacks=[
                           cb.Checkpointer(),
                           rcb.EnvironmentEvaluator(env=env,
                                                    n_evaluations=10,
                                                    frequency=1),
                       ],
                       **learner_args)
示例#4
0
def factory(Model, model_args, optim_args, learner_args, crit_args, name, env):
    model = Model(**model_args)
    # env = TorchGym(**env_args)

    optim = torch.optim.SGD(model.parameters(), **optim_args)
    crit = torch.nn.MSELoss(**crit_args)

    l_args = dict(learner_args)
    l_args['name'] = f"{learner_args['name']}_{name}"

    return rl.QLearner(
        env=env,
        model=model,
        optimizer=optim,
        crit=crit,
        action_selector=None,
        # memory=Memory(**memory_args),
        callbacks=[
            # rcb.EnvironmentEvaluator(env=env,
            #                          n_evaluations=10,
            #                          frequency=10,
            #                          action_selector=sp.GreedyValueSelection()),
        ],
        **l_args)
示例#5
0
crit = torch.nn.MSELoss()
memory_updater = MemoryUpdater(memory_refresh_rate=.1)

learner = pg.QLearner(env=env,
                      model=model,
                      optimizer=optim,
                      memory_updater=memory_updater,
                      crit=crit,
                      action_selector=pg.QActionSelection(temperature=.3),
                      # action_selector=pg.EpsilonGreedyActionSelection(action_space=np.arange(env.action_space.n),
                      #                                                 epsilon=.95),
                      gamma=.95,
                      alpha=.2,
                      batch_size=256,
                      n_samples=8000,
                      grad_clip=5.,
                      memory_size=10000,
                      load_checkpoint=False,
                      name='test_Q',
                      callbacks=[
                          rcb.EnvironmentEvaluator(env=env, n_evaluations=10, frequency=5),
                          # rcb.AgentVisualizer(env=env, frequency=5),
                          cb.MetricPlotter(frequency=1, metric='rewards', smoothing_window=100),
                          cb.MetricPlotter(frequency=1, metric='train_losses', smoothing_window=100),
                          cb.MetricPlotter(frequency=1, metric='avg_reward', smoothing_window=5),
                          cb.MetricPlotter(frequency=5, metric='val_reward', x='val_epoch', smoothing_window=5),
                      ],
                      dump_path='tests/q_learner/tmp',
                      device='cpu')

learner.fit(30, 'cpu', restore_early_stopping=False, verbose=False)