Ejemplo n.º 1
0
    def testRunOnce(self, max_steps, max_episodes, expected_steps):
        env = driver_test_utils.PyEnvironmentMock()
        tf_env = tf_py_environment.TFPyEnvironment(env)
        policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(),
                                                tf_env.action_spec())

        replay_buffer_observer = MockReplayBufferObserver()
        transition_replay_buffer_observer = MockReplayBufferObserver()
        driver = tf_driver.TFDriver(
            tf_env,
            policy,
            observers=[replay_buffer_observer],
            transition_observers=[transition_replay_buffer_observer],
            max_steps=max_steps,
            max_episodes=max_episodes)

        initial_time_step = tf_env.reset()
        initial_policy_state = policy.get_initial_state(batch_size=1)
        self.evaluate(driver.run(initial_time_step, initial_policy_state))
        trajectories = replay_buffer_observer.gather_all()
        self.assertEqual(trajectories, self._trajectories[:expected_steps])

        transitions = transition_replay_buffer_observer.gather_all()
        self.assertLen(transitions, expected_steps)
        # TimeStep, Action, NextTimeStep
        self.assertLen(transitions[0], 3)
Ejemplo n.º 2
0
    def testMultipleRunMaxEpisodes(self):
        num_episodes = 2
        num_expected_steps = 6

        env = driver_test_utils.PyEnvironmentMock()
        tf_env = tf_py_environment.TFPyEnvironment(env)
        policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(),
                                                tf_env.action_spec())

        replay_buffer_observer = MockReplayBufferObserver()
        driver = tf_driver.TFDriver(
            tf_env,
            policy,
            observers=[replay_buffer_observer],
            max_steps=None,
            max_episodes=1,
        )

        time_step = tf_env.reset()
        policy_state = policy.get_initial_state(batch_size=1)
        for _ in range(num_episodes):
            time_step, policy_state = self.evaluate(
                driver.run(time_step, policy_state))
        trajectories = replay_buffer_observer.gather_all()
        self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
    def test_with_tf_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        tf_env = tf_py_environment.TFPyEnvironment(env)
        policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(),
                                                tf_env.action_spec())

        trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     tf_env.time_step_spec())

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec)
        driver = tf_driver.TFDriver(tf_env,
                                    policy, [tfrecord_observer],
                                    max_steps=10)
        self.evaluate(tf.compat.v1.global_variables_initializer())

        time_step = self.evaluate(tf_env.reset())
        initial_policy_state = policy.get_initial_state(batch_size=1)
        self.evaluate(
            common.function(driver.run)(time_step, initial_policy_state))
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)
        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
