def evaluate(): """Evaluated current agent, and records a video with it's performance""" envwrap = GymWrapperFactory.make(FLAGS.env, actrep=FLAGS.action_repeat, memlen=FLAGS.memory_len, w=FLAGS.width, h=FLAGS.height) with tf.Session() as sess: agent = QlearningAgent(session=sess, action_size=envwrap.action_size, h=FLAGS.height, w=FLAGS.width, channels=FLAGS.memory_len, opt=tf.train.AdamOptimizer(FLAGS.lr)) sess.run(tf.initialize_all_variables()) if not os.path.exists(FLAGS.logdir): print('ERROR! No', FLAGS.logdir, 'folder found!') return ckpt = tf.train.latest_checkpoint(FLAGS.logdir) if ckpt is not None: tf.train.Saver().restore(sess, ckpt) agent.update_target() print('Session was restored from %s' % ckpt) else: print('ERROR! No checkpoint found at', FLAGS.logdir) return #envwrap.env.monitor.start(os.path.join(FLAGS.evaldir, FLAGS.env)) envwrap.env = gym.wrappers.Monitor( envwrap.env, os.path.join(FLAGS.evaldir, FLAGS.env)) total_reward = 0 for epnum in range(FLAGS.eval_iter): r_max = [] s = envwrap.reset() terminal = False while not terminal: reward_per_action = agent.predict_rewards(s) rmax = np.max(reward_per_action) s, r, terminal, info = envwrap.step( np.argmax(reward_per_action), test=True) sreshape = np.transpose( np.reshape( s, (FLAGS.width, FLAGS.height, FLAGS.action_repeat)), [2, 0, 1]) total_reward += r r_max.append(rmax) smp.imsave("plots2/" + str(len(r_max)) + "_screen.png", sreshape[2]) plot2(reward_per_action[:4], len(r_max)) #envwrap.render() plot(r_max, epnum) envwrap.env.close() print('Evaluation finished.') print('Average reward per episode: %0.4f' % (total_reward / FLAGS.eval_iter))
def run(worker): """Launches worker asynchronously in 'FLAGS.threads' threads :param worker: worker function""" print('Starting. Threads:', FLAGS.threads) processes = [] envs = [] for _ in range(FLAGS.threads): env = GymWrapperFactory.make(FLAGS.env, actrep=FLAGS.action_repeat, memlen=FLAGS.memory_len, w=FLAGS.width, h=FLAGS.height) envs.append(env) with tf.Session() as sess: agent = QlearningAgent(session=sess, action_size=envs[0].action_size, h=FLAGS.height, w=FLAGS.width, channels=FLAGS.memory_len, opt=tf.train.AdamOptimizer(FLAGS.lr)) saver = tf.train.Saver(tf.global_variables(), max_to_keep=2) sess.run(tf.global_variables_initializer()) if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir) ckpt = tf.train.latest_checkpoint(FLAGS.logdir) if ckpt is not None: saver.restore(sess, ckpt) agent.update_target() print('Restoring session from %s' % ckpt) summary = AgentSummary(FLAGS.logdir, agent, FLAGS.env) for i in range(FLAGS.threads): processes.append( th.Thread(target=worker, args=( agent, envs[i], sess, summary, saver, i, ))) for p in processes: p.daemon = True p.start() while not training_finished: if FLAGS.render: for i in range(FLAGS.threads): envs[i].render() time.sleep(.01) for p in processes: p.join()
def evaluate(): """Evaluated current agent, and records a video with it's performance""" envwrap = GymWrapperFactory.make(FLAGS.env, actrep=FLAGS.action_repeat, memlen=FLAGS.memory_len, w=FLAGS.width, h=FLAGS.height) with tf.Session() as sess: agent = QlearningAgent(session=sess, action_size=envwrap.action_size, h=FLAGS.height, w=FLAGS.width, channels=FLAGS.memory_len, opt=tf.train.AdamOptimizer(FLAGS.lr)) sess.run(tf.initialize_all_variables()) if not os.path.exists(FLAGS.logdir): print('ERROR! No', FLAGS.logdir, 'folder found!') return ckpt = tf.train.latest_checkpoint(FLAGS.logdir) if ckpt is not None: tf.train.Saver().restore(sess, ckpt) agent.update_target() print('Session was restored from %s' % ckpt) else: print('ERROR! No checkpoint found at', FLAGS.logdir) return envwrap.env.monitor.start(os.path.join(FLAGS.evaldir, FLAGS.env)) total_reward = 0 for _ in range(FLAGS.eval_iter): s = envwrap.reset() terminal = False while not terminal: reward_per_action = agent.predict_rewards(s) s, r, terminal, info = envwrap.step( np.argmax(reward_per_action), test=True) total_reward += r envwrap.render() envwrap.env.monitor.close() print('Evaluation finished.') print('Average reward per episode: %0.4f' % (total_reward / FLAGS.eval_iter))