Example #1
0
def get_dqn_agent(env,
                  dqn_agent_subtype,
                  folder='/SRC/pathway/alphaxos/models/',
                  load=False,
                  load_mem=True,
                  side_normalization_factor=1.0):

    agent_info = dqn_agents[dqn_agent_subtype]

    model = agent_info['qfn'](
        out_width=env.action_space.n,
        side_normalization_factor=side_normalization_factor)

    # see https://github.com/keras-rl/keras-rl/blob/master/examples/duel_dqn_cartpole.py
    memory = SequentialMemory(limit=100000, window_length=1)
    '''
	test_policy = ValidGreedyQPolicy()
	test_policy = ValidGreedyQPolicy()
	test_policy.env=env
	policy = ValidEpsGreedyQPolicy(0.1)
	policy.env=env
	'''

    policy = EpsGreedyQPolicy(regime_params['epsilon-train'])
    #policy = ValidEpsGreedyQPolicy(0.1)
    policy.env = env

    test_policy = None

    dqn = DQNAgent(
        model=model,
        batch_size=regime_params['memory_batch_size'],
        gamma=regime_params['gamma'],
        nb_actions=env.action_space.n,
        memory=memory,
        nb_steps_warmup=regime_params['steps_warmup'],
        target_model_update=regime_params['steps_target_model_update'],
        policy=policy,
        test_policy=test_policy,
        enable_double_dqn=True)
    dqn.compile(Adam(lr=regime_params['learning_rate']), metrics=['mae'])

    dqn.modelfile = folder + agent_info['modelfile']
    dqn.memoryfile = folder + agent_info['memoryfile']

    if load:
        dqn.reload()

    if load_mem:
        dqn.reload_memory()

    #dqn.env=env
    return dqn