Ejemplo n.º 4
0
  def testBatchedEnvironment(self, max_steps, max_episodes, expected_length):

    expected_trajectories = [
        trajectory.Trajectory(
            step_type=np.array([0, 0]),
            observation=np.array([0, 0]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([1, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([1., 1.])),
        trajectory.Trajectory(
            step_type=np.array([1, 1]),
            observation=np.array([2, 1]),
            action=np.array([1, 2]),
            policy_info=np.array([2, 4]),
            next_step_type=np.array([2, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([0., 1.])),
        trajectory.Trajectory(
            step_type=np.array([2, 1]),
            observation=np.array([3, 3]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([0, 2]),
            reward=np.array([0., 1.]),
            discount=np.array([1., 0.]))
    ]

    env1 = driver_test_utils.PyEnvironmentMock(final_state=3)
    env2 = driver_test_utils.PyEnvironmentMock(final_state=4)
    env = batched_py_environment.BatchedPyEnvironment([env1, env2])
    tf_env = tf_py_environment.TFPyEnvironment(env)

    policy = driver_test_utils.TFPolicyMock(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        batch_size=2,
        initial_policy_state=tf.constant([1, 2], dtype=tf.int32))

    replay_buffer_observer = MockReplayBufferObserver()

    driver = tf_driver.TFDriver(
        tf_env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=max_steps,
        max_episodes=max_episodes,
    )
    initial_time_step = tf_env.reset()
    initial_policy_state = tf.constant([1, 2], dtype=tf.int32)
    self.evaluate(driver.run(initial_time_step, initial_policy_state))
    trajectories = replay_buffer_observer.gather_all()

    self.assertEqual(
        len(trajectories), len(expected_trajectories[:expected_length]))

    for t1, t2 in zip(trajectories, expected_trajectories[:expected_length]):
      for t1_field, t2_field in zip(t1, t2):
        self.assertAllEqual(t1_field, t2_field)
Ejemplo n.º 5
0
    def testValueErrorOnInvalidArgs(self, max_steps, max_episodes):
        env = driver_test_utils.PyEnvironmentMock()
        tf_env = tf_py_environment.TFPyEnvironment(env)

        policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(),
                                                tf_env.action_spec())

        replay_buffer_observer = MockReplayBufferObserver()
        with self.assertRaises(ValueError):
            tf_driver.TFDriver(
                tf_env,
                policy,
                observers=[replay_buffer_observer],
                max_steps=max_steps,
                max_episodes=max_episodes,
            )
Ejemplo n.º 6
0
def eager_compute(metrics,
                  environment,
                  policy,
                  num_episodes=1,
                  train_step=None,
                  summary_writer=None,
                  summary_prefix='',
                  use_function=True):
    """Compute metrics using `policy` on the `environment`.

  *NOTE*: Because placeholders are not compatible with Eager mode we can not use
  python policies. Because we use tf_policies we need the environment time_steps
  to be tensors making it easier to use a tf_env for evaluations. Otherwise this
  method mirrors `compute` directly.

  Args:
    metrics: List of metrics to compute.
    environment: tf_environment instance.
    policy: tf_policy instance used to step the environment.
    num_episodes: Number of episodes to compute the metrics over.
    train_step: An optional step to write summaries against.
    summary_writer: An optional writer for generating metric summaries.
    summary_prefix: An optional prefix scope for metric summaries.
    use_function: Option to enable use of `tf.function` when collecting the
      metrics.
  Returns:
    A dictionary of results {metric_name: metric_value}
  """
    for metric in metrics:
        metric.reset()

    multiagent_metrics = [m for m in metrics if 'Multiagent' in m.name]

    driver = tf_driver.TFDriver(environment,
                                policy,
                                observers=metrics,
                                max_episodes=num_episodes,
                                disable_tf_function=not use_function)

    def run_driver():
        time_step = environment.reset()
        policy_state = policy.get_initial_state(environment.batch_size)
        driver.run(time_step, policy_state)

    if use_function:
        common.function(run_driver)()
    else:
        run_driver()

    results = [(metric.name, metric.result()) for metric in metrics]
    for m in multiagent_metrics:
        for a in range(m.n_agents):
            results.append((m.name + '_agent' + str(a), m.result_for_agent(a)))

    # TODO(b/120301678) remove the summaries and merge with compute
    if train_step and summary_writer:
        with summary_writer.as_default():
            for m in metrics:
                tag = common.join_scope(summary_prefix, m.name)
                tf.compat.v2.summary.scalar(name=tag,
                                            data=m.result(),
                                            step=train_step)
                if 'Multiagent' in m.name:
                    for a in range(m.n_agents):
                        tf.compat.v2.summary.scalar(name=tag + '_agent' +
                                                    str(a),
                                                    data=m.result_for_agent(a),
                                                    step=train_step)
    # TODO(b/130249101): Add an option to log metrics.
    return collections.OrderedDict(results)
def train_eval(
        root_dir,
        env_name='MultiGrid-Empty-5x5-v0',
        env_load_fn=multiagent_gym_suite.load,
        random_seed=0,
        # Architecture params
        agent_class=multiagent_ppo.MultiagentPPO,
        actor_fc_layers=(64, 64),
        value_fc_layers=(64, 64),
        lstm_size=(64, ),
        conv_filters=64,
        conv_kernel=3,
        direction_fc=5,
        entropy_regularization=0.,
        use_attention_networks=False,
        # Specialized agents
        inactive_agent_ids=tuple(),
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=5,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=2,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=2,
        eval_interval=5,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        log_interval=10,
        summary_interval=10,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=True,
        summarize_grads_and_vars=True,
        eval_metrics_callback=None,
        reinit_checkpoint_dir=None,
        debug=True):
    """A simple train and eval for PPO."""
    tf.compat.v1.enable_v2_behavior()

    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    if debug:
        logging.info('In debug mode, turning tf_functions off')
        use_tf_functions = False

    for a in inactive_agent_ids:
        logging.info('Fixing and not training agent %d', a)

    # Load multiagent gym environment and determine number of agents
    gym_env = env_load_fn(env_name)
    n_agents = gym_env.n_agents

    # Set up logging
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        multiagent_metrics.AverageReturnMetric(n_agents,
                                               buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        logging.info('Creating %d environments...', num_parallel_environments)
        wrappers = []
        if use_attention_networks:
            wrappers = [
                lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size)
            ]

        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(env_name,
                        gym_kwargs=dict(seed=random_seed),
                        gym_env_wrappers=wrappers))
        # pylint: disable=g-complex-comprehension
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                functools.partial(env_load_fn,
                                  environment_name=env_name,
                                  gym_env_wrappers=wrappers,
                                  gym_kwargs=dict(seed=random_seed * 1234 + i))
                for i in range(num_parallel_environments)
            ]))

        logging.info('Preparing to train...')
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        bonus_metrics = [
            multiagent_metrics.MultiagentScalar(n_agents,
                                                name='UnscaledMultiagentBonus',
                                                buffer_size=1000),
        ]
        train_metrics = step_metrics + [
            multiagent_metrics.AverageReturnMetric(
                n_agents, batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        logging.info('Creating agent...')
        tf_agent = agent_class(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            n_agents=n_agents,
            learning_rate=learning_rate,
            actor_fc_layers=actor_fc_layers,
            value_fc_layers=value_fc_layers,
            lstm_size=lstm_size,
            conv_filters=conv_filters,
            conv_kernel=conv_kernel,
            direction_fc=direction_fc,
            entropy_regularization=entropy_regularization,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
            inactive_agent_ids=inactive_agent_ids)
        tf_agent.initialize()
        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        logging.info('Allocating replay buffer ...')
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        logging.info('RB capacity: %i', replay_buffer.capacity)

        # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is
        # reinitialized. The other agents are novices.
        # Otherwise, all agents are reinitialized from train_dir.
        if reinit_checkpoint_dir:
            reinit_checkpointer = common.Checkpointer(
                ckpt_dir=reinit_checkpoint_dir,
                agent=tf_agent,
            )
            reinit_checkpointer.initialize_or_restore()
            temp_dir = os.path.join(train_dir, 'tmp')
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir,
                agent=tf_agent.agents[:-1],
            )
            agent_checkpointer.save(global_step=0)
            tf_agent = agent_class(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                n_agents=n_agents,
                learning_rate=learning_rate,
                actor_fc_layers=actor_fc_layers,
                value_fc_layers=value_fc_layers,
                lstm_size=lstm_size,
                conv_filters=conv_filters,
                conv_kernel=conv_kernel,
                direction_fc=direction_fc,
                entropy_regularization=entropy_regularization,
                num_epochs=num_epochs,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                inactive_agent_ids=inactive_agent_ids,
                non_learning_agents=list(range(n_agents - 1)))
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir, agent=tf_agent.agents[:-1])
            agent_checkpointer.initialize_or_restore()
            tf.io.gfile.rmtree(temp_dir)
            eval_policy = tf_agent.policy
            collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=multiagent_metrics.MultiagentMetricsGroup(
                train_metrics + bonus_metrics, 'train_metrics'))
        if not reinit_checkpoint_dir:
            train_checkpointer.initialize_or_restore()
        logging.info('Successfully initialized train checkpointer')

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        collect_policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=collect_policy,
            global_step=global_step)
        collect_saved_model = policy_saver.PolicySaver(collect_policy,
                                                       train_step=global_step)

        logging.info('Successfully initialized policy saver.')

        print('Using TFDriver')
        if use_attention_networks:
            collect_driver = drivers.StateTFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)
        else:
            collect_driver = tf_driver.TFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0

        # Save operative config as late as possible to include used configurables.
        if global_step.numpy() == 0:
            config_filename = os.path.join(
                train_dir,
                'operative_config-{}.gin'.format(global_step.numpy()))
            with tf.io.gfile.GFile(config_filename, 'wb') as f:
                f.write(gin.operative_config_str())

        total_episodes = 0
        logging.info('Commencing train loop!')
        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()

            # Evaluation
            if global_step_val % eval_interval == 0:
                if debug:
                    logging.info('Performing evaluation at step %d',
                                 global_step_val)
                results = multiagent_metrics.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                    use_function=use_tf_functions,
                    use_attention_networks=use_attention_networks)
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                multiagent_metrics.log_metrics(eval_metrics)

            # Collect data
            if debug:
                logging.info('Collecting at step %d', global_step_val)
            start_time = time.time()
            time_step = tf_env.reset()
            policy_state = collect_policy.get_initial_state(tf_env.batch_size)
            if use_attention_networks:
                # Attention networks require previous policy state to compute attention
                # weights.
                time_step.observation['policy_state'] = (
                    policy_state['actor_network_state'][0],
                    policy_state['actor_network_state'][1])
            collect_driver.run(time_step, policy_state)
            collect_time += time.time() - start_time

            total_episodes += collect_episodes_per_iteration
            if debug:
                logging.info('Have collected a total of %d episodes',
                             total_episodes)

            # Train
            if debug:
                logging.info('Training at step %d', global_step_val)
            start_time = time.time()
            total_loss, extra_loss = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    logging.info(
                        'Loss diverged for too many timesteps, breaking...')
                    break
            else:
                loss_divergence_counter = 0

            for train_metric in train_metrics + bonus_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, total loss = %f', global_step_val,
                             total_loss)
                for a in range(n_agents):
                    if not inactive_agent_ids or a not in inactive_agent_ids:
                        logging.info('Loss for agent %d = %f', a,
                                     extra_loss[a].loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                saved_model_path = os.path.join(
                    saved_model_dir,
                    'policy_' + ('%d' % global_step_val).zfill(9))
                saved_model.save(saved_model_path)
                collect_policy_checkpointer.save(global_step=global_step_val)
                collect_saved_model_path = os.path.join(
                    saved_model_dir,
                    'collect_policy_' + ('%d' % global_step_val).zfill(9))
                collect_saved_model.save(collect_saved_model_path)

        # One final eval before exiting.
        results = multiagent_metrics.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
            use_function=use_tf_functions,
            use_attention_networks=use_attention_networks)
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        multiagent_metrics.log_metrics(eval_metrics)