def make_possibly_parallel_environment(env_name_):
     """Returns a function creating env_name_, possibly a parallel one."""
     if num_parallel_environments == 1:
         return env_load_fn(env_name_)
     else:
         return parallel_py_environment.ParallelPyEnvironment(
             [lambda: env_load_fn(env_name_)] * num_parallel_environments)
def train_eval_bomberman(root_dir,
                         num_parallel_environments=4,
                         summary_interval=1000):
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    ckpt_dir = os.path.join(root_dir, 'checkpoint')
    policy_dir = os.path.join(root_dir, 'policy')

    train_summary_writer = tf.summary.create_file_writer(train_dir,
                                                         flush_millis=1000)
    train_summary_writer.set_as_default()
    eval_summary_writer = tf.summary.create_file_writer(eval_dir)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=10),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=10)
    ]

    global_step = tf.Variable(0)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [BombermanEnvironment] * num_parallel_environments))
        eval_tf_env = BombermanEnvironment()

        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
Beispiel #3
0
def train_eval(root_dir, env_name, env_load_fn, num_parallel_environments):
    global_step = tf.compat.v1.train.get_or_create_global_step()

    print("oof")
    eval_py_env = parallel_py_environment.ParallelPyEnvironment(
        [lambda: env_load_fn(env_name)] * num_parallel_environments)
    print("oof")
Beispiel #4
0
  def _build_configuration(self):
    """Builds a configuration using an SAC agent
    """
    self._scenario_generator = \
      DeterministicDroneChallengeGeneration(num_scenarios=3,
                                            random_seed=0,
                                            params=self._params)
    self._observer = CustomObserver(params=self._params)
    self._behavior_model = DynamicModel(model_name="TripleIntegratorModel",
                                        params=self._params)
    self._evaluator = CustomEvaluator(params=self._params)

    viewer = MPViewer(params=self._params,
                      x_range=[-20, 20],
                      y_range=[-20, 20],
                      follow_agent_id=True)
    self._viewer = viewer
    # self._viewer = VideoRenderer(renderer=viewer, world_step_time=0.2)
    self._runtime = RuntimeRL(action_wrapper=self._behavior_model,
                              observer=self._observer,
                              evaluator=self._evaluator,
                              step_time=0.2,
                              viewer=self._viewer,
                              scenario_generator=self._scenario_generator)
    # tfa_env = tf_py_environment.TFPyEnvironment(TFAWrapper(self._runtime))
    tfa_env = tf_py_environment.TFPyEnvironment(
      parallel_py_environment.ParallelPyEnvironment(
        [lambda: TFAWrapper(self._runtime)] * self._params["ML"]["Agent"]["num_parallel_environments"]))
    self._agent = SACAgent(tfa_env, params=self._params)
    self._runner = SACRunner(tfa_env,
                             self._agent,
                             params=self._params,
                             unwrapped_runtime=self._runtime)
Beispiel #5
0
def create_environment(env_name='CartPole-v0',
                       env_load_fn=suite_gym.load,
                       num_parallel_environments=30,
                       nonparallel=False):
    """Create environment.

    Args:
        env_name (str): env name
        env_load_fn (Callable) : callable that create an environment
        num_parallel_environments (int): num of parallel environments
        nonparallel (bool): force to create a single env in the current
            process. Used for correctly exposing game gin confs to tensorboard.

    Returns:
        TFPyEnvironment
    """
    if nonparallel:
        # Each time we can only create one unwrapped env at most

        # Create and step the env in a separate thread. env `step` and `reset` must
        #   run in the same thread which the env is created in for some simulation
        #   environments such as social_bot(gazebo)
        py_env = ThreadPyEnvironment(lambda: env_load_fn(env_name))
        py_env.seed(np.random.randint(0, np.iinfo(np.int32).max))
    else:
        py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)

        py_env.seed([
            np.random.randint(0,
                              np.iinfo(np.int32).max)
            for i in range(num_parallel_environments)
        ])

    return tf_py_environment.TFPyEnvironment(py_env)
Beispiel #6
0
    def test_parallel_envs(self):
        env_num = 5

        ctors = [
            lambda: suite_socialbot.load('SocialBot-CartPole-v0',
                                         wrap_with_process=False)
        ] * env_num

        self._env = parallel_py_environment.ParallelPyEnvironment(
            env_constructors=ctors, start_serially=False)
        tf_env = tf_py_environment.TFPyEnvironment(self._env)

        self.assertTrue(tf_env.batched)
        self.assertEqual(tf_env.batch_size, env_num)

        random_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        replay_buffer_capacity = 100
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            random_policy.trajectory_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        steps = 100
        step_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            random_policy,
            observers=[replay_buffer.add_batch],
            num_steps=steps)
        step_driver.run = common.function(step_driver.run)
        step_driver.run()

        self.assertIsNotNone(replay_buffer.get_next())
Beispiel #7
0
def get_env(env_name,
            max_episode_steps=None,
            constant_task=None,
            num_parallel_environments=1):
    """Loads the environment.

  Args:
    env_name: (str) name of the environment.
    max_episode_steps: (int) maximum number of steps per episode. Set to None
      to not include a limit.
    constant_task: specifies a fixed task to use for all episodes. Set to None
      to use tasks sampled from the task distribution.
    num_parallel_environments: (int) Number of parallel environments.
  Returns:
    tf_env: the environment, build from a dynamics and task distribution. This
      environment is an instance of TFPyEnvironment.
    task_distribution: the task distribution used for the environment.
  """
    def env_load_fn(return_task_distribution=False):
        py_env, task_distribution = get_py_env(
            env_name,
            max_episode_steps=max_episode_steps,
            constant_task=constant_task)
        if return_task_distribution:
            return (py_env, task_distribution)
        else:
            return py_env

    py_env, task_distribution = env_load_fn(return_task_distribution=True)
    if num_parallel_environments > 1:
        del py_env
        py_env = parallel_py_environment.ParallelPyEnvironment(
            [env_load_fn] * num_parallel_environments)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    return tf_env, task_distribution
Beispiel #8
0
def get_tf_env():
    def _load_env():
        return test_env.CountingEnv(steps_per_episode=10)

    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment([lambda: _load_env()] *
                                                      FP.NUM_PARALLEL_ENVS))
    return tf_env
 def test_checks_constructors(self):
     self._set_default_specs()
     # pytype: disable=wrong-arg-types
     with self.assertRaisesRegex(TypeError, '.*non-callable.*'):
         parallel_py_environment.ParallelPyEnvironment([
             random_py_environment.RandomPyEnvironment(
                 self.observation_spec, self.action_spec)
         ])
 def _make_parallel_py_environment(self,
                                   constructor=None,
                                   num_envs=2,
                                   blocking=True):
     self._set_default_specs()
     constructor = constructor or functools.partial(
         random_py_environment.RandomPyEnvironment, self.observation_spec,
         self.action_spec)
     return parallel_py_environment.ParallelPyEnvironment(
         env_constructors=[constructor] * num_envs, blocking=blocking)
