Example #1
0
def run_dqn(**kargs):
    if kargs['output_dir'] is None and kargs['logdir'] is not None:
        kargs['output_dir'] = kargs['logdir']

    q_model_initial = kargs[
        'q_model_initial'] if 'q_model_initial' in kargs else None

    from collections import namedtuple
    args = namedtuple("DQNParams", kargs.keys())(*kargs.values())

    if 'dont_init_tf' not in kargs.keys() or not kargs['dont_init_tf']:
        #init_nn_library(True, "1")
        init_nn_library("gpu" in kargs and kargs["gpu"] is not None,
                        kargs["gpu"] if "gpu" in kargs else "1")

    #if args.atari:
    #	env = gym_env(args.game + 'NoFrameskip-v0')
    #	env = WarmUp(env, min_step=0, max_step=30)
    #	env = ActionRepeat(env, 4)
    #	#q_model = A3CModel(modelOps)
    #else:
    #	if args.game == "Grid":
    #		env = GridEnv()
    #	else:
    #		env = gym_env(args.game)
    #	#q_model = TabularQModel(modelOps)
    #for trans in args.env_transforms:
    #	env = globals()[trans](env)
    if 'use_env' in kargs and kargs['use_env'] is not None:
        env = kargs['use_env']
    else:
        env = get_env(args.game, args.atari, args.env_transforms,
                      kargs['monitor_dir'] if 'monitor_dir' in kargs else None)
        if 'env_model' in kargs and kargs['env_model'] is not None and kargs[
                'env_weightfile'] is not None:
            print('Using simulated environment')
            envOps = EnvOps(env.observation_space.shape, env.action_space.n,
                            args.learning_rate)
            env_model = globals()[kargs['env_model']](envOps)
            env_model.model.load_weights(kargs['env_weightfile'])
            env = SimulatedEnv(env,
                               env_model,
                               use_reward='env_reward' in kargs
                               and kargs['env_reward'])

    modelOps = DqnOps(env.action_count)
    modelOps.dueling_network = args.dueling_dqn

    viewer = None
    if args.enable_render:
        viewer = EnvViewer(env, args.render_step, 'human')
    if args.atari:
        proproc = PreProPipeline(
            [GrayPrePro(), ResizePrePro(modelOps.INPUT_SIZE)])
        rewproc = PreProPipeline([RewardClipper(-1, 1)])
    else:
        if env.observation_space.__class__.__name__ is 'Discrete':
            modelOps.INPUT_SIZE = env.observation_space.n
        else:
            modelOps.INPUT_SIZE = env.observation_space.shape
        modelOps.AGENT_HISTORY_LENGTH = 1
        proproc = None
        rewproc = None

    modelOps.LEARNING_RATE = args.learning_rate
    if q_model_initial is None:
        q_model = globals()[args.model](modelOps)
    else:
        q_model = q_model_initial

    if not args.load_weightfile is None:
        q_model.model.load_weights(args.load_weightfile)

    summary_writer = tf.summary.FileWriter(
        args.logdir,
        K.get_session().graph) if not args.logdir is None else None

    agentOps = DqnAgentOps()
    agentOps.double_dqn = args.double_dqn
    agentOps.mode = args.mode
    if args.mode == "train":
        agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = args.target_network_update

    replay_buffer = None
    if args.replay_buffer_size > 0:
        if 'load_trajectory' in kargs and kargs['load_trajectory'] is not None:
            replay_buffer = TrajectoryReplay(kargs['load_trajectory'],
                                             kargs['batch_size'],
                                             args.update_frequency,
                                             args.replay_start_size)
        else:
            replay_buffer = ReplayBuffer(args.replay_buffer_size,
                                         modelOps.AGENT_HISTORY_LENGTH,
                                         args.update_frequency,
                                         args.replay_start_size,
                                         args.batch_size)
    #replay_buffer = NStepBuffer(modelOps.AGENT_HISTORY_LENGTH, 8)
    agent = DqnAgent(env.action_space, q_model, replay_buffer, rewproc,
                     agentOps, summary_writer)

    egreedyOps = EGreedyOps()
    if replay_buffer is not None:
        egreedyOps.REPLAY_START_SIZE = replay_buffer.REPLAY_START_SIZE
    egreedyOps.mode = args.mode
    egreedyOps.test_epsilon = args.test_epsilon
    #egreedyOps.FINAL_EXPLORATION_FRAME = 10000
    if args.mode == "train":
        egreedyOps.FINAL_EXPLORATION_FRAME = args.egreedy_final_step

    if args.mode == "train":
        if args.egreedy_decay < 1:
            egreedyOps.DECAY = args.egreedy_decay
            egreedyAgent = EGreedyAgentExp(env.action_space, egreedyOps, agent)
        else:
            egreedyAgent = MultiEGreedyAgent(
                env.action_space,
                egreedyOps,
                agent,
                args.egreedy_props,
                args.egreedy_final,
                final_exp_frame=args.egreedy_final_step)
    else:
        egreedyAgent = EGreedyAgent(env.action_space, egreedyOps, agent)

    runner = Runner(env,
                    egreedyAgent,
                    proproc,
                    modelOps.AGENT_HISTORY_LENGTH,
                    max_step=args.max_step,
                    max_episode=args.max_episode)
    if replay_buffer is not None:
        runner.listen(replay_buffer, proproc)
    runner.listen(agent, None)
    runner.listen(egreedyAgent, None)
    if viewer is not None:
        runner.listen(viewer, None)

    if args.output_dir is not None:
        networkSaver = NetworkSaver(
            50000 if 'save_interval' not in kargs else kargs['save_interval'],
            args.output_dir, q_model.model)
        runner.listen(networkSaver, None)

    return runner, agent
