def run_dqn(**kargs): if kargs['output_dir'] is None and kargs['logdir'] is not None: kargs['output_dir'] = kargs['logdir'] q_model_initial = kargs[ 'q_model_initial'] if 'q_model_initial' in kargs else None from collections import namedtuple args = namedtuple("DQNParams", kargs.keys())(*kargs.values()) if 'dont_init_tf' not in kargs.keys() or not kargs['dont_init_tf']: #init_nn_library(True, "1") init_nn_library("gpu" in kargs and kargs["gpu"] is not None, kargs["gpu"] if "gpu" in kargs else "1") #if args.atari: # env = gym_env(args.game + 'NoFrameskip-v0') # env = WarmUp(env, min_step=0, max_step=30) # env = ActionRepeat(env, 4) # #q_model = A3CModel(modelOps) #else: # if args.game == "Grid": # env = GridEnv() # else: # env = gym_env(args.game) # #q_model = TabularQModel(modelOps) #for trans in args.env_transforms: # env = globals()[trans](env) if 'use_env' in kargs and kargs['use_env'] is not None: env = kargs['use_env'] else: env = get_env(args.game, args.atari, args.env_transforms, kargs['monitor_dir'] if 'monitor_dir' in kargs else None) if 'env_model' in kargs and kargs['env_model'] is not None and kargs[ 'env_weightfile'] is not None: print('Using simulated environment') envOps = EnvOps(env.observation_space.shape, env.action_space.n, args.learning_rate) env_model = globals()[kargs['env_model']](envOps) env_model.model.load_weights(kargs['env_weightfile']) env = SimulatedEnv(env, env_model, use_reward='env_reward' in kargs and kargs['env_reward']) modelOps = DqnOps(env.action_count) modelOps.dueling_network = args.dueling_dqn viewer = None if args.enable_render: viewer = EnvViewer(env, args.render_step, 'human') if args.atari: proproc = PreProPipeline( [GrayPrePro(), ResizePrePro(modelOps.INPUT_SIZE)]) rewproc = PreProPipeline([RewardClipper(-1, 1)]) else: if env.observation_space.__class__.__name__ is 'Discrete': modelOps.INPUT_SIZE = env.observation_space.n else: modelOps.INPUT_SIZE = env.observation_space.shape modelOps.AGENT_HISTORY_LENGTH = 1 proproc = None rewproc = None modelOps.LEARNING_RATE = args.learning_rate if q_model_initial is None: q_model = globals()[args.model](modelOps) else: q_model = q_model_initial if not args.load_weightfile is None: q_model.model.load_weights(args.load_weightfile) summary_writer = tf.summary.FileWriter( args.logdir, K.get_session().graph) if not args.logdir is None else None agentOps = DqnAgentOps() agentOps.double_dqn = args.double_dqn agentOps.mode = args.mode if args.mode == "train": agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = args.target_network_update replay_buffer = None if args.replay_buffer_size > 0: if 'load_trajectory' in kargs and kargs['load_trajectory'] is not None: replay_buffer = TrajectoryReplay(kargs['load_trajectory'], kargs['batch_size'], args.update_frequency, args.replay_start_size) else: replay_buffer = ReplayBuffer(args.replay_buffer_size, modelOps.AGENT_HISTORY_LENGTH, args.update_frequency, args.replay_start_size, args.batch_size) #replay_buffer = NStepBuffer(modelOps.AGENT_HISTORY_LENGTH, 8) agent = DqnAgent(env.action_space, q_model, replay_buffer, rewproc, agentOps, summary_writer) egreedyOps = EGreedyOps() if replay_buffer is not None: egreedyOps.REPLAY_START_SIZE = replay_buffer.REPLAY_START_SIZE egreedyOps.mode = args.mode egreedyOps.test_epsilon = args.test_epsilon #egreedyOps.FINAL_EXPLORATION_FRAME = 10000 if args.mode == "train": egreedyOps.FINAL_EXPLORATION_FRAME = args.egreedy_final_step if args.mode == "train": if args.egreedy_decay < 1: egreedyOps.DECAY = args.egreedy_decay egreedyAgent = EGreedyAgentExp(env.action_space, egreedyOps, agent) else: egreedyAgent = MultiEGreedyAgent( env.action_space, egreedyOps, agent, args.egreedy_props, args.egreedy_final, final_exp_frame=args.egreedy_final_step) else: egreedyAgent = EGreedyAgent(env.action_space, egreedyOps, agent) runner = Runner(env, egreedyAgent, proproc, modelOps.AGENT_HISTORY_LENGTH, max_step=args.max_step, max_episode=args.max_episode) if replay_buffer is not None: runner.listen(replay_buffer, proproc) runner.listen(agent, None) runner.listen(egreedyAgent, None) if viewer is not None: runner.listen(viewer, None) if args.output_dir is not None: networkSaver = NetworkSaver( 50000 if 'save_interval' not in kargs else kargs['save_interval'], args.output_dir, q_model.model) runner.listen(networkSaver, None) return runner, agent
agentOps = DqnAgentOps() agentOps.double_dqn = args.double_dqn agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = 20 #agentOps.REPLAY_START_SIZE = 100 #agentOps.FINAL_EXPLORATION_FRAME = 10000 replay_buffer = ReplayBuffer(int(2000), 1, 1, 1000, 64) #replay_buffer = NStepBuffer(modelOps.AGENT_HISTORY_LENGTH, 8) agent = DqnAgent(env.action_space, q_model, replay_buffer, rewproc, agentOps, summary_writer) egreedyOps = EGreedyOps() egreedyOps.REPLAY_START_SIZE = replay_buffer.REPLAY_START_SIZE egreedyOps.FINAL_EXPLORATION_FRAME = 10000 egreedyOps.FINAL_EXPLORATION = 0.01 egreedyOps.DECAY = 0.999 egreedyAgent = EGreedyAgentExp(env.action_space, egreedyOps, agent) runner = Runner(env, egreedyAgent, proproc, 1) runner.listen(replay_buffer, proproc) runner.listen(agent, None) runner.listen(egreedyAgent, None) if viewer is not None: runner.listen(viewer, None) if args.logdir is not None: networkSaver = NetworkSaver(50000, args.logdir, q_model.model) runner.listen(networkSaver, None) runner.run()
class AgentThread(StoppableThread, RunnerListener): def __init__(self, threadId, sess, graph): StoppableThread.__init__(self) self.threadId = threadId self.sess = sess self.graph = graph with self.graph.as_default(): if args.game == "Grid": env = GridEnv() else: env = gym_env(args.game) env = Penalizer(env) proproc = None rewproc = None q_model = CartPoleModel(modelOps) q_model.model_update = model.model q_model.set_weights(model.get_weights()) summary_writer = tf.summary.FileWriter( args.logdir + '/thread-' + str(threadId), K.get_session().graph) if not args.logdir is None else None agentOps = DqnAgentOps() agentOps.double_dqn = args.double_dqn agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = 20 agentOps.REPLAY_START_SIZE = 1 #agentOps.INITIAL_EXPLORATION = 0 agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = 1e10 #replay_buffer = ReplayBuffer(int(1e6), 4, 4, agentOps.REPLAY_START_SIZE, 32) replay_buffer = NStepBuffer(modelOps.AGENT_HISTORY_LENGTH, args.nstep) agent = DqnAgent(env.action_space, q_model, replay_buffer, rewproc, agentOps, summary_writer, model_eval=model_eval) # egreedyOps = EGreedyOps() egreedyOps.REPLAY_START_SIZE = 1 egreedyOps.FINAL_EXPLORATION_FRAME = 5000 egreedyOps.FINAL_EXPLORATION = 0.01 egreedyOps.DECAY = 0.999 egreedyAgent = MultiEGreedyAgent(env.action_space, egreedyOps, agent, [0.4, 0.3, 0.3], [0.1, 0.01, 0.5]) self.runner = Runner(env, egreedyAgent, proproc, modelOps.AGENT_HISTORY_LENGTH) self.runner.listen(replay_buffer, proproc) self.runner.listen(agent, None) self.runner.listen(egreedyAgent, None) self.runner.listen(self, proproc) pass def run(self): with self.graph.as_default(): self.runner.run() def on_step(self, ob, action, next_ob, reward, done): global T global model, model_eval with tLock: T = T + 1 #if T % 1000 == 0: # print('STEP', T) if T % target_network_update_freq == 0: print('CLONE TARGET') model_eval.set_weights(model.get_weights()) for agent in agents: agent.model_eval = model_eval if T % args.render_step == 0 and ENABLE_RENDER: viewer.imshow( np.repeat(np.reshape(ob, ob.shape + (1, )), 3, axis=2)) if T % SAVE_FREQ == 0 and args.mode == "train": if not args.output_dir is None: model.model.save_weights(args.output_dir + '/weights_{0}.h5'.format(T)) #print(T) def stop(self): super(AgentThread, self).stop() self.runner.stop()
class AgentThread(StoppableThread, RunnerListener): def __init__(self, threadId, sess, graph): StoppableThread.__init__(self) self.threadId = threadId self.sess = sess self.graph = graph with self.graph.as_default(): if args.atari: env = gym_env(args.game + 'NoFrameskip-v0') env = WarmUp(env, min_step=0, max_step=30) env = ActionRepeat(env, 4) proproc = PreProPipeline( [GrayPrePro(), ResizePrePro(modelOps.INPUT_SIZE)]) rewproc = PreProPipeline([RewardClipper(-1, 1)]) #q_model = A3CModel(modelOps) else: if args.game == "Grid": env = GridEnv() else: env = gym_env(args.game) proproc = None rewproc = None #q_model = TabularQModel(modelOps) for trans in args.env_transforms: env = globals()[trans](env) if 'shared_model' in kargs and kargs['shared_model']: q_model = model else: q_model = globals()[args.model](modelOps) q_model.model_update = model.model q_model.set_weights(model.get_weights()) summary_writer = tf.summary.FileWriter( args.logdir + '/thread-' + str(threadId), K.get_session().graph) if not args.logdir is None else None agentOps = DqnAgentOps() agentOps.double_dqn = args.double_dqn agentOps.REPLAY_START_SIZE = 1 #agentOps.INITIAL_EXPLORATION = 0 agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = 1e10 #replay_buffer = ReplayBuffer(int(1e6), 4, 4, agentOps.REPLAY_START_SIZE, 32) replay_buffer = None #if threadId > 0: if args.nstep > 0: replay_buffer = NStepBuffer(modelOps.AGENT_HISTORY_LENGTH, args.nstep) else: replay_buffer = ReplayBuffer(args.replay_buffer_size, modelOps.AGENT_HISTORY_LENGTH, args.update_frequency, args.replay_start_size, args.batch_size) #print(kargs['agent']) if kargs['agent'] == 'ActorCriticAgent': agent = ActorCriticAgent(env.action_space, q_model, replay_buffer, rewproc, agentOps, summary_writer, ac_model_update=model) # else: agent = DqnAgent(env.action_space, q_model, replay_buffer, rewproc, agentOps, summary_writer, model_eval=model_eval) # egreedyAgent = None if threadId > 0 and kargs[ 'agent'] != 'ActorCriticAgent': # first thread is for testing egreedyOps = EGreedyOps() egreedyOps.REPLAY_START_SIZE = replay_buffer.REPLAY_START_SIZE #egreedyOps.FINAL_EXPLORATION_FRAME = int(args.egreedy_final_step / args.thread_count) #if args.egreedy_decay<1: # egreedyAgent = EGreedyAgentExp(env.action_space, egreedyOps, agent) #else: if len(args.egreedy_props ) > 1 and args.egreedy_props[0] == round( args.egreedy_props[0]): cs = np.array(args.egreedy_props) cs = np.cumsum(cs) idx = np.searchsorted(cs, threadId) print('Egreedyagent selected', idx, args.egreedy_final[idx], args.egreedy_decay[idx], args.egreedy_final_step[idx]) egreedyAgent = MultiEGreedyAgent( env.action_space, egreedyOps, agent, [1], [args.egreedy_final[idx]], [args.egreedy_decay[idx]], [args.egreedy_final_step[idx]]) else: egreedyAgent = MultiEGreedyAgent( env.action_space, egreedyOps, agent, args.egreedy_props, args.egreedy_final, args.egreedy_decay, args.egreedy_final_step) self.runner = Runner( env, egreedyAgent if egreedyAgent is not None else agent, proproc, modelOps.AGENT_HISTORY_LENGTH) if replay_buffer is not None: self.runner.listen(replay_buffer, proproc) self.runner.listen(agent, None) if egreedyAgent is not None: self.runner.listen(egreedyAgent, None) self.runner.listen(self, proproc) self.agent = agent self.q_model = q_model pass def run(self): with self.graph.as_default(): self.runner.run() def on_step(self, ob, action, next_ob, reward, done): global T global model, model_eval with tLock: T = T + 1 if T % target_network_update_freq == 0 and kargs[ 'agent'] != 'ActorCriticAgent': #print('CLONE TARGET: ' + str(T)) model_eval.set_weights(model.get_weights()) for agent in agents: agent.model_eval = model_eval if T % SAVE_FREQ == 0 and args.mode == "train": if not args.output_dir is None: model.model.save_weights(args.output_dir + '/weights_{0}.h5'.format(T)) #if T % 1000 == 0: # print('STEP', T) #if self.threadId == 0 and T % 10 == 0: # self.q_model.set_weights(model.get_weights()) if T % args.render_step == 0 and ENABLE_RENDER: viewer.imshow( np.repeat(np.reshape(ob, ob.shape + (1, )), 3, axis=2)) if T > args.max_step: self.stop() #print(T) def stop(self): super(AgentThread, self).stop() self.runner.stop()
#print(env.observation_space.n) modelOps = DqnOps(env.action_count) modelOps.dueling_network = args.dueling_dqn modelOps.INPUT_SIZE = env.observation_space.n modelOps.LEARNING_RATE = 0.2 q_model = TabularQModel(modelOps) summary_writer = tf.summary.FileWriter(args.logdir, K.get_session().graph) if not args.logdir is None else None agentOps = DqnAgentOps() agentOps.double_dqn = args.double_dqn replay_buffer = NStepBuffer(1, args.nstep) agent = DqnAgent(env.action_space, q_model, replay_buffer, None, agentOps, summary_writer) egreedyOps = EGreedyOps() egreedyOps.REPLAY_START_SIZE = 1 egreedyOps.FINAL_EXPLORATION_FRAME = 10000 egreedyAgent = EGreedyAgent(env.action_space, egreedyOps, agent) runner = Runner(env, egreedyAgent, None, 1) runner.listen(replay_buffer, None) runner.listen(agent, None) runner.listen(egreedyAgent, None) runner.run()