Beispiel #11
0
 def test_dmlab_env(self):
     ctor = lambda: suite_dmlab.load(scene='lt_chasm',
                                     gym_env_wrappers=[
                                         wrappers.FrameGrayScale, wrappers.
                                         FrameResize, wrappers.FrameStack
                                     ],
                                     wrap_with_process=False)
     self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 2)
     env = tf_py_environment.TFPyEnvironment(self._env)
     self.assertEqual((84, 84, 4), env.observation_spec().shape)
Beispiel #12
0
 def _make_parallel_py_environment(self, constructor=None, num_envs=2):
     self.observation_spec = array_spec.ArraySpec((3, 3), np.float32)
     self.time_step_spec = ts.time_step_spec(self.observation_spec)
     self.action_spec = array_spec.BoundedArraySpec([7],
                                                    dtype=np.float32,
                                                    minimum=-1.0,
                                                    maximum=1.0)
     constructor = constructor or functools.partial(
         random_py_environment.RandomPyEnvironment, self.observation_spec,
         self.action_spec)
     return parallel_py_environment.ParallelPyEnvironment(
         env_constructors=[constructor] * num_envs, blocking=True)
Beispiel #13
0
    def test_dmlab_env_run(self, scene):
        ctor = lambda: suite_dmlab.load(scene=scene,
                                        gym_env_wrappers=
                                        [wrappers.FrameResize],
                                        wrap_with_process=False)

        self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 4)
        env = tf_py_environment.TFPyEnvironment(self._env)
        self.assertEqual((84, 84, 3), env.observation_spec().shape)

        random_policy = random_tf_policy.RandomTFPolicy(
            env.time_step_spec(), env.action_spec())

        driver = dynamic_step_driver.DynamicStepDriver(env=env,
                                                       policy=random_policy,
                                                       observers=None,
                                                       num_steps=10)

        driver.run(maximum_iterations=10)
Beispiel #14
0
 def _build_configuration(self):
     """Builds a configuration using an PPO agent
 """
     # self._runtime = RuntimeRL(action_wrapper=self._behavior_model,
     #                           observer=self._observer,
     #                           evaluator=self._evaluator,
     #                           step_time=0.2,
     #                           viewer=self._viewer,
     #                           scenario_generator=self._scenario_generator)
     self._runtime = gym.make('Pendulum-v0')
     # tfa_env = tf_py_environment.TFPyEnvironment(TFAWrapper(self._runtime))
     tfa_env = tf_py_environment.TFPyEnvironment(
         parallel_py_environment.ParallelPyEnvironment(
             [lambda: TFAWrapper(self._runtime)] *
             self._params["ML"]["Agent"]["num_parallel_environments", "",
                                         0]))
     self._agent = PPOAgent(tfa_env, params=self._params)
     self._runner = PPORunner(tfa_env,
                              self._agent,
                              params=self._params,
                              unwrapped_runtime=self._runtime)
Beispiel #15
0
def create_environment(env_name='CartPole-v0',
                       env_load_fn=suite_gym.load,
                       num_parallel_environments=30):
    """Create environment.

    Args:
        env_name (str): env name
        env_load_fn (Callable) : callable that create an environment
        num_parallel_environments (int): num of parallel environments
    """
    if num_parallel_environments == 1:
        py_env = env_load_fn(env_name)
    else:
        if env_load_fn == suite_socialbot.load:
            logging.info("suite_socialbot environment")
            # No need to wrap with process since ParallelPyEnvironment will do it
            env_load_fn = lambda env_name: suite_socialbot.load(
                env_name, wrap_with_process=False)
        py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
    return tf_py_environment.TFPyEnvironment(py_env)
Beispiel #16
0
    def test_mario_env(self):
        ctor = lambda: suite_mario.load(
            'SuperMarioBros-Nes', 'Level1-1', wrap_with_process=False)

        self._env = parallel_py_environment.ParallelPyEnvironment([ctor] * 4)
        env = tf_py_environment.TFPyEnvironment(self._env)
        self.assertEqual(np.uint8, env.observation_spec().dtype)
        self.assertEqual((84, 84, 4), env.observation_spec().shape)

        random_policy = random_tf_policy.RandomTFPolicy(
            env.time_step_spec(), env.action_spec())

        metrics = [
            AverageReturnMetric(batch_size=4),
            AverageEpisodeLengthMetric(batch_size=4),
            EnvironmentSteps(),
            NumberOfEpisodes()
        ]
        driver = dynamic_step_driver.DynamicStepDriver(env, random_policy,
                                                       metrics, 10000)
        driver.run(maximum_iterations=10000)
Beispiel #17
0
def create_envs(env_name,
                use_multiprocessing,
                num_parallel_envs,
                visualize_eval=False,
                mock_train_envs=False):
    def env_load_fn(env_map_name, visualize=False, mock=False):
        env = gym_wrapper.GymWrapper(
            gym_env=SC2GymEnv(map_name=env_map_name,
                              visualize=visualize,
                              mock=mock),
            spec_dtype_map={
                gym.spaces.Box: np.float32,
                gym.spaces.Discrete: np.int32,
                gym.spaces.MultiBinary: np.float32
            },
        )
        return env

    if num_parallel_envs == 1:
        par_env = env_load_fn(env_map_name=env_name, mock=mock_train_envs)
    elif use_multiprocessing:
        par_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_map_name=env_name, mock=mock_train_envs)
             ] * num_parallel_envs,
            start_serially=False)
    else:
        par_env = batched_py_environment.BatchedPyEnvironment(envs=[
            env_load_fn(env_map_name=env_name, mock=mock_train_envs)
            for _ in range(num_parallel_envs)
        ])
    tf_env = tf_py_environment.TFPyEnvironment(par_env)
    tf_env.reset()

    eval_env = env_load_fn(env_name, visualize=visualize_eval)
    eval_env = tf_py_environment.TFPyEnvironment(eval_env)
    eval_env.reset()

    return tf_env, eval_env
Beispiel #18
0
def test():
    num_episodes = 5
    py_env = parallel_py_environment.ParallelPyEnvironment(
        [lambda: point_mass.env_load_fn() for _ in range(num_episodes)])
    env = tf_py_environment.TFPyEnvironment(py_env)
    policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(),
                                             env.action_spec())

    traj_spec = trajectory.from_transition(env.time_step_spec(),
                                           policy.policy_step_spec,
                                           env.time_step_spec())
    rb = episodic_replay_buffer.EpisodicReplayBuffer(traj_spec)
    srb = episodic_replay_buffer.StatefulEpisodicReplayBuffer(
        rb, num_episodes=num_episodes)
    rb2 = tf_uniform_replay_buffer.TFUniformReplayBuffer(traj_spec, 1)

    driver = safe_dynamic_episode_driver.SafeDynamicEpisodeDriver(
        env,
        policy,
        rb,
        rb2,
        observers=[srb.add_batch],
        num_episodes=num_episodes)
    driver.run()