Example #2
0
agentOps = DqnAgentOps()
agentOps.double_dqn = args.double_dqn
agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = 20
#agentOps.REPLAY_START_SIZE = 100
#agentOps.FINAL_EXPLORATION_FRAME = 10000

replay_buffer = ReplayBuffer(int(2000), 1, 1, 1000, 64)
#replay_buffer = NStepBuffer(modelOps.AGENT_HISTORY_LENGTH, 8)
agent = DqnAgent(env.action_space, q_model, replay_buffer, rewproc, agentOps,
                 summary_writer)

egreedyOps = EGreedyOps()
egreedyOps.REPLAY_START_SIZE = replay_buffer.REPLAY_START_SIZE
egreedyOps.FINAL_EXPLORATION_FRAME = 10000
egreedyOps.FINAL_EXPLORATION = 0.01
egreedyOps.DECAY = 0.999
egreedyAgent = EGreedyAgentExp(env.action_space, egreedyOps, agent)

runner = Runner(env, egreedyAgent, proproc, 1)
runner.listen(replay_buffer, proproc)
runner.listen(agent, None)
runner.listen(egreedyAgent, None)
if viewer is not None:
    runner.listen(viewer, None)

if args.logdir is not None:
    networkSaver = NetworkSaver(50000, args.logdir, q_model.model)
    runner.listen(networkSaver, None)

