示例#1
0
    def host_call_fn(**kwargs):
        """Host_call_fn.

    Args:
      **kwargs: dict of summary name to tf.Tensor mapping. The value we see here
        is the tensor across all cores, concatenated along axis 0. This function
        will take make a scalar summary that is the mean of the whole tensor (as
        all the values are the same - the mean, trait of
        tpu.CrossShardOptimizer).

    Returns:
      A merged summary op.
    """
        gs = kwargs.pop('global_step')[0]
        with tf_summary.create_file_writer(model_dir).as_default():
            with tf_summary.record_if(tf.equal(gs % 10, 0)):
                for name, tensor in kwargs.items():
                    # Take the mean across cores.
                    tensor = tf.reduce_mean(tensor)
                    tf_summary.scalar(name, tensor, step=gs)
                return tf.summary.all_v2_summary_ops()
示例#2
0
def host_call_fn(model_dir, **kwargs):
    """host_call function used for creating training summaries when using TPU.

  Args:
    model_dir: String indicating the output_dir to save summaries in.
    **kwargs: Set of metric names and tensor values for all desired summaries.

  Returns:
    Summary op to be passed to the host_call arg of the estimator function.
  """
    gs = kwargs.pop('global_step')[0]
    with summary.create_file_writer(model_dir).as_default():
        # Always record summaries.
        with summary.record_if(True):
            for name, tensor in kwargs.items():
                if name.startswith(IMG_SUMMARY_PREFIX):
                    summary.image(name.replace(IMG_SUMMARY_PREFIX, ''),
                                  tensor,
                                  max_images=1)
                else:
                    summary.scalar(name, tensor[0], step=gs)
            # Following function is under tf:1x, so we use it.
            return tf.summary.all_v2_summary_ops()
