Exemplo n.º 1
0
 def testBatchModeTrueWithTwoDimensionalState(self):
   state = np.arange(2).reshape(1, 2)
   mock_agent = mock.create_autospec(caql_agent.CaqlAgent, instance=True)
   mock_agent.best_action.return_value = (
       np.arange(3).reshape(1, 3), None, None, True)
   policy = agent_policy.AgentPolicy(self._action_spec, mock_agent)
   action = policy.action(state, batch_mode=True)
   self.assertAllEqual(np.arange(3).reshape(1, 3), action)
Exemplo n.º 2
0
    def testInitializeWithCheckpoint(self):
        # Create an agent, train 17 steps, and save a checkpoint.
        env = utils.create_env('Pendulum')
        state_spec, action_spec = utils.get_state_and_action_specs(env)
        replay_memory = replay_memory_lib.ReplayMemory(name='ReplayBuffer',
                                                       capacity=100000)
        save_path = os.path.join(self.get_temp_dir(), 'test_checkpoint')
        with self.test_session(graph=tf.Graph()) as sess:
            agent = _create_agent(sess, state_spec, action_spec)
            saver = tf.train.Saver()
            step = agent.initialize(saver)
            self.assertEqual(0, step)

            greedy_policy = agent_policy.AgentPolicy(action_spec, agent)
            behavior_policy = gaussian_noise_policy.GaussianNoisePolicy(
                greedy_policy, 1., .99, .01)

            for _ in range(100):
                env = utils.create_env('Pendulum')
                episode, _, _ = utils._collect_episode(
                    env=env,
                    time_out=200,
                    discount_factor=.99,
                    behavior_policy=behavior_policy)
                replay_memory.extend(episode)
                if hasattr(env, 'close'):
                    env.close()

            while step < 17:
                minibatch = replay_memory.sample_with_replacement(64)
                (_, _, _, best_train_label_batch, _,
                 _) = (agent.train_q_function_network(minibatch, None, None))
                agent.train_action_function_network(best_train_label_batch)
                step += 1
            saver.save(sess, save_path)

        # Create an agent and restore TF variables from the checkpoint.
        with self.test_session() as sess:
            agent = _create_agent(sess, state_spec, action_spec)
            saver = tf.train.Saver()
            self.assertEqual(17, agent.initialize(saver, self.get_temp_dir()))