def train_eval(
        root_dir,
        env_name='MultiGrid-Empty-5x5-v0',
        env_load_fn=multiagent_gym_suite.load,
        random_seed=0,
        # Architecture params
        agent_class=multiagent_ppo.MultiagentPPO,
        actor_fc_layers=(64, 64),
        value_fc_layers=(64, 64),
        lstm_size=(64, ),
        conv_filters=64,
        conv_kernel=3,
        direction_fc=5,
        entropy_regularization=0.,
        use_attention_networks=False,
        # Specialized agents
        inactive_agent_ids=tuple(),
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=5,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=2,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=2,
        eval_interval=5,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        log_interval=10,
        summary_interval=10,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=True,
        summarize_grads_and_vars=True,
        eval_metrics_callback=None,
        reinit_checkpoint_dir=None,
        debug=True):
    """A simple train and eval for PPO."""
    tf.compat.v1.enable_v2_behavior()

    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    if debug:
        logging.info('In debug mode, turning tf_functions off')
        use_tf_functions = False

    for a in inactive_agent_ids:
        logging.info('Fixing and not training agent %d', a)

    # Load multiagent gym environment and determine number of agents
    gym_env = env_load_fn(env_name)
    n_agents = gym_env.n_agents

    # Set up logging
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        multiagent_metrics.AverageReturnMetric(n_agents,
                                               buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        logging.info('Creating %d environments...', num_parallel_environments)
        wrappers = []
        if use_attention_networks:
            wrappers = [
                lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size)
            ]

        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(env_name,
                        gym_kwargs=dict(seed=random_seed),
                        gym_env_wrappers=wrappers))
        # pylint: disable=g-complex-comprehension
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                functools.partial(env_load_fn,
                                  environment_name=env_name,
                                  gym_env_wrappers=wrappers,
                                  gym_kwargs=dict(seed=random_seed * 1234 + i))
                for i in range(num_parallel_environments)
            ]))

        logging.info('Preparing to train...')
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        bonus_metrics = [
            multiagent_metrics.MultiagentScalar(n_agents,
                                                name='UnscaledMultiagentBonus',
                                                buffer_size=1000),
        ]
        train_metrics = step_metrics + [
            multiagent_metrics.AverageReturnMetric(
                n_agents, batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        logging.info('Creating agent...')
        tf_agent = agent_class(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            n_agents=n_agents,
            learning_rate=learning_rate,
            actor_fc_layers=actor_fc_layers,
            value_fc_layers=value_fc_layers,
            lstm_size=lstm_size,
            conv_filters=conv_filters,
            conv_kernel=conv_kernel,
            direction_fc=direction_fc,
            entropy_regularization=entropy_regularization,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
            inactive_agent_ids=inactive_agent_ids)
        tf_agent.initialize()
        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        logging.info('Allocating replay buffer ...')
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        logging.info('RB capacity: %i', replay_buffer.capacity)

        # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is
        # reinitialized. The other agents are novices.
        # Otherwise, all agents are reinitialized from train_dir.
        if reinit_checkpoint_dir:
            reinit_checkpointer = common.Checkpointer(
                ckpt_dir=reinit_checkpoint_dir,
                agent=tf_agent,
            )
            reinit_checkpointer.initialize_or_restore()
            temp_dir = os.path.join(train_dir, 'tmp')
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir,
                agent=tf_agent.agents[:-1],
            )
            agent_checkpointer.save(global_step=0)
            tf_agent = agent_class(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                n_agents=n_agents,
                learning_rate=learning_rate,
                actor_fc_layers=actor_fc_layers,
                value_fc_layers=value_fc_layers,
                lstm_size=lstm_size,
                conv_filters=conv_filters,
                conv_kernel=conv_kernel,
                direction_fc=direction_fc,
                entropy_regularization=entropy_regularization,
                num_epochs=num_epochs,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                inactive_agent_ids=inactive_agent_ids,
                non_learning_agents=list(range(n_agents - 1)))
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir, agent=tf_agent.agents[:-1])
            agent_checkpointer.initialize_or_restore()
            tf.io.gfile.rmtree(temp_dir)
            eval_policy = tf_agent.policy
            collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=multiagent_metrics.MultiagentMetricsGroup(
                train_metrics + bonus_metrics, 'train_metrics'))
        if not reinit_checkpoint_dir:
            train_checkpointer.initialize_or_restore()
        logging.info('Successfully initialized train checkpointer')

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        collect_policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=collect_policy,
            global_step=global_step)
        collect_saved_model = policy_saver.PolicySaver(collect_policy,
                                                       train_step=global_step)

        logging.info('Successfully initialized policy saver.')

        print('Using TFDriver')
        if use_attention_networks:
            collect_driver = drivers.StateTFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)
        else:
            collect_driver = tf_driver.TFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0

        # Save operative config as late as possible to include used configurables.
        if global_step.numpy() == 0:
            config_filename = os.path.join(
                train_dir,
                'operative_config-{}.gin'.format(global_step.numpy()))
            with tf.io.gfile.GFile(config_filename, 'wb') as f:
                f.write(gin.operative_config_str())

        total_episodes = 0
        logging.info('Commencing train loop!')
        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()

            # Evaluation
            if global_step_val % eval_interval == 0:
                if debug:
                    logging.info('Performing evaluation at step %d',
                                 global_step_val)
                results = multiagent_metrics.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                    use_function=use_tf_functions,
                    use_attention_networks=use_attention_networks)
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                multiagent_metrics.log_metrics(eval_metrics)

            # Collect data
            if debug:
                logging.info('Collecting at step %d', global_step_val)
            start_time = time.time()
            time_step = tf_env.reset()
            policy_state = collect_policy.get_initial_state(tf_env.batch_size)
            if use_attention_networks:
                # Attention networks require previous policy state to compute attention
                # weights.
                time_step.observation['policy_state'] = (
                    policy_state['actor_network_state'][0],
                    policy_state['actor_network_state'][1])
            collect_driver.run(time_step, policy_state)
            collect_time += time.time() - start_time

            total_episodes += collect_episodes_per_iteration
            if debug:
                logging.info('Have collected a total of %d episodes',
                             total_episodes)

            # Train
            if debug:
                logging.info('Training at step %d', global_step_val)
            start_time = time.time()
            total_loss, extra_loss = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    logging.info(
                        'Loss diverged for too many timesteps, breaking...')
                    break
            else:
                loss_divergence_counter = 0

            for train_metric in train_metrics + bonus_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, total loss = %f', global_step_val,
                             total_loss)
                for a in range(n_agents):
                    if not inactive_agent_ids or a not in inactive_agent_ids:
                        logging.info('Loss for agent %d = %f', a,
                                     extra_loss[a].loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)
                saved_model_path = os.path.join(
                    saved_model_dir,
                    'policy_' + ('%d' % global_step_val).zfill(9))
                saved_model.save(saved_model_path)
                collect_policy_checkpointer.save(global_step=global_step_val)
                collect_saved_model_path = os.path.join(
                    saved_model_dir,
                    'collect_policy_' + ('%d' % global_step_val).zfill(9))
                collect_saved_model.save(collect_saved_model_path)

        # One final eval before exiting.
        results = multiagent_metrics.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
            use_function=use_tf_functions,
            use_attention_networks=use_attention_networks)
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        multiagent_metrics.log_metrics(eval_metrics)
Beispiel #20
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        lstm_size=(20, ),
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)
        eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None,
                lstm_size=lstm_size)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_clip_agent.PPOClipAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
    def __init__(self,
                 root_dir,
                 env_load_fn=suite_gym.load,
                 env_name='CartPole-v0',
                 num_parallel_environments=1,
                 agent_class=None,
                 num_eval_episodes=30,
                 write_summaries=True,
                 summaries_flush_secs=10,
                 eval_metrics_callback=None,
                 env_metric_factories=None):
        """Evaluate policy checkpoints as they are produced.

    Args:
      root_dir: Main directory for experiment files.
      env_load_fn: Function to load the environment specified by env_name.
      env_name: Name of environment to evaluate in.
      num_parallel_environments: Number of environments to evaluate on in
        parallel.
      agent_class: TFAgent class to instantiate for evaluation.
      num_eval_episodes: Number of episodes to average evaluation over.
      write_summaries: Whether to write summaries to the file system.
      summaries_flush_secs: How frequently to flush summaries (in seconds).
      eval_metrics_callback: A function that will be called with evaluation
        results for every checkpoint.
      env_metric_factories: An iterable of metric factories. Use this for eval
        metrics that needs access to the evaluated environment. A metric
        factory is a function that takes an eviornment and buffer_size as
        keyword arguments and returns an instance of py_metric.

    Raises:
      ValueError: when num_parallel_environments > num_eval_episodes or
        agent_class is not set
    """
        if not agent_class:
            raise ValueError(
                'The `agent_class` parameter of Evaluator must be set.')
        if num_parallel_environments > num_eval_episodes:
            raise ValueError(
                'num_parallel_environments should not be greater than '
                'num_eval_episodes')

        self._num_eval_episodes = num_eval_episodes
        self._eval_metrics_callback = eval_metrics_callback
        # Flag that controls eval cycle. If set, evaluation will exit eval loop
        # before the max checkpoint number is reached.
        self._terminate_early = False

        # Save root dir to self so derived classes have access to it.
        self._root_dir = os.path.expanduser(root_dir)
        train_dir = os.path.join(self._root_dir, 'train')
        self._eval_dir = os.path.join(self._root_dir, 'eval')

        self._global_step = tf.compat.v1.train.get_or_create_global_step()

        self._env_name = env_name
        if num_parallel_environments == 1:
            eval_env = env_load_fn(env_name)
        else:
            eval_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)

        if isinstance(eval_env, py_environment.PyEnvironment):
            self._eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env)
            self._eval_py_env = eval_env
        else:
            self._eval_tf_env = eval_env
            self._eval_py_env = None  # Can't generically convert to PyEnvironment.

        self._eval_metrics = [
            tf_metrics.AverageReturnMetric(
                buffer_size=self._num_eval_episodes,
                batch_size=self._eval_tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=self._num_eval_episodes,
                batch_size=self._eval_tf_env.batch_size),
        ]
        if env_metric_factories:
            if not self._eval_py_env:
                raise ValueError(
                    'The `env_metric_factories` parameter of Evaluator '
                    'can only be used with a PyEnvironment environment.')
            for metric_factory in env_metric_factories:
                py_metric = metric_factory(environment=self._eval_py_env,
                                           buffer_size=self._num_eval_episodes)
                self._eval_metrics.append(tf_py_metric.TFPyMetric(py_metric))

        if write_summaries:
            self._eval_summary_writer = tf.compat.v2.summary.create_file_writer(
                self._eval_dir, flush_millis=summaries_flush_secs * 1000)
            self._eval_summary_writer.set_as_default()
        else:
            self._eval_summary_writer = None

        environment_specs.set_observation_spec(
            self._eval_tf_env.observation_spec())
        environment_specs.set_action_spec(self._eval_tf_env.action_spec())

        # Agent params configured with gin.
        self._agent = agent_class(self._eval_tf_env.time_step_spec(),
                                  self._eval_tf_env.action_spec())

        self._eval_policy = greedy_policy.GreedyPolicy(self._agent.policy)
        self._eval_policy.action = common.function(self._eval_policy.action)

        # Run the agent on dummy data to force instantiation of the network. Keras
        # doesn't create variables until you first use the layer. This is needed
        # for checkpoint restoration to work.
        dummy_obs = tensor_spec.sample_spec_nest(
            self._eval_tf_env.observation_spec(),
            outer_dims=(self._eval_tf_env.batch_size, ))
        self._eval_policy.action(
            ts.restart(dummy_obs, batch_size=self._eval_tf_env.batch_size),
            self._eval_policy.get_initial_state(self._eval_tf_env.batch_size))

        self._policy_checkpoint = tf.train.Checkpoint(
            policy=self._agent.policy, global_step=self._global_step)
        self._policy_checkpoint_dir = os.path.join(train_dir, 'policy')
