def get_dqn_agent(env, dqn_agent_subtype, folder='/SRC/pathway/alphaxos/models/', load=False, load_mem=True, side_normalization_factor=1.0): agent_info = dqn_agents[dqn_agent_subtype] model = agent_info['qfn']( out_width=env.action_space.n, side_normalization_factor=side_normalization_factor) # see https://github.com/keras-rl/keras-rl/blob/master/examples/duel_dqn_cartpole.py memory = SequentialMemory(limit=100000, window_length=1) ''' test_policy = ValidGreedyQPolicy() test_policy = ValidGreedyQPolicy() test_policy.env=env policy = ValidEpsGreedyQPolicy(0.1) policy.env=env ''' policy = EpsGreedyQPolicy(regime_params['epsilon-train']) #policy = ValidEpsGreedyQPolicy(0.1) policy.env = env test_policy = None dqn = DQNAgent( model=model, batch_size=regime_params['memory_batch_size'], gamma=regime_params['gamma'], nb_actions=env.action_space.n, memory=memory, nb_steps_warmup=regime_params['steps_warmup'], target_model_update=regime_params['steps_target_model_update'], policy=policy, test_policy=test_policy, enable_double_dqn=True) dqn.compile(Adam(lr=regime_params['learning_rate']), metrics=['mae']) dqn.modelfile = folder + agent_info['modelfile'] dqn.memoryfile = folder + agent_info['memoryfile'] if load: dqn.reload() if load_mem: dqn.reload_memory() #dqn.env=env return dqn