示例#3
0
def train_eval(tf_agent,
               num_iterations,
               batch_size,
               tf_env,
               eval_tf_env,
               train_metrics,
               step_metrics,
               eval_metrics,
               global_step,
               steps_per_episode,
               num_parallel_environments,
               collect_per_iteration,
               train_steps_per_iteration,
               train_dir,
               saved_model_dir,
               eval_summary_writer,
               num_eval_episodes,
               num_eval_seeds=1,
               eval_metrics_callback=None,
               train_sequence_length=1,
               initial_collect_steps=1000,
               log_interval=100,
               eval_interval=400,
               policy_checkpoint_interval=400,
               train_checkpoint_interval=1200,
               rb_checkpoint_interval=2000,
               train_model=True,
               use_tf_functions=True,
               eval_early_stopping=False,
               seed=12345):
    """ Train and evaluation function of a TF Agent given the properties provided """

    # Define seed for each environment
    for i, env in enumerate(tf_env.envs):
        env.seed(seed + i)
    for i, env in enumerate(eval_tf_env.envs):
        env.seed(seed + i)
    tf_env.reset()
    eval_tf_env.reset()

    tf_agent.initialize()
    agent_name = tf_agent.__dict__['_name']

    # Define policies
    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    # Define Replay Buffer
    replay_buffer_capacity = steps_per_episode * \
        collect_per_iteration // num_parallel_environments + 1
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,  # batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    # Define Dynamic driver to go through the environment depending on the agent
    if train_model:
        if agent_name in ['dqn_agent']:
            collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                num_steps=collect_per_iteration)
        elif agent_name in ['ppo_agent']:
            collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                num_episodes=collect_per_iteration)
        else:
            raise NotImplementedError(
                f'{agent_name} agent not yet implemented')

    # Define Checkpointers for train and policy
    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=tf_agent,
                                             global_step=global_step,
                                             metrics=metric_utils.MetricsGroup(
                                                 train_metrics,
                                                 'train_metrics'))
    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=eval_policy,
                                              global_step=global_step)
    saved_model = policy_saver.PolicySaver(eval_policy, train_step=global_step)
    # rb_checkpointer = common.Checkpointer(
    #     ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
    #     max_to_keep=1,
    #     replay_buffer=replay_buffer)

    policy_checkpointer.initialize_or_restore()  # TODO: To be tested
    train_checkpointer.initialize_or_restore()
    # rb_checkpointer.initialize_or_restore()

    if train_model:

        eval_metrics_callback.add_checkpointer(policy_checkpointer)
        eval_metrics_callback.add_checkpointer(train_checkpointer)
        # eval_metrics_callback.add_checkpointer(rb_checkpointer)

        # TODO: should they use autograph=False?? as in tf_agents/agents/ppo/examples/v2/train_eval_clip_agent.py
        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Only run Replay buffer initialization if using one of the following agents
        if agent_name in ['dqn_agent']:
            initial_collect_policy = random_tf_policy.RandomTFPolicy(
                tf_env.time_step_spec(), tf_env.action_spec())

            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps with '
                'a random policy.', initial_collect_steps)
            dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                num_steps=initial_collect_steps).run()

        # num_eval_episodes = eval_tf_env.envs[0].frame_bound[-1] // eval_tf_env.envs[0].steps_per_episode
        logging.info(f'Initial eval metric')
        results = evaluate(eval_metrics,
                           eval_tf_env,
                           eval_policy,
                           num_eval_episodes,
                           num_eval_seeds,
                           global_step,
                           eval_summary_writer,
                           summary_prefix='Metrics',
                           seed=seed)

        if eval_early_stopping and not isinstance(eval_metrics_callback,
                                                  AgentEarlyStopping):
            raise ValueError(
                'Cannot set eval_early_stopping without eval_metric_callback being Agent Early Stopping instance'
            )

        # Once evaluate has been done call eval metrics callback
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())

        # Initialize training variables
        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        collect_time = 0
        train_time = 0
        summary_time = 0

        # Define train_step and generate dataset if required
        if agent_name in ['dqn_agent']:
            # Dataset generates trajectories with shape [Bx2x...]
            logging.info(f'Dataset generates trajectories')
            dataset = replay_buffer.as_dataset(
                num_parallel_calls=3,
                sample_batch_size=batch_size,
                # single_deterministic_pass=True,
                num_steps=train_sequence_length + 1).prefetch(3)
            iterator = iter(dataset)

            def train_step():
                experience, _ = next(iterator)
                return tf_agent.train(experience)
        elif agent_name in ['ppo_agent']:

            def train_step():
                trajectories = replay_buffer.gather_all()
                return tf_agent.train(experience=trajectories)
        else:
            raise NotImplementedError(
                f'{agent_name} agent not yet implemented')

        if use_tf_functions:
            train_step = common.function(train_step)

        logging.info(f'Starting training...')
        for _ in range(num_iterations):

            # Collect data
            start_time = time.time()
            if agent_name in ['dqn_agent']:
                time_step, policy_state = collect_driver.run(
                    time_step=time_step,
                    policy_state=policy_state,
                )
            elif agent_name in ['ppo_agent']:
                collect_driver.run()
            else:
                raise NotImplementedError(
                    f'{agent_name} agent not yet implemented')

            collect_time += time.time() - start_time

            # Train on collected data
            start_time = time.time()
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            train_time += time.time() - start_time

            # Write on Tensorboard the training results
            start_time = time.time()
            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)
            summary_time += time.time() - start_time

            # Print out metrics and reset variables
            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() - timed_at_step) / \
                    (train_time + collect_time + summary_time)
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info(
                    'collect_time = %.3f, train_time = %.3f, summary_time = %.3f',
                    collect_time, train_time, summary_time)
                summary.scalar(name='global_steps_per_sec',
                               data=steps_per_sec,
                               step=global_step)
                timed_at_step = global_step.numpy()
                collect_time = 0
                train_time = 0
                summary_time = 0

            # Save train checkpoint
            if global_step.numpy() % train_checkpoint_interval == 0:
                start_time = time.time()
                train_checkpointer.save(global_step=global_step.numpy())
                logging.info(
                    f'Saving Train lasts: {time.time() - start_time:.3f} s')

            # Save policy checkpoint
            if global_step.numpy() % policy_checkpoint_interval == 0:
                start_time = time.time()
                policy_checkpointer.save(global_step=global_step.numpy())
                saved_model_path = os.path.join(
                    saved_model_dir,
                    'policy_' + ('%d' % global_step.numpy()).zfill(9))
                saved_model.save(saved_model_path)
                logging.info(
                    f'Saving Policy lasts: {time.time() - start_time:.3f} s')

            # if global_step.numpy() % rb_checkpoint_interval == 0:
            #   start_time = time.time()
            #   rb_checkpointer.save(global_step=global_step.numpy())
            #   logging.info(
            #     f'Saving Replay Buffer lasts: {time.time() - start_time:.3f} s'
            #   )

            # Evaluate on evaluation environment
            if global_step.numpy() % eval_interval == 0:
                start_time = time.time()
                results = evaluate(eval_metrics,
                                   eval_tf_env,
                                   eval_policy,
                                   num_eval_episodes,
                                   num_eval_seeds,
                                   global_step,
                                   eval_summary_writer,
                                   summary_prefix='Metrics',
                                   seed=seed)
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                logging.info(
                    f'Calculate Evaluation lasts {time.time() - start_time:.3f} s'
                )

                # Stop training if EarlyStopping says so
                if eval_early_stopping and eval_metrics_callback.stop_training:
                    logging.info(
                        f'Training stopped due to Agent Early Stopping at step: {global_step.numpy()}'
                    )
                    logging.info(
                        f'Best {eval_metrics_callback.monitor} was {eval_metrics_callback.best:.5f} at step {eval_metrics_callback.best_step}'
                    )

                    def loadBestCheckpoint(checkpointer, ckpt_dir=None):
                        latest_dir = checkpointer._manager.latest_checkpoint
                        if latest_dir is not None:
                            best_dir = latest_dir.split('-')
                            best_dir[-1] = str(eval_metrics_callback.best_step)
                            best_dir = '-'.join(best_dir)
                        elif ckpt_dir is not None:
                            best_dir = os.path.join(
                                ckpt_dir,
                                f'ckpt-{eval_metrics_callback.best_step}')
                        else:
                            raise ValueError(
                                'Checkpointer with previous checkpoints or ckpt_dir must be provided'
                            )

                        policy_checkpointer \
                            ._checkpoint \
                            .restore(best_dir)

                    # Load policy with best evaluation metric according to EarlyStopping
                    loadBestCheckpoint(policy_checkpointer,
                                       os.path.join(train_dir, 'policy'))
                    loadBestCheckpoint(train_checkpointer, train_dir)
                    # loadBestCheckpoint(rb_checkpointer, os.path.join(train_dir, 'replay_buffer'))

                    eval_metrics_callback.reset()

                    break