Beispiel #22
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_allowlist='position',
        eval_env_name=None,
        num_iterations=1000000,
        # Params for networks.
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        num_parallel_environments=1,
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        critic_learning_rate=3e-4,
        train_sequence_length=20,
        actor_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for RNN SAC on DM control."""
    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_allowlist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_allowlist=[observations_allowlist])
            ]
        else:
            env_wrappers = []

        env_load_fn = functools.partial(suite_dm_control.load,
                                        task_name=task_name,
                                        env_wrappers=env_wrappers)

        if num_parallel_environments == 1:
            py_env = env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
        average_return = tf_metrics.AverageReturnMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix='Train'),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='Train',
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
            logging.info(
                'Initializing replay buffer by collecting experience for %d episodes '
                'with a random policy.', initial_collect_episodes)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            # Reduce filter_fn over full trajectory sampled. The sequence is kept only
            # if all elements except for the last one pass the filter. This is to
            # allow training on terminal steps.
            return tf.reduce_all(~trajectories.is_boundary()[:-1])

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size,
            num_steps=train_sequence_length + 1).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            start_env_steps = env_steps.result()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            episode_steps = env_steps.result() - start_env_steps
            # TODO(b/152648849)
            for _ in range(episode_steps):
                for _ in range(train_steps_per_iteration):
                    train_step()
                time_acc += time.time() - start_time

                if global_step.numpy() % log_interval == 0:
                    logging.info('env steps = %d, average return = %f',
                                 env_steps.result(), average_return.result())
                    env_steps_per_sec = (env_steps.result().numpy() -
                                         env_steps_before) / time_acc
                    logging.info('%.3f env steps/sec', env_steps_per_sec)
                    tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                                data=env_steps_per_sec,
                                                step=env_steps.result())
                    time_acc = 0
                    env_steps_before = env_steps.result().numpy()

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=env_steps.result())

                if global_step.numpy() % eval_interval == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, env_steps.numpy())
                    metric_utils.log_metrics(eval_metrics)

                global_step_val = global_step.numpy()
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
Beispiel #23
0
def train_eval(
        load_root_dir,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        train_metrics_callback=None,
        # SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        safety_critic_joint_fc_layers=None,
        safety_critic_lr=3e-4,
        safety_critic_bias_init_val=None,
        safety_critic_kernel_scale=None,
        n_envs=None,
        target_safety=0.2,
        fail_weight=None,
        # Params for train
        num_global_steps=10000,
        batch_size=256,
        # Params for eval
        run_eval=False,
        eval_metrics=[],
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        summary_interval=1000,
        monitor_interval=5000,
        summaries_flush_secs=10,
        debug_summaries=False,
        seed=None):

    if isinstance(agent_class, str):
        assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(
            agent_class)
        agent_class = ALGOS.get(agent_class)

    train_ckpt_dir = osp.join(load_root_dir, 'train')
    rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer')

    py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    if monitor:
        vid_path = os.path.join(load_root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    if run_eval:
        eval_dir = os.path.join(load_root_dir, 'eval')
        n_envs = n_envs or num_eval_episodes
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(prefix='EvalMetrics',
                                           buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='EvalMetrics',
                buffer_size=num_eval_episodes,
                batch_size=n_envs)
        ] + [
            tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name))
            for m in eval_metrics
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            seeds = [seed * n_envs + i for i in range(n_envs)]
            try:
                eval_tf_env.pyenv.seed(seeds)
            except:
                pass

    global_step = tf.compat.v1.train.get_or_create_global_step()

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)

    critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
        safety_critic_net = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers)
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               safety_critic_network=safety_critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)
    else:
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)

    collect_data_spec = tf_agent.collect_data_spec
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=1000000)
    replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer)

    tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent)
    if agent_class in SAFETY_AGENTS:
        target_safety = target_safety or tf_agent._target_safety
    loaded_train_steps = global_step.numpy()
    logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir,
                 loaded_train_steps)
    global_step.assign(0)
    tf.summary.experimental.set_step(global_step)

    thresholds = [target_safety, 0.5]
    sc_metrics = [
        tf.keras.metrics.AUC(name='safety_critic_auc'),
        tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                        threshold=0.5),
        tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                       thresholds=thresholds),
        tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                        thresholds=thresholds),
        tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                       thresholds=thresholds),
        tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                        thresholds=thresholds)
    ]

    if seed:
        tf.compat.v1.set_random_seed(seed)

    summaries_flush_secs = 10
    timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
    offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp)
    config_saver = gin.tf.GinConfigSaverHook(offline_train_dir,
                                             summarize_config=True)
    tf.function(config_saver.after_create_session)()

    sc_summary_writer = tf.compat.v2.summary.create_file_writer(
        offline_train_dir, flush_millis=summaries_flush_secs * 1000)
    sc_summary_writer.set_as_default()

    if safety_critic_kernel_scale is not None:
        ki = tf.compat.v1.variance_scaling_initializer(
            scale=safety_critic_kernel_scale,
            mode='fan_in',
            distribution='truncated_normal')
    else:
        ki = tf.compat.v1.keras.initializers.VarianceScaling(
            scale=1. / 3., mode='fan_in', distribution='uniform')

    if safety_critic_bias_init_val is not None:
        bi = tf.constant_initializer(safety_critic_bias_init_val)
    else:
        bi = None
    sc_net_off = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=safety_critic_joint_fc_layers,
        kernel_initializer=ki,
        value_bias_initializer=bi,
        name='SafetyCriticOffline')
    sc_net_off.create_variables()
    target_sc_net_off = common.maybe_copy_target_network_with_checks(
        sc_net_off, None, 'TargetSafetyCriticNetwork')
    optimizer = tf.keras.optimizers.Adam(safety_critic_lr)
    sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic')
    sc_checkpointer = common.Checkpointer(
        ckpt_dir=sc_net_off_ckpt_dir,
        safety_critic=sc_net_off,
        target_safety_critic=target_sc_net_off,
        optimizer=optimizer,
        global_step=global_step,
        max_to_keep=5)
    sc_checkpointer.initialize_or_restore()

    resample_counter = py_metrics.CounterMetric('ActionResampleCounter')
    eval_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=actor_net,
        safety_critic_network=sc_net_off,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       num_steps=2,
                                       sample_batch_size=batch_size //
                                       2).prefetch(3)
    data = iter(dataset)
    full_data = replay_buffer.gather_all()

    fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool)
    fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, fail_mask), full_data)
    init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data)
    before_fail_mask = tf.roll(fail_mask, [-1], axis=[1])
    after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1])
    before_fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data)
    after_init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, after_init_mask), full_data)

    filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask))
    filter_mask = tf.pad(
        filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]])
    n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy()

    failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec,
        batch_size=1,
        max_length=n_failures,
        dataset_window_shift=1)
    data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask)

    sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3,
                                               sample_batch_size=batch_size //
                                               2,
                                               num_steps=2).prefetch(3)
    neg_data = iter(sc_dataset_neg)

    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]
    eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step,
                                after_init_step, get_action)

    losses = []
    mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss')
    target_update = train_utils.get_target_updater(sc_net_off,
                                                   target_sc_net_off)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        while global_step.numpy() < num_global_steps:
            pos_experience, _ = next(data)
            neg_experience, _ = next(neg_data)
            exp = data_utils.concat_batches(pos_experience, neg_experience,
                                            collect_data_spec)
            boundary_mask = tf.logical_not(exp.is_boundary()[:, 0])
            exp = nest_utils.fast_map_structure(
                lambda *x: tf.boolean_mask(*x, boundary_mask), exp)
            safe_rew = exp.observation['task_agn_rew'][:, 1]
            if fail_weight:
                weights = tf.where(tf.cast(safe_rew, tf.bool),
                                   fail_weight / 0.5, (1 - fail_weight) / 0.5)
            else:
                weights = None
            train_loss, sc_loss, lam_loss = train_step(
                exp,
                safe_rew,
                tf_agent,
                sc_net=sc_net_off,
                target_sc_net=target_sc_net_off,
                metrics=sc_metrics,
                weights=weights,
                target_safety=target_safety,
                optimizer=optimizer,
                target_update=target_update,
                debug_summaries=debug_summaries)
            global_step.assign_add(1)
            global_step_val = global_step.numpy()
            losses.append(
                (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy()))
            mean_loss(train_loss)
            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='sc_loss',
                                            data=sc_loss,
                                            step=global_step_val)
                tf.compat.v2.summary.scalar(name='lam_loss',
                                            data=lam_loss,
                                            step=global_step_val)
                if global_step_val % summary_interval == 0:
                    tf.compat.v2.summary.scalar(name=mean_loss.name,
                                                data=mean_loss.result(),
                                                step=global_step_val)
            if global_step_val % summary_interval == 0:
                with tf.name_scope('Metrics'):
                    for metric in sc_metrics:
                        if len(tf.squeeze(metric.result()).shape) == 0:
                            tf.compat.v2.summary.scalar(name=metric.name,
                                                        data=metric.result(),
                                                        step=global_step_val)
                        else:
                            fmt_str = '_{}'.format(thresholds[0])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[0],
                                step=global_step_val)
                            fmt_str = '_{}'.format(thresholds[1])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[1],
                                step=global_step_val)
                        metric.reset_states()
            if global_step_val % eval_interval == 0:
                eval_sc(sc_net_off, step=global_step_val)
                if run_eval:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='EvalMetrics',
                    )
                    if train_metrics_callback is not None:
                        train_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics)
                    with eval_summary_writer.as_default():
                        for eval_metric in eval_metrics[2:]:
                            eval_metric.tf_summaries(
                                train_step=global_step,
                                step_metrics=eval_metrics[:2])
            if monitor and global_step_val % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                logging.debug(
                    'saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                    global_step_val, ep_len,
                    time.time() - monitor_start)

            if global_step_val % train_checkpoint_interval == 0:
                sc_checkpointer.save(global_step=global_step_val)
Beispiel #24
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        eval_env_name=None,
        env_load_fn=suite_mujoco.load,
        num_iterations=2000000,
        actor_fc_layers=(400, 300),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_parallel_environments=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for checkpoints, summaries, and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if num_parallel_environments > 1:
            tf_env = tf_py_environment.TFPyEnvironment(
                parallel_py_environment.ParallelPyEnvironment(
                    [lambda: env_load_fn(env_name)] *
                    num_parallel_environments))
        else:
            tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_network.CriticNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

        return train_loss
def train_eval(
        root_dir,
        env_load_fn=get_env,
        random_seed=None,
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=10,
        num_parallel_environments=10,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=10,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        policy_save_interval=10000,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):

    if random_seed is not None:
        tf.set_random_seed(random_seed)

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    logging.info('Running %d episodes in parallel' % num_parallel_environments)
    logging.info('Collecting %d episodes per step' %
                 collect_episodes_per_iteration)
    logging.info('Using replay buffer capacity of %d' % replay_buffer_capacity)

    train_summary_writer = tf.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()
    eval_summary_writer = tf.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn())
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn()] * num_parallel_environments))

    actor_net, value_net = get_actor_and_value_network(
        tf_env.action_spec(), tf_env.observation_spec())

    train_steps = tf.Variable(0)
    with tf.summary.record_if(
            lambda: tf.math.equal(train_steps % summary_interval, 0)):
        tf_agent = get_agent(time_step_spec=tf_env.time_step_spec(),
                             action_spec=tf_env.action_spec(),
                             actor_net=actor_net,
                             value_net=value_net,
                             num_epochs=num_epochs,
                             step_counter=train_steps,
                             learning_rate=learning_rate)
        tf_agent.initialize()

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        step_metrics, train_metrics, eval_metrics = get_metrics(
            n_parallel_env=num_parallel_environments,
            num_eval_episodes=num_eval_episodes)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=train_steps,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=train_steps)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=train_steps)
        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_save_interval == 0:
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Beispiel #26
0
def train_eval(
        root_dir,
        gpu=0,
        env_load_fn=None,
        model_ids=None,
        eval_env_mode='headless',
        num_iterations=1000000,
        conv_layer_params=None,
        encoder_fc_layers=[256],
        actor_fc_layers=[400, 300],
        critic_obs_fc_layers=[400],
        critic_action_fc_layers=None,
        critic_joint_fc_layers=[300],
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_parallel_environments=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        eval_only=False,
        eval_deterministic=False,
        num_parallel_environments_eval=1,
        model_ids_eval=None,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=10000,
        rb_checkpoint_interval=50000,
        log_interval=100,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
        batched_py_metric.BatchedPyMetric(
            py_metrics.AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments_eval),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if model_ids is None:
            model_ids = [None] * num_parallel_environments
        else:
            assert len(model_ids) == num_parallel_environments, \
                'model ids provided, but length not equal to num_parallel_environments'

        if model_ids_eval is None:
            model_ids_eval = [None] * num_parallel_environments_eval
        else:
            assert len(model_ids_eval) == num_parallel_environments_eval,\
                'model ids eval provided, but length not equal to num_parallel_environments_eval'

        tf_py_env = [
            lambda model_id=model_ids[i]: env_load_fn(model_id, 'headless', gpu
                                                      )
            for i in range(num_parallel_environments)
        ]
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(tf_py_env))

        if eval_env_mode == 'gui':
            assert num_parallel_environments_eval == 1, 'only one GUI env is allowed'
        eval_py_env = [
            lambda model_id=model_ids_eval[i]: env_load_fn(
                model_id, eval_env_mode, gpu)
            for i in range(num_parallel_environments_eval)
        ]
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            eval_py_env)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        print('observation_spec', observation_spec)
        print('action_spec', action_spec)

        glorot_uniform_initializer = tf.compat.v1.keras.initializers.glorot_uniform(
        )
        preprocessing_layers = {
            'depth_seg':
            tf.keras.Sequential(
                mlp_layers(
                    conv_layer_params=conv_layer_params,
                    fc_layer_params=encoder_fc_layers,
                    kernel_initializer=glorot_uniform_initializer,
                )),
            'sensor':
            tf.keras.Sequential(
                mlp_layers(
                    conv_layer_params=None,
                    fc_layer_params=encoder_fc_layers,
                    kernel_initializer=glorot_uniform_initializer,
                )),
        }
        preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)

        actor_net = actor_network.ActorNetwork(
            observation_spec,
            action_spec,
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            fc_layer_params=actor_fc_layers,
            kernel_initializer=glorot_uniform_initializer,
        )

        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            preprocessing_layers=preprocessing_layers,
            preprocessing_combiner=preprocessing_combiner,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer=glorot_uniform_initializer,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.compat.v1.Session(config=config)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        if eval_deterministic:
            eval_py_policy = py_tf_policy.PyTFPolicy(
                greedy_policy.GreedyPolicy(tf_agent.policy))
        else:
            eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                buffer_size=100, batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=100, batch_size=num_parallel_environments),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps * num_parallel_environments).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration *
            num_parallel_environments).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=5,
            sample_batch_size=5 * batch_size,
            num_steps=2).apply(tf.data.experimental.unbatch()).filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics))

        with eval_summary_writer.as_default(), tf.compat.v2.summary.record_if(
                True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=step_metrics)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        init_agent_op = tf_agent.initialize()
        with sess.as_default():
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)

            if eval_only:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=0,
                    callback=eval_metrics_callback,
                    tf_summaries=False,
                    log=True,
                )
                episodes = eval_py_env.get_stored_episodes()
                episodes = [
                    episode for sublist in episodes for episode in sublist
                ][:num_eval_episodes]
                metrics = episode_utils.get_metrics(episodes)
                for key in sorted(metrics.keys()):
                    print(key, ':', metrics[key])

                save_path = os.path.join(eval_dir, 'episodes_vis.pkl')
                episode_utils.save(episodes, save_path)
                print('EVAL DONE')
                return

            # Initialize training.
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)
            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=0,
                    callback=eval_metrics_callback,
                    tf_summaries=True,
                    log=True,
                )
                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = sess.run(global_step)
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                # print('collect:', time.time() - start_time)

                # train_start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                # print('train:', time.time() - train_start_time)

                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=0,
                        callback=eval_metrics_callback,
                        tf_summaries=True,
                        log=True,
                    )
                    with eval_summary_writer.as_default(
                    ), tf.compat.v2.summary.record_if(True):
                        with tf.name_scope('Metrics/'):
                            episodes = eval_py_env.get_stored_episodes()
                            episodes = [
                                episode for sublist in episodes
                                for episode in sublist
                            ][:num_eval_episodes]
                            metrics = episode_utils.get_metrics(episodes)
                            for key in sorted(metrics.keys()):
                                print(key, ':', metrics[key])
                                metric_op = tf.compat.v2.summary.scalar(
                                    name=key,
                                    data=metrics[key],
                                    step=global_step_val)
                                sess.run(metric_op)
                    sess.run(eval_summary_flush_op)

        sess.close()
Beispiel #27
0
def main():

    logging.set_verbosity(logging.INFO)
    tf.compat.v1.enable_v2_behavior()
    parser = argparse.ArgumentParser()

    ## Essential parameters
    parser.add_argument("--output_dir", default=None, type=str, required=True,help="The output directory where the model stats and checkpoints will be written.")
    parser.add_argument("--env", default=None, type=str, required=True,help="The environment to train the agent on")
    parser.add_argument("--max_horizon", default=4, type=int)
    parser.add_argument("--atari", default=False, type=bool, help = "Gets some data Types correctly")


    ##agent parameters
    parser.add_argument("--reward_scale_factor", default=1.0, type=float)
    parser.add_argument("--debug_summaries", default=False, type=bool)
    parser.add_argument("--summarize_grads_and_vars", default=False, type=bool)

    ##transformer parameters
    parser.add_argument("--d_model", default=64, type=int)
    parser.add_argument("--num_layers", default=3, type=int)
    parser.add_argument("--dff", default=256, type=int)

    ##Training parameters
    parser.add_argument('--num_iterations', type=int, default=100000,help="steps in the env")
    parser.add_argument('--num_parallel', type=int, default=30,help="how many envs should run in parallel")
    parser.add_argument("--collect_episodes_per_iteration", default=1, type=int)
    parser.add_argument('--num_epochs', type=int, default = 25,help = 'Number of epochs for computing policy updates.')


    ## Other parameters
    parser.add_argument("--num_eval_episodes", default=10, type=int)
    parser.add_argument("--eval_interval", default=1000, type=int)
    parser.add_argument("--log_interval", default=10, type=int)
    parser.add_argument("--summary_interval", default=1000, type=int)
    parser.add_argument("--run_graph_mode", default=True, type=bool)
    parser.add_argument("--checkpoint_interval", default=1000, type=int)
    parser.add_argument("--summary_flush", default=10, type=int)   #what does this exactly do? 

    # HP opt params
    #parser.add_argument("--doubleQ", default=True, type=bool,help="Whether to use a  DoubleQ agent")
    parser.add_argument("--custom_last_layer", default=True, type=bool)
    parser.add_argument("--custom_layer_init", default=1.0,type=    float)
    parser.add_argument("--initial_collect_steps", default=5000, type=int)
    #parser.add_argument("--loss_function", default="element_wise_huber_loss", type=str)
    parser.add_argument("--num_heads", default=4, type=int)
    parser.add_argument("--normalize_env", default=False, type=bool)  
    parser.add_argument('--custom_lr_schedule',default="No",type=str,help = "whether to use a custom LR schedule")
    #parser.add_argument("--epsilon_greedy", default=0.3, type=float)
    #parser.add_argument("--target_update_period", default=1000, type=int)
    parser.add_argument("--rate", default=0.1, type=float)  # dropout rate  (might be not used depending on the q network)  #Setting this to 0.0 somehow break the code. Not relevant tho just select a network without dropout
    parser.add_argument("--gradient_clipping", default=True, type=bool)
    parser.add_argument("--replay_buffer_max_length", default=1001, type=int)
    #parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--learning_rate", default=1e-4, type=float)
    parser.add_argument("--encoder_type", default=3, type=int,help="Which Type of encoder is used for the model")
    parser.add_argument("--layer_type", default=3, type=int,help="Which Type of layer is used for the encoder")
    #parser.add_argument("--target_update_tau", default=1, type=float)
    #parser.add_argument("--gamma", default=0.99, type=float)


    
    args = parser.parse_args()
    global_step = tf.compat.v1.train.get_or_create_global_step()
    
    baseEnv = gym.make(args.env)
    
    eval_tf_env = tf_py_environment.TFPyEnvironment(PyhistoryWrapper(suite_gym.load(args.env),args.max_horizon,args.atari))
        #[lambda: PyhistoryWrapper(suite_gym.load(args.env),args.max_horizon,args.atari)] * args.num_parallel)
    tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
            #[lambda: PyhistoryWrapper(suite_gym.load(args.env),args.max_horizon,args.atari)] * args.num_parallel))
            [lambda: PyhistoryWrapper(suite_gym.load(args.env),args.max_horizon,args.atari)] * args.num_parallel))
    
    
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        fc_layer_params=(200, 100),
        activation_fn=tf.keras.activations.tanh)
    value_net = value_network.ValueNetwork(
        tf_env.observation_spec(),
        fc_layer_params=(200, 100),
        activation_fn=tf.keras.activations.tanh)
    
    
    
    actor_net = QTransformer(
        tf_env.observation_spec(),
        baseEnv.action_space.n,
        num_layers=args.num_layers,
        d_model=args.d_model,
        num_heads=args.num_heads, 
        dff=args.dff,
        rate = args.rate,
        encoderType = args.encoder_type,
        enc_layer_type=args.layer_type,
        max_horizon=args.max_horizon,
        custom_layer = args.custom_layer_init, 
        custom_last_layer = args.custom_last_layer)

    if args.custom_lr_schedule == "Transformer":    # builds a lr schedule according to the original usage for the transformer
        learning_rate = CustomSchedule(args.d_model,int(args.num_iterations/10))
        optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

    elif args.custom_lr_schedule == "Transformer_low":    # builds a lr schedule according to the original usage for the transformer
        learning_rate = CustomSchedule(int(args.d_model/2),int(args.num_iterations/10)) # --> same schedule with lower general lr
        optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

    elif args.custom_lr_schedule == "Linear": 
        lrs = LinearCustomSchedule(learning_rate,args.num_iterations)
        optimizer = tf.keras.optimizers.Adam(lrs, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

    else:
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.learning_rate)




    tf_agent = ppo_clip_agent.PPOClipAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        optimizer,
        actor_net=actor_net,
        value_net=value_net,
        entropy_regularization=0.0,
        importance_ratio_clipping=0.2,
        normalize_observations=False,
        normalize_rewards=False,
        use_gae=True,
        num_epochs=args.num_epochs,
        debug_summaries=args.debug_summaries,
        summarize_grads_and_vars=args.summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()


    
    train_eval(
    args.output_dir,
    0, # ??
    # TODO(b/127576522): rename to policy_fc_layers.
    tf_agent,
    eval_tf_env,
    tf_env,
    # Params for collect
    args.num_iterations,
    args.collect_episodes_per_iteration,
    args.num_parallel,
    args.replay_buffer_max_length,  # Per-environment
    # Params for train
    args.num_epochs,
    args.learning_rate,
    # Params for eval
    args.num_eval_episodes,
    args.eval_interval,
    # Params for summaries and logging
    args.checkpoint_interval,
    args.checkpoint_interval,
    args.checkpoint_interval,
    args.log_interval,
    args.summary_interval,
    args.summary_flush,
    args.debug_summaries,
    args.summarize_grads_and_vars,
    args.run_graph_mode,
    None)
    

    
    pickle.dump(args,open(args.output_dir + "/training_args.p","wb"))
    print("Successfully trained and evaluation.")
Beispiel #28
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=50,
        rb_checkpoint_interval=200,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=step_metrics)

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss = sess.run(train_op)
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)
Beispiel #29
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v1',
    env_load_fn=suite_mujoco.load,
    num_iterations=2000000,
    actor_fc_layers=(400, 300),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    num_parallel_environments=1,
    replay_buffer_capacity=100000,
    ou_stddev=0.2,
    ou_damping=0.15,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=tf.losses.huber_loss,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for DDPG."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.contrib.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.contrib.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  # TODO(kbanoop): Figure out if it is possible to avoid the with block.
  with tf.contrib.summary.record_summaries_every_n_global_steps(
      summary_interval):
    if num_parallel_environments > 1:
      tf_env = tf_py_environment.TFPyEnvironment(
          parallel_py_environment.ParallelPyEnvironment(
              [lambda: env_load_fn(env_name)] * num_parallel_environments))
    else:
      tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    eval_py_env = env_load_fn(env_name)

    actor_net = actor_network.ActorNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers,
    )

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_network.CriticNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
    )

    tf_agent = ddpg_agent.DdpgAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        ou_stddev=ou_stddev,
        ou_damping=ou_damping,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec(),
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    global_step = tf.train.get_or_create_global_step()

    collect_policy = tf_agent.collect_policy()
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)

    iterator = dataset.make_initializable_iterator()
    trajectories, unused_info = iterator.get_next()
    train_op = tf_agent.train(
        experience=trajectories, train_step_counter=global_step)

    train_checkpointer = common_utils.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=tf.contrib.checkpoint.List(train_metrics))
    policy_checkpointer = common_utils.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy(),
        global_step=global_step)
    rb_checkpointer = common_utils.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    for train_metric in train_metrics:
      train_metric.tf_summaries(step_metrics=train_metrics[:2])
    summary_op = tf.contrib.summary.all_summary_ops()

    with eval_summary_writer.as_default(), \
         tf.contrib.summary.always_record_summaries():
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries()

    init_agent_op = tf_agent.initialize()

    with tf.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      # TODO(sguada) Remove once Periodically can be saved.
      common_utils.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      tf.contrib.summary.initialize(session=sess)
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_op, global_step])

      timed_at_step = sess.run(global_step)
      time_acc = 0
      steps_per_second_ph = tf.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.contrib.summary.scalar(
          name='global_steps/sec', tensor=steps_per_second_ph)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _, global_step_val = train_step_call()
        time_acc += time.time() - start_time

        if global_step_val % log_interval == 0:
          tf.logging.info('step = %d, loss = %f', global_step_val,
                          loss_info_value.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          tf.logging.info('%.3f steps/sec' % steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
          )
def main(_):
    tf.compat.v1.enable_resource_variables()
    if tf.executing_eagerly():
        # self.skipTest('b/123777119')  # Secondary bug: ('b/123775375')
        return
    # loop over game params to create different configs
    logging.set_verbosity(logging.INFO)
    # todo: when this training is done, try different learning rates and architectures
    for colors in COLORS:
        for ranks in RANKS:
            for num_players in NUM_PLAYERS:
                for hand_size in HAND_SIZES:
                    for max_information_tokens in MAX_INFORMATION_TOKENS:
                        for max_life_tokens in MAX_LIFE_TOKENS:  # 2 * 1 * 1 * 4 * 4 * 2 = 64 total iterations
                            for custom_reward in CUSTOM_REWARDS:
                                for penalty in PENALTIES_LAST_HINT_TOKEN:
                                    config = {
                                        "colors": colors,
                                        "ranks": ranks,
                                        "players": num_players,
                                        "hand_size": hand_size,
                                        "max_information_tokens":
                                        max_information_tokens,
                                        "max_life_tokens": max_life_tokens,
                                        "observation_type": OBSERVATION_TYPE,
                                        "custom_reward": custom_reward,
                                        "penalty_last_hint_token": penalty,
                                        "per_card_reward": True
                                    }
                                    # ################################################ #
                                    # --------------- Load Environments -------------- #
                                    # ################################################ #
                                    eval_py_env = parallel_py_environment.ParallelPyEnvironment(
                                        [lambda: load_hanabi_env(config)] *
                                        FLAGS.num_parallel_environments)

                                    tf_env = tf_py_environment.TFPyEnvironment(
                                        parallel_py_environment.
                                        ParallelPyEnvironment(
                                            [lambda: load_hanabi_env(config)] *
                                            FLAGS.num_parallel_environments))
                                    train_eval(
                                        root_dir=FLAGS.root_dir,
                                        summary_dir=FLAGS.summary_dir,
                                        game_config=config,
                                        tf_master=FLAGS.master,
                                        replay_buffer_capacity=FLAGS.
                                        replay_buffer_capacity,
                                        env_load_fn=load_hanabi_env,
                                        num_environment_steps=FLAGS.
                                        num_environment_steps,
                                        num_parallel_environments=FLAGS.
                                        num_parallel_environments,
                                        num_epochs=FLAGS.num_epochs,
                                        collect_episodes_per_iteration=FLAGS.
                                        collect_episodes_per_iteration,
                                        num_eval_episodes=FLAGS.
                                        num_eval_episodes,
                                        use_rnns=FLAGS.use_rnns,
                                        eval_py_env=eval_py_env,
                                        tf_env=tf_env)
                                    del eval_py_env
                                    del tf_env