runner.run()
Example #3
0
class AgentThread(StoppableThread, RunnerListener):
    def __init__(self, threadId, sess, graph):
        StoppableThread.__init__(self)
        self.threadId = threadId
        self.sess = sess
        self.graph = graph
        with self.graph.as_default():

            if args.game == "Grid":
                env = GridEnv()
            else:
                env = gym_env(args.game)
                env = Penalizer(env)

            proproc = None
            rewproc = None
            q_model = CartPoleModel(modelOps)

            q_model.model_update = model.model
            q_model.set_weights(model.get_weights())
            summary_writer = tf.summary.FileWriter(
                args.logdir + '/thread-' + str(threadId),
                K.get_session().graph) if not args.logdir is None else None

            agentOps = DqnAgentOps()
            agentOps.double_dqn = args.double_dqn
            agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = 20
            agentOps.REPLAY_START_SIZE = 1
            #agentOps.INITIAL_EXPLORATION = 0
            agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = 1e10

            #replay_buffer = ReplayBuffer(int(1e6), 4, 4, agentOps.REPLAY_START_SIZE, 32)
            replay_buffer = NStepBuffer(modelOps.AGENT_HISTORY_LENGTH,
                                        args.nstep)
            agent = DqnAgent(env.action_space,
                             q_model,
                             replay_buffer,
                             rewproc,
                             agentOps,
                             summary_writer,
                             model_eval=model_eval)  #

            egreedyOps = EGreedyOps()
            egreedyOps.REPLAY_START_SIZE = 1
            egreedyOps.FINAL_EXPLORATION_FRAME = 5000
            egreedyOps.FINAL_EXPLORATION = 0.01
            egreedyOps.DECAY = 0.999
            egreedyAgent = MultiEGreedyAgent(env.action_space, egreedyOps,
                                             agent, [0.4, 0.3, 0.3],
                                             [0.1, 0.01, 0.5])

            self.runner = Runner(env, egreedyAgent, proproc,
                                 modelOps.AGENT_HISTORY_LENGTH)
            self.runner.listen(replay_buffer, proproc)
            self.runner.listen(agent, None)
            self.runner.listen(egreedyAgent, None)
            self.runner.listen(self, proproc)
        pass

    def run(self):
        with self.graph.as_default():
            self.runner.run()

    def on_step(self, ob, action, next_ob, reward, done):
        global T
        global model, model_eval
        with tLock:
            T = T + 1
        #if T % 1000 == 0:
        #	print('STEP', T)
        if T % target_network_update_freq == 0:
            print('CLONE TARGET')
            model_eval.set_weights(model.get_weights())
            for agent in agents:
                agent.model_eval = model_eval
        if T % args.render_step == 0 and ENABLE_RENDER:
            viewer.imshow(
                np.repeat(np.reshape(ob, ob.shape + (1, )), 3, axis=2))
        if T % SAVE_FREQ == 0 and args.mode == "train":
            if not args.output_dir is None:
                model.model.save_weights(args.output_dir +
                                         '/weights_{0}.h5'.format(T))
        #print(T)
    def stop(self):
        super(AgentThread, self).stop()
        self.runner.stop()
