def main(args): args_dict = vars(args) print('args: {}'.format(args_dict)) with tf.Graph().as_default() as g: # rollout subgraph with tf.name_scope('rollout'): observations = tf.placeholder(shape=(None, OBSERVATION_DIM), dtype=tf.float32) logits = build_graph(observations) logits_for_sampling = tf.reshape(logits, shape=(1, len(ACTIONS))) # Sample the action to be played during rollout. sample_action = tf.squeeze(tf.multinomial(logits=logits_for_sampling, num_samples=1)) optimizer = tf.train.RMSPropOptimizer( learning_rate=args.learning_rate, decay=args.decay ) # dataset subgraph for experience replay with tf.name_scope('dataset'): # the dataset reads from MEMORY ds = tf.data.Dataset.from_generator(gen, output_types=(tf.float32, tf.int32, tf.float32)) ds = ds.shuffle(MEMORY_CAPACITY).repeat().batch(args.batch_size) iterator = ds.make_one_shot_iterator() # training subgraph with tf.name_scope('train'): # the train_op includes getting a batch of data from the dataset, so we do not need to use a feed_dict when running the train_op. next_batch = iterator.get_next() train_observations, labels, processed_rewards = next_batch # This reuses the same weights in the rollout phase. train_observations.set_shape((args.batch_size, OBSERVATION_DIM)) train_logits = build_graph(train_observations) cross_entropies = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=train_logits, labels=labels ) # Extra loss when the paddle is moved, to encourage more natural moves. probs = tf.nn.softmax(logits=train_logits) move_cost = args.laziness * tf.reduce_sum(probs * [0, 1.0, 1.0], axis=1) loss = tf.reduce_sum(processed_rewards * cross_entropies + move_cost) global_step = tf.train.get_or_create_global_step() train_op = optimizer.minimize(loss, global_step=global_step) init = tf.global_variables_initializer() saver = tf.train.Saver(max_to_keep=args.max_to_keep) with tf.name_scope('summaries'): rollout_reward = tf.placeholder( shape=(), dtype=tf.float32 ) # the weights to the hidden layer can be visualized hidden_weights = tf.trainable_variables()[0] for h in range(args.hidden_dim): slice_ = tf.slice(hidden_weights, [0, h], [-1, 1]) image = tf.reshape(slice_, [1, 80, 80, 1]) tf.summary.image('hidden_{:04d}'.format(h), image) for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var) tf.summary.scalar('{}_max'.format(var.op.name), tf.reduce_max(var)) tf.summary.scalar('{}_min'.format(var.op.name), tf.reduce_min(var)) tf.summary.scalar('rollout_reward', rollout_reward) tf.summary.scalar('loss', loss) merged = tf.summary.merge_all() print('Number of trainable variables: {}'.format(len(tf.trainable_variables()))) inner_env = gym.make('Pong-v0') # tf.agents helper to more easily track consecutive pairs of frames env = FrameHistory(inner_env, past_indices=[0, 1], flatten=False) # tf.agents helper to automatically reset the environment env = AutoReset(env) with tf.Session(graph=g) as sess: if args.restore: restore_path = tf.train.latest_checkpoint(args.output_dir) print('Restoring from {}'.format(restore_path)) saver.restore(sess, restore_path) else: sess.run(init) summary_path = os.path.join(args.output_dir, 'summary') summary_writer = tf.summary.FileWriter(summary_path, sess.graph) # lowest possible score after an episode as the # starting value of the running reward _rollout_reward = -21.0 for i in range(args.n_epoch): print('>>>>>>> epoch {}'.format(i+1)) print('>>> Rollout phase') epoch_memory = [] episode_memory = [] # The loop for actions/stepss _observation = np.zeros(OBSERVATION_DIM) while True: # sample one action with the given probability distribution _label = sess.run(sample_action, feed_dict={observations: [_observation]}) _action = ACTIONS[_label] _pair_state, _reward, _done, _ = env.step(_action) if args.render: env.render() # record experience episode_memory.append((_observation, _label, _reward)) # Get processed frame delta for the next step pair_state = _pair_state current_state, previous_state = pair_state current_x = prepro(current_state) previous_x = prepro(previous_state) _observation = current_x - previous_x if _done: obs, lbl, rwd = zip(*episode_memory) # processed rewards prwd = discount_rewards(rwd, args.gamma) prwd -= np.mean(prwd) prwd /= np.std(prwd) # store the processed experience to memory epoch_memory.extend(zip(obs, lbl, prwd)) # calculate the running rollout reward _rollout_reward = 0.9 * _rollout_reward + 0.1 * sum(rwd) episode_memory = [] if args.render: _ = input('episode done, press Enter to replay') epoch_memory = [] continue if len(epoch_memory) >= ROLLOUT_SIZE: break # add to the global memory MEMORY.extend(epoch_memory) print('>>> Train phase') print('rollout reward: {}'.format(_rollout_reward)) # Here we train only once. _, _global_step = sess.run([train_op, global_step]) if _global_step % args.save_checkpoint_steps == 0: print('Writing summary') feed_dict = {rollout_reward: _rollout_reward} summary = sess.run(merged, feed_dict=feed_dict) summary_writer.add_summary(summary, _global_step) save_path = os.path.join(args.output_dir, 'model.ckpt') save_path = saver.save(sess, save_path, global_step=_global_step) print('Model checkpoint saved: {}'.format(save_path))
return discounted_r.tolist() if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--model-directory', default='/tmp/pong') parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--train-batch-size', default=10000, type=int) parser.add_argument('--learning-rate', type=float, default=5e-3) parser.add_argument('--gamma', default=0.99, type=float) parser.add_argument('--decay', type=float, default=0.99) args = parser.parse_args() inner_env = gym.make('Pong-v0') env = FrameHistory(inner_env, past_indices=[0, 1], flatten=False) env = AutoReset(env) writer = tf.contrib.summary.create_file_writer(args.model_directory) writer.set_as_default() model = tf.keras.Sequential([ tf.keras.layers.Dense(200, input_shape=(OBSERVATION_DIM, ), use_bias=False, activation='relu', name='hidden'), tf.keras.layers.Dense(len(ACTIONS), use_bias=False, name='logits') ]) optimizer = tf.train.RMSPropOptimizer(learning_rate=args.learning_rate, decay=args.decay)