示例#4
0
    def train_fn(data_path):
        """A train_fn to train the planet model."""
        nonlocal iterator
        nonlocal optimizer

        with strategy.scope():
            global_step = tf.Variable(0, dtype=tf.int64, trainable=False)
            checkpoint = tf.train.Checkpoint(global_step=global_step,
                                             optimizer=optimizer,
                                             **model.get_trackables())
            manager = tf.train.CheckpointManager(checkpoint,
                                                 model_dir,
                                                 max_to_keep=1)
            checkpoint.restore(manager.latest_checkpoint)
        if iterator is None:
            dataset = npz.load_dataset_from_directory(data_path, duration,
                                                      batch)
            dataset = strategy.experimental_distribute_dataset(dataset)
            iterator = dataset

        writer = tfs.create_file_writer(model_dir)
        tfs.experimental.set_step(global_step)
        true_rewards, pred_rewards = None, None
        with writer.as_default():
            for step, obs in enumerate(iterator):
                if step > train_steps:
                    if save_rewards:
                        # We are only saving the last training batch.
                        reward_dir = os.path.join(model_dir, 'train_rewards')
                        true_rewards = strategy.experimental_local_results(
                            true_rewards)
                        pred_reward = strategy.experimental_local_results(
                            pred_rewards)
                        true_rewards = np.concatenate(
                            [x.numpy() for x in true_rewards])
                        pred_reward = np.concatenate(
                            [x.numpy() for x in pred_reward])
                        rewards_to_save = {
                            'true': true_rewards,
                            'pred': pred_reward
                        }
                        npz.save_dictionary(rewards_to_save, reward_dir)
                    break
                (loss, reward_loss, divergence, frames, pred_rewards,
                 true_rewards, frame_loss) = train_step(obs)
                if step % 100 == 0:
                    loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss)
                    reward_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                                  reward_loss)
                    divergence = strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                                 divergence)
                    frame_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                                 frame_loss)
                    frames = strategy.experimental_local_results(frames)
                    frames = tf.concat(frames, axis=0)
                    pred_reward = strategy.experimental_local_results(
                        pred_rewards)
                    pred_reward = tf.concat(pred_reward, axis=0)
                    tf.logging.info('loss at step %d: %f', step, loss)
                    tfs.scalar('loss/total', loss)
                    tfs.scalar('loss/reward', reward_loss)
                    tfs.scalar('loss/divergence', divergence)
                    tfs.scalar('loss/frames', frame_loss)
                    tfs.experimental.write_raw_pb(
                        visualization.py_gif_summary(tag='predictions/frames',
                                                     images=frames.numpy(),
                                                     max_outputs=6,
                                                     fps=20))
                    ground_truth_rewards = (tf.concat(
                        strategy.experimental_local_results(obs['reward']),
                        axis=0)[:, :, 0])
                    rewards = pred_reward[:, :, 0]
                    signals = tf.stack([ground_truth_rewards, rewards], axis=1)
                    visualization.py_plot_1d_signal(
                        name='predictions/reward',
                        signals=signals.numpy(),
                        labels=['ground_truth', 'prediction'],
                        max_outputs=6)
                global_step.assign_add(1)

            manager.save(global_step)