Exemplo n.º 3
0
def main(_):
    logging.set_verbosity(logging.INFO)

    assert FLAGS.replay_memory_capacity > FLAGS.batch_size * FLAGS.train_steps_per_iteration
    replay_memory = replay_memory_lib.ReplayMemory(
        name='ReplayBuffer', capacity=FLAGS.replay_memory_capacity)
    replay_memory.restore(FLAGS.checkpoint_dir)

    env = utils.create_env(FLAGS.env_name)
    state_spec, action_spec = utils.get_state_and_action_specs(
        env, action_bounds=FLAGS.action_bounds)

    hidden_layers = [int(h) for h in FLAGS.hidden_layers]

    summary_writer = None
    if FLAGS.result_dir is not None:
        hparam_dict = {
            'env_name': FLAGS.env_name,
            'discount_factor': FLAGS.discount_factor,
            'time_out': FLAGS.time_out,
            'action_bounds': FLAGS.action_bounds,
            'max_iterations': FLAGS.max_iterations,
            'num_episodes_per_iteration': FLAGS.num_episodes_per_iteration,
            'collect_experience_parallelism':
            FLAGS.collect_experience_parallelism,
            'hidden_layers': FLAGS.hidden_layers,
            'batch_size': FLAGS.batch_size,
            'train_steps_per_iteration': FLAGS.train_steps_per_iteration,
            'target_update_steps': FLAGS.target_update_steps,
            'learning_rate': FLAGS.learning_rate,
            'learning_rate_action': FLAGS.learning_rate_action,
            'learning_rate_ga': FLAGS.learning_rate_ga,
            'action_maximization_iterations':
            FLAGS.action_maximization_iterations,
            'tau_copy': FLAGS.tau_copy,
            'clipped_target': FLAGS.clipped_target,
            'hard_update_steps': FLAGS.hard_update_steps,
            'l2_loss_flag': FLAGS.l2_loss_flag,
            'simple_lambda_flag': FLAGS.simple_lambda_flag,
            'dual_filter_clustering_flag': FLAGS.dual_filter_clustering_flag,
            'solver': FLAGS.solver,
            'initial_lambda': FLAGS.initial_lambda,
            'tolerance_init': FLAGS.tolerance_init,
            'tolerance_min': FLAGS.tolerance_min,
            'tolerance_max': FLAGS.tolerance_max,
            'tolerance_decay': FLAGS.tolerance_decay,
            'warmstart': FLAGS.warmstart,
            'dual_q_label': FLAGS.dual_q_label,
            'seed': FLAGS.seed,
        }
        if FLAGS.exploration_policy == 'egreedy':
            hparam_dict.update({
                'epsilon': FLAGS.epsilon,
                'epsilon_decay': FLAGS.epsilon_decay,
                'epsilon_min': FLAGS.epsilon_min,
            })
        elif FLAGS.exploration_policy == 'gaussian':
            hparam_dict.update({
                'sigma': FLAGS.sigma,
                'sigma_decay': FLAGS.sigma_decay,
                'sigma_min': FLAGS.sigma_min,
            })

        utils.save_hparam_config(hparam_dict, FLAGS.result_dir)
        summary_writer = tf.summary.FileWriter(FLAGS.result_dir)

    with tf.Session() as sess:
        agent = caql_agent.CaqlAgent(
            session=sess,
            state_spec=state_spec,
            action_spec=action_spec,
            discount_factor=FLAGS.discount_factor,
            hidden_layers=hidden_layers,
            learning_rate=FLAGS.learning_rate,
            learning_rate_action=FLAGS.learning_rate_action,
            learning_rate_ga=FLAGS.learning_rate_ga,
            action_maximization_iterations=FLAGS.
            action_maximization_iterations,
            tau_copy=FLAGS.tau_copy,
            clipped_target_flag=FLAGS.clipped_target,
            hard_update_steps=FLAGS.hard_update_steps,
            batch_size=FLAGS.batch_size,
            l2_loss_flag=FLAGS.l2_loss_flag,
            simple_lambda_flag=FLAGS.simple_lambda_flag,
            dual_filter_clustering_flag=FLAGS.dual_filter_clustering_flag,
            solver=FLAGS.solver,
            dual_q_label=FLAGS.dual_q_label,
            initial_lambda=FLAGS.initial_lambda,
            tolerance_min_max=[FLAGS.tolerance_min, FLAGS.tolerance_max])

        saver = tf.train.Saver(max_to_keep=None)
        step = agent.initialize(saver, FLAGS.checkpoint_dir)

        iteration = int(step / FLAGS.train_steps_per_iteration)
        if iteration >= FLAGS.max_iterations:
            return

        greedy_policy = agent_policy.AgentPolicy(action_spec, agent)
        if FLAGS.exploration_policy == 'egreedy':
            epsilon_init = max(
                FLAGS.epsilon * (FLAGS.epsilon_decay**iteration),
                FLAGS.epsilon_min)
            behavior_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(
                greedy_policy, epsilon_init, FLAGS.epsilon_decay,
                FLAGS.epsilon_min)
        elif FLAGS.exploration_policy == 'gaussian':
            sigma_init = max(FLAGS.sigma * (FLAGS.sigma_decay**iteration),
                             FLAGS.sigma_min)
            behavior_policy = gaussian_noise_policy.GaussianNoisePolicy(
                greedy_policy, sigma_init, FLAGS.sigma_decay, FLAGS.sigma_min)
        elif FLAGS.exploration_policy == 'none':
            behavior_policy = greedy_policy

        logging.info('Start with iteration %d, step %d, %s', iteration, step,
                     behavior_policy.params_debug_str())

        while iteration < FLAGS.max_iterations:
            utils.collect_experience_parallel(
                num_episodes=FLAGS.num_episodes_per_iteration,
                session=sess,
                behavior_policy=behavior_policy,
                time_out=FLAGS.time_out,
                discount_factor=FLAGS.discount_factor,
                replay_memory=replay_memory)

            if (replay_memory.size <
                    FLAGS.batch_size * FLAGS.train_steps_per_iteration):
                continue

            tf_summary = None
            if summary_writer:
                tf_summary = tf.Summary()

            q_function_losses = []
            q_vals = []
            lambda_function_losses = []
            action_function_losses = []
            portion_active_data = []
            portion_active_data_and_clusters = []
            ts_begin = time.time()

            # 'step' can be started from any number if the program is restored from
            # a checkpoint after crash or pre-emption.
            local_step = step % FLAGS.train_steps_per_iteration
            while local_step < FLAGS.train_steps_per_iteration:
                minibatch = replay_memory.sample(FLAGS.batch_size)
                if FLAGS.tolerance_decay is not None:
                    tolerance_decay = FLAGS.tolerance_decay**iteration
                else:
                    tolerance_decay = None

                # Leave summary only for the last one.
                agent_tf_summary_vals = None
                if local_step == FLAGS.train_steps_per_iteration - 1:
                    agent_tf_summary_vals = []

                # train q_function and lambda_function networks
                (q_function_loss, target_q_vals, lambda_function_loss,
                 best_train_label_batch, portion_active_constraint,
                 portion_active_constraint_and_cluster) = (
                     agent.train_q_function_network(minibatch,
                                                    FLAGS.tolerance_init,
                                                    tolerance_decay,
                                                    FLAGS.warmstart,
                                                    agent_tf_summary_vals))

                action_function_loss = agent.train_action_function_network(
                    best_train_label_batch)

                q_function_losses.append(q_function_loss)
                q_vals.append(target_q_vals)
                lambda_function_losses.append(lambda_function_loss)
                action_function_losses.append(action_function_loss)
                portion_active_data.append(portion_active_constraint)
                portion_active_data_and_clusters.append(
                    portion_active_constraint_and_cluster)

                local_step += 1
                step += 1
                if step % FLAGS.target_update_steps == 0:
                    agent.update_target_network()
                if FLAGS.clipped_target and step % FLAGS.hard_update_steps == 0:
                    agent.update_target_network2()

            elapsed_secs = time.time() - ts_begin
            steps_per_sec = FLAGS.train_steps_per_iteration / elapsed_secs

            iteration += 1
            logging.info(
                'Iteration: %d, steps per sec: %.2f, replay memory size: %d, %s, '
                'avg q_function loss: %.3f, '
                'avg lambda_function loss: %.3f, '
                'avg action_function loss: %.3f '
                'avg portion active data: %.3f '
                'avg portion active data and cluster: %.3f ', iteration,
                steps_per_sec, replay_memory.size,
                behavior_policy.params_debug_str(), np.mean(q_function_losses),
                np.mean(lambda_function_losses),
                np.mean(action_function_losses), np.mean(portion_active_data),
                np.mean(portion_active_data_and_clusters))

            if tf_summary:
                if agent_tf_summary_vals:
                    tf_summary.value.extend(agent_tf_summary_vals)
                tf_summary.value.extend([
                    tf.Summary.Value(tag='steps_per_sec',
                                     simple_value=steps_per_sec),
                    tf.Summary.Value(tag='avg_q_loss',
                                     simple_value=np.mean(q_function_loss)),
                    tf.Summary.Value(tag='avg_q_val',
                                     simple_value=np.mean(q_vals)),
                    tf.Summary.Value(
                        tag='avg_portion_active_data',
                        simple_value=np.mean(portion_active_data)),
                    tf.Summary.Value(
                        tag='avg_portion_active_data_and_cluster',
                        simple_value=np.mean(portion_active_data_and_clusters))
                ])

            behavior_policy.update_params()
            utils.periodic_updates(iteration=iteration,
                                   train_step=step,
                                   replay_memories=(replay_memory, ),
                                   greedy_policy=greedy_policy,
                                   use_action_function=True,
                                   saver=saver,
                                   sess=sess,
                                   time_out=FLAGS.time_out,
                                   tf_summary=tf_summary)

            if summary_writer and tf_summary:
                summary_writer.add_summary(tf_summary, step)

    logging.info('Training is done.')
    env.close()