Example #4
0
    class AgentThread(StoppableThread, RunnerListener):
        def __init__(self, threadId, sess, graph):
            StoppableThread.__init__(self)
            self.threadId = threadId
            self.sess = sess
            self.graph = graph
            with self.graph.as_default():
                if args.atari:
                    env = gym_env(args.game + 'NoFrameskip-v0')
                    env = WarmUp(env, min_step=0, max_step=30)
                    env = ActionRepeat(env, 4)
                    proproc = PreProPipeline(
                        [GrayPrePro(),
                         ResizePrePro(modelOps.INPUT_SIZE)])
                    rewproc = PreProPipeline([RewardClipper(-1, 1)])
                    #q_model = A3CModel(modelOps)
                else:
                    if args.game == "Grid":
                        env = GridEnv()
                    else:
                        env = gym_env(args.game)
                    proproc = None
                    rewproc = None
                    #q_model = TabularQModel(modelOps)
                for trans in args.env_transforms:
                    env = globals()[trans](env)

                if 'shared_model' in kargs and kargs['shared_model']:
                    q_model = model
                else:
                    q_model = globals()[args.model](modelOps)
                    q_model.model_update = model.model
                    q_model.set_weights(model.get_weights())
                summary_writer = tf.summary.FileWriter(
                    args.logdir + '/thread-' + str(threadId),
                    K.get_session().graph) if not args.logdir is None else None

                agentOps = DqnAgentOps()
                agentOps.double_dqn = args.double_dqn
                agentOps.REPLAY_START_SIZE = 1
                #agentOps.INITIAL_EXPLORATION = 0
                agentOps.TARGET_NETWORK_UPDATE_FREQUENCY = 1e10

                #replay_buffer = ReplayBuffer(int(1e6), 4, 4, agentOps.REPLAY_START_SIZE, 32)
                replay_buffer = None
                #if threadId > 0:
                if args.nstep > 0:
                    replay_buffer = NStepBuffer(modelOps.AGENT_HISTORY_LENGTH,
                                                args.nstep)
                else:
                    replay_buffer = ReplayBuffer(args.replay_buffer_size,
                                                 modelOps.AGENT_HISTORY_LENGTH,
                                                 args.update_frequency,
                                                 args.replay_start_size,
                                                 args.batch_size)

                #print(kargs['agent'])
                if kargs['agent'] == 'ActorCriticAgent':
                    agent = ActorCriticAgent(env.action_space,
                                             q_model,
                                             replay_buffer,
                                             rewproc,
                                             agentOps,
                                             summary_writer,
                                             ac_model_update=model)  #
                else:
                    agent = DqnAgent(env.action_space,
                                     q_model,
                                     replay_buffer,
                                     rewproc,
                                     agentOps,
                                     summary_writer,
                                     model_eval=model_eval)  #

                egreedyAgent = None

                if threadId > 0 and kargs[
                        'agent'] != 'ActorCriticAgent':  # first thread is for testing
                    egreedyOps = EGreedyOps()
                    egreedyOps.REPLAY_START_SIZE = replay_buffer.REPLAY_START_SIZE
                    #egreedyOps.FINAL_EXPLORATION_FRAME = int(args.egreedy_final_step / args.thread_count)
                    #if args.egreedy_decay<1:
                    #	egreedyAgent = EGreedyAgentExp(env.action_space, egreedyOps, agent)
                    #else:
                    if len(args.egreedy_props
                           ) > 1 and args.egreedy_props[0] == round(
                               args.egreedy_props[0]):
                        cs = np.array(args.egreedy_props)
                        cs = np.cumsum(cs)
                        idx = np.searchsorted(cs, threadId)
                        print('Egreedyagent selected', idx,
                              args.egreedy_final[idx], args.egreedy_decay[idx],
                              args.egreedy_final_step[idx])
                        egreedyAgent = MultiEGreedyAgent(
                            env.action_space, egreedyOps, agent, [1],
                            [args.egreedy_final[idx]],
                            [args.egreedy_decay[idx]],
                            [args.egreedy_final_step[idx]])
                    else:
                        egreedyAgent = MultiEGreedyAgent(
                            env.action_space, egreedyOps, agent,
                            args.egreedy_props, args.egreedy_final,
                            args.egreedy_decay, args.egreedy_final_step)

                self.runner = Runner(
                    env, egreedyAgent if egreedyAgent is not None else agent,
                    proproc, modelOps.AGENT_HISTORY_LENGTH)
                if replay_buffer is not None:
                    self.runner.listen(replay_buffer, proproc)
                self.runner.listen(agent, None)
                if egreedyAgent is not None:
                    self.runner.listen(egreedyAgent, None)
                self.runner.listen(self, proproc)
                self.agent = agent
                self.q_model = q_model
            pass

        def run(self):
            with self.graph.as_default():
                self.runner.run()

        def on_step(self, ob, action, next_ob, reward, done):
            global T
            global model, model_eval
            with tLock:
                T = T + 1
                if T % target_network_update_freq == 0 and kargs[
                        'agent'] != 'ActorCriticAgent':
                    #print('CLONE TARGET: ' + str(T))
                    model_eval.set_weights(model.get_weights())
                    for agent in agents:
                        agent.model_eval = model_eval
                if T % SAVE_FREQ == 0 and args.mode == "train":
                    if not args.output_dir is None:
                        model.model.save_weights(args.output_dir +
                                                 '/weights_{0}.h5'.format(T))
            #if T % 1000 == 0:
            #	print('STEP', T)
            #if self.threadId == 0 and T % 10 == 0:
            #	self.q_model.set_weights(model.get_weights())
            if T % args.render_step == 0 and ENABLE_RENDER:
                viewer.imshow(
                    np.repeat(np.reshape(ob, ob.shape + (1, )), 3, axis=2))
            if T > args.max_step:
                self.stop()
            #print(T)
        def stop(self):
            super(AgentThread, self).stop()
            self.runner.stop()
Example #5
0

#print(env.observation_space.n)

modelOps = DqnOps(env.action_count)
modelOps.dueling_network = args.dueling_dqn
modelOps.INPUT_SIZE = env.observation_space.n
modelOps.LEARNING_RATE = 0.2

q_model = TabularQModel(modelOps)

summary_writer = tf.summary.FileWriter(args.logdir, K.get_session().graph) if not args.logdir is None else None

agentOps = DqnAgentOps()
agentOps.double_dqn = args.double_dqn

replay_buffer = NStepBuffer(1, args.nstep)
agent = DqnAgent(env.action_space, q_model, replay_buffer, None, agentOps, summary_writer)

egreedyOps = EGreedyOps()
egreedyOps.REPLAY_START_SIZE = 1
egreedyOps.FINAL_EXPLORATION_FRAME = 10000
egreedyAgent = EGreedyAgent(env.action_space, egreedyOps, agent)

runner = Runner(env, egreedyAgent, None, 1)
runner.listen(replay_buffer, None)
runner.listen(agent, None)
runner.listen(egreedyAgent, None)

runner.run()