def testBatchModeTrueWithTwoDimensionalState(self): state = np.arange(2).reshape(1, 2) mock_agent = mock.create_autospec(caql_agent.CaqlAgent, instance=True) mock_agent.best_action.return_value = ( np.arange(3).reshape(1, 3), None, None, True) policy = agent_policy.AgentPolicy(self._action_spec, mock_agent) action = policy.action(state, batch_mode=True) self.assertAllEqual(np.arange(3).reshape(1, 3), action)
def testInitializeWithCheckpoint(self): # Create an agent, train 17 steps, and save a checkpoint. env = utils.create_env('Pendulum') state_spec, action_spec = utils.get_state_and_action_specs(env) replay_memory = replay_memory_lib.ReplayMemory(name='ReplayBuffer', capacity=100000) save_path = os.path.join(self.get_temp_dir(), 'test_checkpoint') with self.test_session(graph=tf.Graph()) as sess: agent = _create_agent(sess, state_spec, action_spec) saver = tf.train.Saver() step = agent.initialize(saver) self.assertEqual(0, step) greedy_policy = agent_policy.AgentPolicy(action_spec, agent) behavior_policy = gaussian_noise_policy.GaussianNoisePolicy( greedy_policy, 1., .99, .01) for _ in range(100): env = utils.create_env('Pendulum') episode, _, _ = utils._collect_episode( env=env, time_out=200, discount_factor=.99, behavior_policy=behavior_policy) replay_memory.extend(episode) if hasattr(env, 'close'): env.close() while step < 17: minibatch = replay_memory.sample_with_replacement(64) (_, _, _, best_train_label_batch, _, _) = (agent.train_q_function_network(minibatch, None, None)) agent.train_action_function_network(best_train_label_batch) step += 1 saver.save(sess, save_path) # Create an agent and restore TF variables from the checkpoint. with self.test_session() as sess: agent = _create_agent(sess, state_spec, action_spec) saver = tf.train.Saver() self.assertEqual(17, agent.initialize(saver, self.get_temp_dir()))
def main(_): logging.set_verbosity(logging.INFO) assert FLAGS.replay_memory_capacity > FLAGS.batch_size * FLAGS.train_steps_per_iteration replay_memory = replay_memory_lib.ReplayMemory( name='ReplayBuffer', capacity=FLAGS.replay_memory_capacity) replay_memory.restore(FLAGS.checkpoint_dir) env = utils.create_env(FLAGS.env_name) state_spec, action_spec = utils.get_state_and_action_specs( env, action_bounds=FLAGS.action_bounds) hidden_layers = [int(h) for h in FLAGS.hidden_layers] summary_writer = None if FLAGS.result_dir is not None: hparam_dict = { 'env_name': FLAGS.env_name, 'discount_factor': FLAGS.discount_factor, 'time_out': FLAGS.time_out, 'action_bounds': FLAGS.action_bounds, 'max_iterations': FLAGS.max_iterations, 'num_episodes_per_iteration': FLAGS.num_episodes_per_iteration, 'collect_experience_parallelism': FLAGS.collect_experience_parallelism, 'hidden_layers': FLAGS.hidden_layers, 'batch_size': FLAGS.batch_size, 'train_steps_per_iteration': FLAGS.train_steps_per_iteration, 'target_update_steps': FLAGS.target_update_steps, 'learning_rate': FLAGS.learning_rate, 'learning_rate_action': FLAGS.learning_rate_action, 'learning_rate_ga': FLAGS.learning_rate_ga, 'action_maximization_iterations': FLAGS.action_maximization_iterations, 'tau_copy': FLAGS.tau_copy, 'clipped_target': FLAGS.clipped_target, 'hard_update_steps': FLAGS.hard_update_steps, 'l2_loss_flag': FLAGS.l2_loss_flag, 'simple_lambda_flag': FLAGS.simple_lambda_flag, 'dual_filter_clustering_flag': FLAGS.dual_filter_clustering_flag, 'solver': FLAGS.solver, 'initial_lambda': FLAGS.initial_lambda, 'tolerance_init': FLAGS.tolerance_init, 'tolerance_min': FLAGS.tolerance_min, 'tolerance_max': FLAGS.tolerance_max, 'tolerance_decay': FLAGS.tolerance_decay, 'warmstart': FLAGS.warmstart, 'dual_q_label': FLAGS.dual_q_label, 'seed': FLAGS.seed, } if FLAGS.exploration_policy == 'egreedy': hparam_dict.update({ 'epsilon': FLAGS.epsilon, 'epsilon_decay': FLAGS.epsilon_decay, 'epsilon_min': FLAGS.epsilon_min, }) elif FLAGS.exploration_policy == 'gaussian': hparam_dict.update({ 'sigma': FLAGS.sigma, 'sigma_decay': FLAGS.sigma_decay, 'sigma_min': FLAGS.sigma_min, }) utils.save_hparam_config(hparam_dict, FLAGS.result_dir) summary_writer = tf.summary.FileWriter(FLAGS.result_dir) with tf.Session() as sess: agent = caql_agent.CaqlAgent( session=sess, state_spec=state_spec, action_spec=action_spec, discount_factor=FLAGS.discount_factor, hidden_layers=hidden_layers, learning_rate=FLAGS.learning_rate, learning_rate_action=FLAGS.learning_rate_action, learning_rate_ga=FLAGS.learning_rate_ga, action_maximization_iterations=FLAGS. action_maximization_iterations, tau_copy=FLAGS.tau_copy, clipped_target_flag=FLAGS.clipped_target, hard_update_steps=FLAGS.hard_update_steps, batch_size=FLAGS.batch_size, l2_loss_flag=FLAGS.l2_loss_flag, simple_lambda_flag=FLAGS.simple_lambda_flag, dual_filter_clustering_flag=FLAGS.dual_filter_clustering_flag, solver=FLAGS.solver, dual_q_label=FLAGS.dual_q_label, initial_lambda=FLAGS.initial_lambda, tolerance_min_max=[FLAGS.tolerance_min, FLAGS.tolerance_max]) saver = tf.train.Saver(max_to_keep=None) step = agent.initialize(saver, FLAGS.checkpoint_dir) iteration = int(step / FLAGS.train_steps_per_iteration) if iteration >= FLAGS.max_iterations: return greedy_policy = agent_policy.AgentPolicy(action_spec, agent) if FLAGS.exploration_policy == 'egreedy': epsilon_init = max( FLAGS.epsilon * (FLAGS.epsilon_decay**iteration), FLAGS.epsilon_min) behavior_policy = epsilon_greedy_policy.EpsilonGreedyPolicy( greedy_policy, epsilon_init, FLAGS.epsilon_decay, FLAGS.epsilon_min) elif FLAGS.exploration_policy == 'gaussian': sigma_init = max(FLAGS.sigma * (FLAGS.sigma_decay**iteration), FLAGS.sigma_min) behavior_policy = gaussian_noise_policy.GaussianNoisePolicy( greedy_policy, sigma_init, FLAGS.sigma_decay, FLAGS.sigma_min) elif FLAGS.exploration_policy == 'none': behavior_policy = greedy_policy logging.info('Start with iteration %d, step %d, %s', iteration, step, behavior_policy.params_debug_str()) while iteration < FLAGS.max_iterations: utils.collect_experience_parallel( num_episodes=FLAGS.num_episodes_per_iteration, session=sess, behavior_policy=behavior_policy, time_out=FLAGS.time_out, discount_factor=FLAGS.discount_factor, replay_memory=replay_memory) if (replay_memory.size < FLAGS.batch_size * FLAGS.train_steps_per_iteration): continue tf_summary = None if summary_writer: tf_summary = tf.Summary() q_function_losses = [] q_vals = [] lambda_function_losses = [] action_function_losses = [] portion_active_data = [] portion_active_data_and_clusters = [] ts_begin = time.time() # 'step' can be started from any number if the program is restored from # a checkpoint after crash or pre-emption. local_step = step % FLAGS.train_steps_per_iteration while local_step < FLAGS.train_steps_per_iteration: minibatch = replay_memory.sample(FLAGS.batch_size) if FLAGS.tolerance_decay is not None: tolerance_decay = FLAGS.tolerance_decay**iteration else: tolerance_decay = None # Leave summary only for the last one. agent_tf_summary_vals = None if local_step == FLAGS.train_steps_per_iteration - 1: agent_tf_summary_vals = [] # train q_function and lambda_function networks (q_function_loss, target_q_vals, lambda_function_loss, best_train_label_batch, portion_active_constraint, portion_active_constraint_and_cluster) = ( agent.train_q_function_network(minibatch, FLAGS.tolerance_init, tolerance_decay, FLAGS.warmstart, agent_tf_summary_vals)) action_function_loss = agent.train_action_function_network( best_train_label_batch) q_function_losses.append(q_function_loss) q_vals.append(target_q_vals) lambda_function_losses.append(lambda_function_loss) action_function_losses.append(action_function_loss) portion_active_data.append(portion_active_constraint) portion_active_data_and_clusters.append( portion_active_constraint_and_cluster) local_step += 1 step += 1 if step % FLAGS.target_update_steps == 0: agent.update_target_network() if FLAGS.clipped_target and step % FLAGS.hard_update_steps == 0: agent.update_target_network2() elapsed_secs = time.time() - ts_begin steps_per_sec = FLAGS.train_steps_per_iteration / elapsed_secs iteration += 1 logging.info( 'Iteration: %d, steps per sec: %.2f, replay memory size: %d, %s, ' 'avg q_function loss: %.3f, ' 'avg lambda_function loss: %.3f, ' 'avg action_function loss: %.3f ' 'avg portion active data: %.3f ' 'avg portion active data and cluster: %.3f ', iteration, steps_per_sec, replay_memory.size, behavior_policy.params_debug_str(), np.mean(q_function_losses), np.mean(lambda_function_losses), np.mean(action_function_losses), np.mean(portion_active_data), np.mean(portion_active_data_and_clusters)) if tf_summary: if agent_tf_summary_vals: tf_summary.value.extend(agent_tf_summary_vals) tf_summary.value.extend([ tf.Summary.Value(tag='steps_per_sec', simple_value=steps_per_sec), tf.Summary.Value(tag='avg_q_loss', simple_value=np.mean(q_function_loss)), tf.Summary.Value(tag='avg_q_val', simple_value=np.mean(q_vals)), tf.Summary.Value( tag='avg_portion_active_data', simple_value=np.mean(portion_active_data)), tf.Summary.Value( tag='avg_portion_active_data_and_cluster', simple_value=np.mean(portion_active_data_and_clusters)) ]) behavior_policy.update_params() utils.periodic_updates(iteration=iteration, train_step=step, replay_memories=(replay_memory, ), greedy_policy=greedy_policy, use_action_function=True, saver=saver, sess=sess, time_out=FLAGS.time_out, tf_summary=tf_summary) if summary_writer and tf_summary: summary_writer.add_summary(tf_summary, step) logging.info('Training is done.') env.close()