Пример #1
0
def train_eval(
    root_dir,
    random_seed=None,
    env_name='sawyer_push',
    eval_env_name=None,
    env_load_fn=get_env,
    max_episode_steps=1000,
    eval_episode_steps=1000,
    # The SAC paper reported:
    # Hopper and Cartpole results up to 1000000 iters,
    # Humanoid results up to 10000000 iters,
    # Other mujoco tasks up to 3000000 iters.
    num_iterations=3000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
    # HalfCheetah and Ant take 10000 initial collection steps.
    # Other mujoco tasks take 1000.
    # Different choices roughly keep the initial episodes about the same.
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    reset_goal_frequency=1000,  # virtual episode size for reset-free training
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    # reset-free parameters
    use_minimum=True,
    reset_lagrange_learning_rate=3e-4,
    value_threshold=None,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    # Td3 parameters
    actor_update_period=1,
    exploration_noise_std=0.1,
    target_policy_noise=0.1,
    target_policy_noise_clip=0.1,
    dqda_clipping=None,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    # video recording for the environment
    video_record_interval=10000,
    num_videos=0,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  start_time = time.time()

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')
  video_dir = os.path.join(eval_dir, 'videos')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if random_seed is not None:
      tf.compat.v1.set_random_seed(random_seed)
    env, env_train_metrics, env_eval_metrics, aux_info = env_load_fn(
        name=env_name,
        max_episode_steps=None,
        gym_env_wrappers=(functools.partial(
            reset_free_wrapper.ResetFreeWrapper,
            reset_goal_frequency=reset_goal_frequency,
            full_reset_frequency=max_episode_steps),))

    tf_env = tf_py_environment.TFPyEnvironment(env)
    eval_env_name = eval_env_name or env_name
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        env_load_fn(name=eval_env_name,
                    max_episode_steps=eval_episode_steps)[0])

    eval_metrics += env_eval_metrics

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    if FLAGS.agent_type == 'sac':
      actor_net = actor_distribution_network.ActorDistributionNetwork(
          observation_spec,
          action_spec,
          fc_layer_params=actor_fc_layers,
          continuous_projection_net=functools.partial(
              tanh_normal_projection_network.TanhNormalProjectionNetwork,
              std_transform=std_clip_transform),
          name='forward_actor')
      critic_net = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform',
          name='forward_critic')

      tf_agent = SacAgent(
          time_step_spec,
          action_spec,
          num_action_samples=FLAGS.num_action_samples,
          actor_network=actor_net,
          critic_network=critic_net,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=alpha_learning_rate),
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
          name='forward_agent')

      actor_net_rev = actor_distribution_network.ActorDistributionNetwork(
          observation_spec,
          action_spec,
          fc_layer_params=actor_fc_layers,
          continuous_projection_net=functools.partial(
              tanh_normal_projection_network.TanhNormalProjectionNetwork,
              std_transform=std_clip_transform),
          name='reverse_actor')

      critic_net_rev = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform',
          name='reverse_critic')

      tf_agent_rev = SacAgent(
          time_step_spec,
          action_spec,
          num_action_samples=FLAGS.num_action_samples,
          actor_network=actor_net_rev,
          critic_network=critic_net_rev,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=alpha_learning_rate),
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
          name='reverse_agent')

    elif FLAGS.agent_type == 'td3':
      actor_net = actor_network.ActorNetwork(
          tf_env.time_step_spec().observation,
          tf_env.action_spec(),
          fc_layer_params=actor_fc_layers,
      )
      critic_net = critic_network.CriticNetwork(
          (observation_spec, action_spec),
          observation_fc_layer_params=critic_obs_fc_layers,
          action_fc_layer_params=critic_action_fc_layers,
          joint_fc_layer_params=critic_joint_fc_layers,
          kernel_initializer='glorot_uniform',
          last_kernel_initializer='glorot_uniform')

      tf_agent = Td3Agent(
          tf_env.time_step_spec(),
          tf_env.action_spec(),
          actor_network=actor_net,
          critic_network=critic_net,
          actor_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=actor_learning_rate),
          critic_optimizer=tf.compat.v1.train.AdamOptimizer(
              learning_rate=critic_learning_rate),
          exploration_noise_std=exploration_noise_std,
          target_update_tau=target_update_tau,
          target_update_period=target_update_period,
          actor_update_period=actor_update_period,
          dqda_clipping=dqda_clipping,
          td_errors_loss_fn=td_errors_loss_fn,
          gamma=gamma,
          reward_scale_factor=reward_scale_factor,
          target_policy_noise=target_policy_noise,
          target_policy_noise_clip=target_policy_noise_clip,
          gradient_clipping=gradient_clipping,
          debug_summaries=debug_summaries,
          summarize_grads_and_vars=summarize_grads_and_vars,
          train_step_counter=global_step,
      )

    tf_agent.initialize()
    tf_agent_rev.initialize()

    if FLAGS.use_reset_goals:
      # distance to initial state distribution
      initial_state_distance = state_distribution_distance.L2Distance(
          initial_state_shape=aux_info['reset_state_shape'])
      initial_state_distance.update(
          tf.constant(aux_info['reset_states'], dtype=tf.float32),
          update_type='complete')

      if use_tf_functions:
        initial_state_distance.distance = common.function(
            initial_state_distance.distance)
        tf_agent.compute_value = common.function(tf_agent.compute_value)

      # initialize reset / practice goal proposer
      if reset_lagrange_learning_rate > 0:
        reset_goal_generator = ResetGoalGenerator(
            goal_dim=aux_info['reset_state_shape'][0],
            num_reset_candidates=FLAGS.num_reset_candidates,
            compute_value_fn=tf_agent.compute_value,
            distance_fn=initial_state_distance,
            use_minimum=use_minimum,
            value_threshold=value_threshold,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=reset_lagrange_learning_rate),
            name='reset_goal_generator')
      else:
        reset_goal_generator = FixedResetGoal(
            distance_fn=initial_state_distance)

      # if use_tf_functions:
      #   reset_goal_generator.get_reset_goal = common.function(
      #       reset_goal_generator.get_reset_goal)

      # modify the reset-free wrapper to use the reset goal generator
      tf_env.pyenv.envs[0].set_reset_goal_fn(
          reset_goal_generator.get_reset_goal)

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    replay_buffer_rev = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent_rev.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer_rev = [replay_buffer_rev.add_batch]

    # initialize metrics and observers
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
    ]
    train_metrics += env_train_metrics
    train_metrics_rev = [
        tf_metrics.NumberOfEpisodes(name='NumberOfEpisodesRev'),
        tf_metrics.EnvironmentSteps(name='EnvironmentStepsRev'),
        tf_metrics.AverageReturnMetric(
            name='AverageReturnRev',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthRev',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
    ]
    train_metrics_rev += aux_info['train_metrics_rev']

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    eval_py_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_agent.policy, use_tf_function=True)

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    initial_collect_policy_rev = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy
    collect_policy_rev = tf_agent_rev.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'forward'),
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'forward', 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)
    # reverse policy savers
    train_checkpointer_rev = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'reverse'),
        agent=tf_agent_rev,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics_rev,
                                          'train_metrics_rev'))
    rb_checkpointer_rev = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer_rev'),
        max_to_keep=1,
        replay_buffer=replay_buffer_rev)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()
    train_checkpointer_rev.initialize_or_restore()
    rb_checkpointer_rev.initialize_or_restore()

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)
    collect_driver_rev = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy_rev,
        observers=replay_observer_rev + train_metrics_rev,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      collect_driver.run = common.function(collect_driver.run)
      collect_driver_rev.run = common.function(collect_driver_rev.run)
      tf_agent.train = common.function(tf_agent.train)
      tf_agent_rev.train = common.function(tf_agent_rev.train)

    if replay_buffer.num_frames() == 0:
      initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy,
          observers=replay_observer + train_metrics,
          num_steps=1)
      initial_collect_driver_rev = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy_rev,
          observers=replay_observer_rev + train_metrics_rev,
          num_steps=1)
      # does not work for some reason
      if use_tf_functions:
        initial_collect_driver.run = common.function(initial_collect_driver.run)
        initial_collect_driver_rev.run = common.function(
            initial_collect_driver_rev.run)

      # Collect initial replay data.
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps with '
          'a random policy.', initial_collect_steps)
      for iter_idx_initial in range(initial_collect_steps):
        if tf_env.pyenv.envs[0]._forward_or_reset_goal:
          initial_collect_driver.run()
        else:
          initial_collect_driver_rev.run()
        if FLAGS.use_reset_goals and iter_idx_initial % FLAGS.reset_goal_frequency == 0:
          if replay_buffer_rev.num_frames():
            reset_candidates_from_forward_buffer = replay_buffer.get_next(
                sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
            reset_candidates_from_reverse_buffer = replay_buffer_rev.get_next(
                sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
            flat_forward_tensors = tf.nest.flatten(
                reset_candidates_from_forward_buffer)
            flat_reverse_tensors = tf.nest.flatten(
                reset_candidates_from_reverse_buffer)
            concatenated_tensors = [
                tf.concat([x, y], axis=0)
                for x, y in zip(flat_forward_tensors, flat_reverse_tensors)
            ]
            reset_candidates = tf.nest.pack_sequence_as(
                reset_candidates_from_forward_buffer, concatenated_tensors)
            tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates)
          else:
            reset_candidates = replay_buffer.get_next(
                sample_batch_size=FLAGS.num_reset_candidates)[0]
            tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates)

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]

    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size, num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    dataset_rev = replay_buffer_rev.as_dataset(
        sample_batch_size=batch_size, num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator_rev = iter(dataset_rev)

    def train_step_rev():
      experience_rev, _ = next(iterator_rev)
      return tf_agent_rev.train(experience_rev)

    if use_tf_functions:
      train_step = common.function(train_step)
      train_step_rev = common.function(train_step_rev)

    # manual data save for plotting utils
    np_on_cns_save(os.path.join(eval_dir, 'eval_interval.npy'), eval_interval)
    try:
      average_eval_return = np_on_cns_load(
          os.path.join(eval_dir, 'average_eval_return.npy')).tolist()
      average_eval_success = np_on_cns_load(
          os.path.join(eval_dir, 'average_eval_success.npy')).tolist()
    except:
      average_eval_return = []
      average_eval_success = []

    print('initialization_time:', time.time() - start_time)
    for iter_idx in range(num_iterations):
      start_time = time.time()
      if tf_env.pyenv.envs[0]._forward_or_reset_goal:
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )
      else:
        time_step, policy_state = collect_driver_rev.run(
            time_step=time_step,
            policy_state=policy_state,
        )

      # reset goal generator updates
      if FLAGS.use_reset_goals and iter_idx % (
          FLAGS.reset_goal_frequency * collect_steps_per_iteration) == 0:
        reset_candidates_from_forward_buffer = replay_buffer.get_next(
            sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
        reset_candidates_from_reverse_buffer = replay_buffer_rev.get_next(
            sample_batch_size=FLAGS.num_reset_candidates // 2)[0]
        flat_forward_tensors = tf.nest.flatten(
            reset_candidates_from_forward_buffer)
        flat_reverse_tensors = tf.nest.flatten(
            reset_candidates_from_reverse_buffer)
        concatenated_tensors = [
            tf.concat([x, y], axis=0)
            for x, y in zip(flat_forward_tensors, flat_reverse_tensors)
        ]
        reset_candidates = tf.nest.pack_sequence_as(
            reset_candidates_from_forward_buffer, concatenated_tensors)
        tf_env.pyenv.envs[0].set_reset_candidates(reset_candidates)
        if reset_lagrange_learning_rate > 0:
          reset_goal_generator.update_lagrange_multipliers()

      for _ in range(train_steps_per_iteration):
        train_loss_rev = train_step_rev()
        train_loss = train_step()

      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val, train_loss.loss)
        logging.info('step = %d, loss_rev = %f', global_step_val,
                     train_loss_rev.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in train_metrics:
        if 'Heatmap' in train_metric.name:
          if global_step_val % summary_interval == 0:
            train_metric.tf_summaries(
                train_step=global_step, step_metrics=train_metrics[:2])
        else:
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics[:2])

      for train_metric in train_metrics_rev:
        if 'Heatmap' in train_metric.name:
          if global_step_val % summary_interval == 0:
            train_metric.tf_summaries(
                train_step=global_step, step_metrics=train_metrics_rev[:2])
        else:
          train_metric.tf_summaries(
              train_step=global_step, step_metrics=train_metrics_rev[:2])

      if global_step_val % summary_interval == 0 and FLAGS.use_reset_goals:
        reset_goal_generator.update_summaries(step_counter=global_step)

      if global_step_val % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step_val)
        metric_utils.log_metrics(eval_metrics)

        # numpy saves for plotting
        average_eval_return.append(results['AverageReturn'].numpy())
        average_eval_success.append(results['EvalSuccessfulEpisodes'].numpy())
        np_on_cns_save(
            os.path.join(eval_dir, 'average_eval_return.npy'),
            average_eval_return)
        np_on_cns_save(
            os.path.join(eval_dir, 'average_eval_success.npy'),
            average_eval_success)

      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)
        train_checkpointer_rev.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
        rb_checkpointer_rev.save(global_step=global_step_val)

      if global_step_val % video_record_interval == 0:
        for video_idx in range(num_videos):
          video_name = os.path.join(video_dir, str(global_step_val),
                                    'video_' + str(video_idx) + '.mp4')
          record_video(
              lambda: env_load_fn(  # pylint: disable=g-long-lambda
                  name=env_name,
                  max_episode_steps=max_episode_steps)[0],
              video_name,
              eval_py_policy,
              max_episode_length=eval_episode_steps)

    return train_loss
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    agent_name='sac',  # agent's name
    num_iterations=int(1e7),
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    mask_names=['birdeye'],  # names for masks
    preprocessing_combiner=tf.keras.layers.Add(
    ),  # takes a flat list of tensors and combines them
    actor_lstm_size=(40, ),  # lstm size for actor
    critic_lstm_size=(40, ),  # lstm size for critic
    actor_output_fc_layers=(100, ),  # lstm output
    critic_output_fc_layers=(100, ),  # lstm output
    epsilon_greedy=0.1,  # exploration parameter for DQN
    q_learning_rate=1e-3,  # q learning rate for DQN
    ou_stddev=0.2,  # exploration paprameter for DDPG
    ou_damping=0.15,  # exploration parameter for DDPG
    dqda_clipping=None,  # for DDPG
    exploration_noise_std=0.1,  # exploration paramter for td3
    actor_update_period=2,  # for td3
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=int(1e5),
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    initial_model_train_steps=100000,  # initial model training
    batch_size=256,
    model_batch_size=32,  # model training batch size
    sequence_length=4,  # number of timesteps to train model
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    model_learning_rate=1e-4,  # learning rate for model training
    td_errors_loss_fn=tf.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=10000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1
):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
    # Setup GPU
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    # Get train and eval directories
    root_dir = os.path.expanduser(root_dir)
    root_dir = os.path.join(root_dir, env_name, experiment_name)

    # Get summary writers
    summary_writer = tf.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    # Eval metrics
    eval_metrics = [
        tf_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy',
                                       buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalPolicy',
            buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    # Whether to record for summary
    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create Carla environment
        if agent_name == 'latent_sac':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names +
                                                 mask_names,
                                                 action_repeat=action_repeat)
        elif agent_name == 'dqn':
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 discrete=True,
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)
        else:
            py_env, eval_py_env = load_carla_env(env_name='carla-v0',
                                                 obs_channels=input_names,
                                                 action_repeat=action_repeat)

        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
        fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

        # Specs
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        ## Make tf agent
        if agent_name == 'latent_sac':
            # Get model network for latent sac
            if model_network_ctor_type == 'hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelHierarchical
            elif model_network_ctor_type == 'non-hierarchical':
                model_network_ctor = sequential_latent_network.SequentialLatentModelNonHierarchical
            else:
                raise NotImplementedError
            model_net = model_network_ctor(input_names,
                                           input_names + mask_names)

            # Get the latent spec
            latent_size = model_net.latent_size
            latent_observation_spec = tensor_spec.TensorSpec((latent_size, ),
                                                             dtype=tf.float32)
            latent_time_step_spec = ts.time_step_spec(
                observation_spec=latent_observation_spec)

            # Get actor and critic net
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                latent_observation_spec,
                action_spec,
                fc_layer_params=actor_fc_layers,
                continuous_projection_net=normal_projection_net)
            critic_net = critic_network.CriticNetwork(
                (latent_observation_spec, action_spec),
                observation_fc_layer_params=critic_obs_fc_layers,
                action_fc_layer_params=critic_action_fc_layers,
                joint_fc_layer_params=critic_joint_fc_layers)

            # Build the inner SAC agent based on latent space
            inner_agent = sac_agent.SacAgent(
                latent_time_step_spec,
                action_spec,
                actor_network=actor_net,
                critic_network=critic_net,
                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=actor_learning_rate),
                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=critic_learning_rate),
                alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=alpha_learning_rate),
                target_update_tau=target_update_tau,
                target_update_period=target_update_period,
                td_errors_loss_fn=td_errors_loss_fn,
                gamma=gamma,
                reward_scale_factor=reward_scale_factor,
                gradient_clipping=gradient_clipping,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step)
            inner_agent.initialize()

            # Build the latent sac agent
            tf_agent = latent_sac_agent.LatentSACAgent(
                time_step_spec,
                action_spec,
                inner_agent=inner_agent,
                model_network=model_net,
                model_optimizer=tf.compat.v1.train.AdamOptimizer(
                    learning_rate=model_learning_rate),
                model_batch_size=model_batch_size,
                num_images_per_summary=num_images_per_summary,
                sequence_length=sequence_length,
                gradient_clipping=gradient_clipping,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                fps=fps)

        else:
            # Set up preprosessing layers for dictionary observation inputs
            preprocessing_layers = collections.OrderedDict()
            for name in input_names:
                preprocessing_layers[name] = Preprocessing_Layer(32, 256)
            if len(input_names) < 2:
                preprocessing_combiner = None

            if agent_name == 'dqn':
                q_rnn_net = q_rnn_network.QRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = dqn_agent.DqnAgent(
                    time_step_spec,
                    action_spec,
                    q_network=q_rnn_net,
                    epsilon_greedy=epsilon_greedy,
                    n_step_update=1,
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=q_learning_rate),
                    td_errors_loss_fn=common.element_wise_squared_loss,
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            elif agent_name == 'ddpg' or agent_name == 'td3':
                actor_rnn_net = multi_inputs_actor_rnn_network.MultiInputsActorRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                if agent_name == 'ddpg':
                    tf_agent = ddpg_agent.DdpgAgent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        ou_stddev=ou_stddev,
                        ou_damping=ou_damping,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars)
                elif agent_name == 'td3':
                    tf_agent = td3_agent.Td3Agent(
                        time_step_spec,
                        action_spec,
                        actor_network=actor_rnn_net,
                        critic_network=critic_rnn_net,
                        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=actor_learning_rate),
                        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                            learning_rate=critic_learning_rate),
                        exploration_noise_std=exploration_noise_std,
                        target_update_tau=target_update_tau,
                        target_update_period=target_update_period,
                        actor_update_period=actor_update_period,
                        dqda_clipping=dqda_clipping,
                        td_errors_loss_fn=None,
                        gamma=gamma,
                        reward_scale_factor=reward_scale_factor,
                        gradient_clipping=gradient_clipping,
                        debug_summaries=debug_summaries,
                        summarize_grads_and_vars=summarize_grads_and_vars,
                        train_step_counter=global_step)

            elif agent_name == 'sac':
                actor_distribution_rnn_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                    observation_spec,
                    action_spec,
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    input_fc_layer_params=actor_fc_layers,
                    lstm_size=actor_lstm_size,
                    output_fc_layer_params=actor_output_fc_layers,
                    continuous_projection_net=normal_projection_net)

                critic_rnn_net = multi_inputs_critic_rnn_network.MultiInputsCriticRnnNetwork(
                    (observation_spec, action_spec),
                    preprocessing_layers=preprocessing_layers,
                    preprocessing_combiner=preprocessing_combiner,
                    action_fc_layer_params=critic_action_fc_layers,
                    joint_fc_layer_params=critic_joint_fc_layers,
                    lstm_size=critic_lstm_size,
                    output_fc_layer_params=critic_output_fc_layers)

                tf_agent = sac_agent.SacAgent(
                    time_step_spec,
                    action_spec,
                    actor_network=actor_distribution_rnn_net,
                    critic_network=critic_rnn_net,
                    actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=actor_learning_rate),
                    critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=critic_learning_rate),
                    alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                        learning_rate=alpha_learning_rate),
                    target_update_tau=target_update_tau,
                    target_update_period=target_update_period,
                    td_errors_loss_fn=tf.math.
                    squared_difference,  # make critic loss dimension compatible
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=global_step)

            else:
                raise NotImplementedError

        tf_agent.initialize()

        # Get replay buffer
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,  # No parallel environments
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        # Train metrics
        env_steps = tf_metrics.EnvironmentSteps()
        average_return = tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        # Get policies
        # eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        eval_policy = tf_agent.policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)
        collect_policy = tf_agent.collect_policy

        # Checkpointers
        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=2)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step,
                                                  max_to_keep=2)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)
        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        # Collect driver
        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        # Optimize the performance by using tf functions
        initial_collect_driver.run = common.function(
            initial_collect_driver.run)
        collect_driver.run = common.function(collect_driver.run)
        tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if (env_steps.result() == 0 or replay_buffer.num_frames() == 0):
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        if agent_name == 'latent_sac':
            compute_summaries(eval_metrics,
                              eval_tf_env,
                              eval_policy,
                              train_step=global_step,
                              summary_writer=summary_writer,
                              num_episodes=1,
                              num_episodes_to_render=1,
                              model_net=model_net,
                              fps=10,
                              image_keys=input_names + mask_names)
        else:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=1,
                train_step=env_steps.result(),
                summary_writer=summary_writer,
                summary_prefix='Eval',
            )
            metric_utils.log_metrics(eval_metrics)

        # Dataset generates trajectories with shape [Bxslx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        # Get train step
        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        train_step = common.function(train_step)

        if agent_name == 'latent_sac':

            def train_model_step():
                experience, _ = next(iterator)
                return tf_agent.train_model(experience)

            train_model_step = common.function(train_model_step)

        # Training initializations
        time_step = None
        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Start training
        for iteration in range(num_iterations):
            start_time = time.time()

            if agent_name == 'latent_sac' and iteration < initial_model_train_steps:
                train_model_step()
            else:
                # Run collect
                time_step, _ = collect_driver.run(time_step=time_step)

                # Train an iteration
                for _ in range(train_steps_per_iteration):
                    train_step()

            time_acc += time.time() - start_time

            # Log training information
            if global_step.numpy() % log_interval == 0:
                logging.info('env steps = %d, average return = %f',
                             env_steps.result(), average_return.result())
                env_steps_per_sec = (env_steps.result().numpy() -
                                     env_steps_before) / time_acc
                logging.info('%.3f env steps/sec', env_steps_per_sec)
                tf.summary.scalar(name='env_steps_per_sec',
                                  data=env_steps_per_sec,
                                  step=env_steps.result())
                time_acc = 0
                env_steps_before = env_steps.result().numpy()

            # Get training metrics
            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=env_steps.result())

            # Evaluation
            if global_step.numpy() % eval_interval == 0:
                # Log evaluation metrics
                if agent_name == 'latent_sac':
                    compute_summaries(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        train_step=global_step,
                        summary_writer=summary_writer,
                        num_episodes=num_eval_episodes,
                        num_episodes_to_render=num_images_per_summary,
                        model_net=model_net,
                        fps=10,
                        image_keys=input_names + mask_names)
                else:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    metric_utils.log_metrics(eval_metrics)

            # Save checkpoints
            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
Пример #3
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v2',
        env_load_fn=suite_mujoco.load,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-3,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                lstm_size=(40, ),
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_clip_agent.PPOClipAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.0,
            importance_ratio_clipping=0.2,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss, _ = sess.run([train_op, summary_ops])
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)
Пример #4
0
def train_eval_doom_simple(
        # Params for collect
        num_environment_steps=30000000,
        collect_episodes_per_iteration=32,
        num_parallel_environments=32,
        replay_buffer_capacity=301,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=4e-4,
        # Params for eval
        eval_interval=500,
        num_video_episodes=10,
        # Params for summaries and logging
        log_interval=50):
    """A simple train and eval for PPO."""
    # if not os.path.exists(videos_dir):
    # 	os.makedirs(videos_dir)

    # eval_py_env = CSGOEnvironment()
    # eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
    tf_env = tf_py_environment.TFPyEnvironment(CSGOEnvironment())

    actor_net, value_net = create_networks(tf_env.observation_spec(),
                                           tf_env.action_spec())

    global_step = tf.compat.v1.train.get_or_create_global_step()
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate,
                                                 epsilon=1e-5)

    tf_agent = ppo_agent.PPOAgent(tf_env.time_step_spec(),
                                  tf_env.action_spec(),
                                  optimizer,
                                  actor_net,
                                  value_net,
                                  num_epochs=num_epochs,
                                  train_step_counter=global_step,
                                  discount_factor=0.99,
                                  gradient_clipping=0.5,
                                  entropy_regularization=1e-2,
                                  importance_ratio_clipping=0.2,
                                  use_gae=True,
                                  use_td_lambda_return=True)
    tf_agent.initialize()

    environment_steps_metric = tf_metrics.EnvironmentSteps()
    step_metrics = [
        tf_metrics.NumberOfEpisodes(),
        environment_steps_metric,
    ]

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=num_parallel_environments,
        max_length=replay_buffer_capacity)
    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        tf_agent.collect_policy,
        observers=[replay_buffer.add_batch] + step_metrics,
        num_episodes=collect_episodes_per_iteration)

    def train_step():
        trajectories = replay_buffer.gather_all()
        return tf_agent.train(experience=trajectories)

    # def evaluate():
    # 	create_video(eval_py_env, eval_tf_env, tf_agent.policy, num_episodes=num_video_episodes, video_filename=os.path.join(videos_dir, "video_%d.mp4" % global_step_val))

    collect_time = 0
    train_time = 0
    timed_at_step = global_step.numpy()

    while environment_steps_metric.result() < num_environment_steps:

        start_time = time.time()
        collect_driver.run()
        collect_time += time.time() - start_time

        start_time = time.time()
        total_loss, _ = train_step()
        replay_buffer.clear()
        train_time += time.time() - start_time

        global_step_val = global_step.numpy()

        if global_step_val % log_interval == 0:
            logging.info('step = %d, loss = %f', global_step_val, total_loss)
            steps_per_sec = ((global_step_val - timed_at_step) /
                             (collect_time + train_time))
            logging.info('%.3f steps/sec', steps_per_sec)
            logging.info('collect_time = {}, train_time = {}'.format(
                collect_time, train_time))

            timed_at_step = global_step_val
            collect_time = 0
            train_time = 0
Пример #5
0
def run():
    tf_env = tf_py_environment.TFPyEnvironment(SnakeEnv())
    eval_env = tf_py_environment.TFPyEnvironment(SnakeEnv(step_limit=50))

    q_net = q_network.QNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        conv_layer_params=(),
        fc_layer_params=(512, 256, 128),
    )

    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
    global_counter = tf.compat.v1.train.get_or_create_global_step()

    agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_counter,
        gamma=0.95,
        epsilon_greedy=0.1,
        n_step_update=1,
    )

    root_dir = os.path.join('/tf-logs', 'snake')
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_max_length,
    )

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        agent.collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration,
    )

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=agent,
        global_step=global_counter,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
    )

    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=agent.policy,
        global_step=global_counter,
    )

    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer,
    )

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    collect_driver.run = common.function(collect_driver.run)
    agent.train = common.function(agent.train)

    random_policy = random_tf_policy.RandomTFPolicy(tf_env.time_step_spec(),
                                                    tf_env.action_spec())

    if replay_buffer.num_frames() >= initial_collect_steps:
        logging.info("We loaded memories, not doing random seed")
    else:
        logging.info("Capturing %d steps to seed with random memories",
                     initial_collect_steps)

        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            random_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

    train_summary_writer = tf.summary.create_file_writer(train_dir)
    train_summary_writer.set_as_default()

    avg_returns = []
    avg_return_metric = tf_metrics.AverageReturnMetric(
        buffer_size=num_eval_episodes)
    eval_metrics = [
        avg_return_metric,
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]
    logging.info("Running initial evaluation")
    results = metric_utils.eager_compute(
        eval_metrics,
        eval_env,
        agent.policy,
        num_episodes=num_eval_episodes,
        train_step=global_counter,
        summary_writer=tf.summary.create_file_writer(eval_dir),
        summary_prefix='Metrics',
    )
    avg_returns.append(
        (global_counter.numpy(), avg_return_metric.result().numpy()))
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_counter.numpy()
    time_acc = 0

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)

    iterator = iter(dataset)

    @common.function
    def train_step():
        experience, _ = next(iterator)
        return agent.train(experience)

    for _ in range(num_iterations):
        start_time = time.time()
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )

        for _ in range(train_steps_per_iteration):
            train_loss = train_step()
        time_acc += time.time() - start_time

        step = global_counter.numpy()

        if step % log_interval == 0:
            logging.info("step = %d, loss = %f", step, train_loss.loss)
            steps_per_sec = (step - timed_at_step) / time_acc
            logging.info("%.3f steps/sec", steps_per_sec)
            timed_at_step = step
            time_acc = 0

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_counter,
                                      step_metrics=train_metrics[:2])

        if step % train_checkpoint_interval == 0:
            train_checkpointer.save(global_step=step)

        if step % policy_checkpoint_interval == 0:
            policy_checkpointer.save(global_step=step)

        if step % rb_checkpoint_interval == 0:
            rb_checkpointer.save(global_step=step)

        if step % capture_interval == 0:
            print("Capturing run:")
            capture_run(os.path.join(root_dir, "snake" + str(step) + ".mp4"),
                        eval_env, agent.policy)

        if step % eval_interval == 0:
            print("EVALUTION TIME:")
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_env,
                agent.policy,
                num_episodes=num_eval_episodes,
                train_step=global_counter,
                summary_writer=tf.summary.create_file_writer(eval_dir),
                summary_prefix='Metrics',
            )
            metric_utils.log_metrics(eval_metrics)
            avg_returns.append(
                (global_counter.numpy(), avg_return_metric.result().numpy()))
Пример #6
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_allowlist='position',
        eval_env_name=None,
        num_iterations=1000000,
        # Params for networks.
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        num_parallel_environments=1,
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        critic_learning_rate=3e-4,
        train_sequence_length=20,
        actor_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for RNN SAC on DM control."""
    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_allowlist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_allowlist=[observations_allowlist])
            ]
        else:
            env_wrappers = []

        env_load_fn = functools.partial(suite_dm_control.load,
                                        task_name=task_name,
                                        env_wrappers=env_wrappers)

        if num_parallel_environments == 1:
            py_env = env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
        average_return = tf_metrics.AverageReturnMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix='Train'),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='Train',
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
            logging.info(
                'Initializing replay buffer by collecting experience for %d episodes '
                'with a random policy.', initial_collect_episodes)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            # Reduce filter_fn over full trajectory sampled. The sequence is kept only
            # if all elements except for the last one pass the filter. This is to
            # allow training on terminal steps.
            return tf.reduce_all(~trajectories.is_boundary()[:-1])

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size,
            num_steps=train_sequence_length + 1).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            start_env_steps = env_steps.result()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            episode_steps = env_steps.result() - start_env_steps
            # TODO(b/152648849)
            for _ in range(episode_steps):
                for _ in range(train_steps_per_iteration):
                    train_step()
                time_acc += time.time() - start_time

                if global_step.numpy() % log_interval == 0:
                    logging.info('env steps = %d, average return = %f',
                                 env_steps.result(), average_return.result())
                    env_steps_per_sec = (env_steps.result().numpy() -
                                         env_steps_before) / time_acc
                    logging.info('%.3f env steps/sec', env_steps_per_sec)
                    tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                                data=env_steps_per_sec,
                                                step=env_steps.result())
                    time_acc = 0
                    env_steps_before = env_steps.result().numpy()

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=env_steps.result())

                if global_step.numpy() % eval_interval == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, env_steps.numpy())
                    metric_utils.log_metrics(eval_metrics)

                global_step_val = global_step.numpy()
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
Пример #7
0
def train_eval(
    root_dir,
    env_name='CartPole-v0',
    num_iterations=100000,
    fc_layer_params=(100,),
    # Params for collect
    initial_collect_steps=1000,
    collect_steps_per_iteration=1,
    epsilon_greedy=0.1,
    replay_buffer_capacity=100000,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=64,
    learning_rate=1e-3,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    agent_class=dqn_agent.DqnAgent,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for DQN."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
    eval_py_env = suite_gym.load(env_name)

    q_net = q_network.QNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=fc_layer_params)

    # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
    tf_agent = agent_class(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        epsilon_greedy=epsilon_greedy,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    replay_observer = [replay_buffer.add_batch]
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps).run()

    collect_policy = tf_agent.collect_policy
    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)

    iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    experience, _ = iterator.get_next()
    train_op = common.function(tf_agent.train)(experience=experience)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      sess.run(iterator.initializer)
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())
      sess.run(initial_collect_op)

      global_step_val = sess.run(global_step)
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
          log=True,
      )

      collect_call = sess.make_callable(collect_op)
      global_step_call = sess.make_callable(global_step)
      train_step_call = sess.make_callable([train_op, summary_ops])

      timed_at_step = global_step_call()
      collect_time = 0
      train_time = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        # Train/collect/eval.
        start_time = time.time()
        collect_call()
        collect_time += time.time() - start_time
        start_time = time.time()
        for _ in range(train_steps_per_iteration):
          loss_info_value, _ = train_step_call()
        train_time += time.time() - start_time

        global_step_val = global_step_call()

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val,
                       loss_info_value.loss)
          steps_per_sec = (
              (global_step_val - timed_at_step) / (collect_time + train_time))
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          logging.info('%.3f steps/sec', steps_per_sec)
          logging.info('%s', 'collect_time = {}, train_time = {}'.format(
              collect_time, train_time))
          timed_at_step = global_step_val
          collect_time = 0
          train_time = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
          )
Пример #8
0
def train_eval(
    root_dir,
    env_name='sawyer_reach',
    num_iterations=3000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    gamma=0.99,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=200000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    random_seed=0,
    max_future_steps=50,
    actor_std=None,
    log_subset=None,
    ):
  """A simple train and eval for SAC."""
  np.random.seed(random_seed)
  tf.random.set_seed(random_seed)

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env, eval_tf_env, obs_dim = c_learning_envs.load(env_name)

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    if actor_std is None:
      proj_net = tanh_normal_projection_network.TanhNormalProjectionNetwork
    else:
      proj_net = functools.partial(
          tanh_normal_projection_network.TanhNormalProjectionNetwork,
          std_transform=lambda t: actor_std * tf.ones_like(t))

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=proj_net)
    critic_net = c_learning_utils.ClassifierCriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')

    tf_agent = c_learning_agent.CLearningAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=bce_loss,
        gamma=gamma,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
        c_learning_utils.FinalDistance(
            buffer_size=num_eval_episodes, obs_dim=obs_dim),
        c_learning_utils.MinimumDistance(
            buffer_size=num_eval_episodes, obs_dim=obs_dim),
        c_learning_utils.DeltaDistance(
            buffer_size=num_eval_episodes, obs_dim=obs_dim),
    ]
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        c_learning_utils.InitialDistance(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size,
            obs_dim=obs_dim),
        c_learning_utils.FinalDistance(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size,
            obs_dim=obs_dim),
        c_learning_utils.MinimumDistance(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size,
            obs_dim=obs_dim),
        c_learning_utils.DeltaDistance(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size,
            obs_dim=obs_dim),
    ]
    if log_subset is not None:
      start_index, end_index = log_subset
      for name, metrics in [('train', train_metrics), ('eval', eval_metrics)]:
        metrics.extend([
            c_learning_utils.InitialDistance(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size if name == 'train' else 10,
                obs_dim=obs_dim,
                start_index=start_index,
                end_index=end_index,
                name='SubsetInitialDistance'),
            c_learning_utils.FinalDistance(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size if name == 'train' else 10,
                obs_dim=obs_dim,
                start_index=start_index,
                end_index=end_index,
                name='SubsetFinalDistance'),
            c_learning_utils.MinimumDistance(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size if name == 'train' else 10,
                obs_dim=obs_dim,
                start_index=start_index,
                end_index=end_index,
                name='SubsetMinimumDistance'),
            c_learning_utils.DeltaDistance(
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size if name == 'train' else 10,
                obs_dim=obs_dim,
                start_index=start_index,
                end_index=end_index,
                name='SubsetDeltaDistance'),
        ])

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
        max_to_keep=None)

    train_checkpointer.initialize_or_restore()

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Save the hyperparameters
    operative_filename = os.path.join(root_dir, 'operative.gin')
    with tf.compat.v1.gfile.Open(operative_filename, 'w') as f:
      f.write(gin.operative_config_str())
      logging.info(gin.operative_config_str())

    if replay_buffer.num_frames() == 0:
      # Collect initial replay data.
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps '
          'with a random policy.', initial_collect_steps)
      initial_collect_driver.run()

    metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]
    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size,
        num_steps=max_future_steps)
    dataset = dataset.unbatch().filter(_filter_invalid_transition)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    goal_fn = functools.partial(
        c_learning_utils.goal_fn,
        batch_size=batch_size,
        obs_dim=obs_dim,
        gamma=gamma)
    dataset = dataset.map(goal_fn)
    dataset = dataset.prefetch(5)
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    if use_tf_functions:
      train_step = common.function(train_step)

    global_step_val = global_step.numpy()
    while global_step_val < num_iterations:
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val,
                     train_loss.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step_val % eval_interval == 0:
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        metric_utils.log_metrics(eval_metrics)

      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

    return train_loss
Пример #9
0
def train_eval(
        root_dir,
        env_name='Blob2d-v1',
        num_iterations=100000,
        train_sequence_length=1,
        collect_steps_per_iteration=1,
        initial_collect_steps=1500,
        replay_buffer_max_length=10000,
        batch_size=64,
        learning_rate=1e-3,
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for QNetwork
        fc_layer_params=(100, ),
        use_tf_functions=False,
        ## train params
        train_steps_per_iteration=1,
        train_checkpoint_interval=1000,
        policy_checkpoint_interval=1000,
        rb_checkpoint_interval=1000,
        n_step_update=1,
        ## Params for Summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

    env = suite_gym.load('Blob2d-v1')

    tf_env = tf_py_environment.TFPyEnvironment(env)

    action_spec = tf_env.action_spec()

    fc_layer_params = (100, )

    q_net = q_network.QNetwork(tf_env.observation_spec(),
                               tf_env.action_spec(),
                               fc_layer_params=fc_layer_params)

    agent = dqn_agent.DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_step)
    agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = agent.policy
    collect_policy = agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_max_length)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=collect_steps_per_iteration)

    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=agent,
                                             global_step=global_step,
                                             metrics=metric_utils.MetricsGroup(
                                                 train_metrics,
                                                 'train_metrics'))
    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=eval_policy,
                                              global_step=global_step)
    rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'replay_buffer'),
                                          max_to_keep=1,
                                          replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    logging.info(
        'Initializing replay buffer by collecting experience for %d steps with '
        'a random policy.', initial_collect_steps)
    dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=initial_collect_steps).run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )

    if eval_metrics_callback is not None:
        eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=train_sequence_length +
                                       1).prefetch(3)
    iterator = iter(dataset)

    def train_step():
        experience, _ = next(iterator)
        return agent.train(experience)

    if use_tf_functions:
        train_step = common.function(train_step)

    # Main Training loop.
    for _ in range(num_iterations):
        start_time = time.time()
        time_step, policy_state = collect_driver.run(
            time_step=time_step,
            policy_state=policy_state,
        )
        for _ in range(train_steps_per_iteration):
            train_loss = train_step()
        time_acc += time.time() - start_time

        if global_step.numpy() % log_interval == 0:
            logging.info('step = %d, loss = %f', global_step.numpy(),
                         train_loss.loss)
            steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
            logging.info('%.3f steps/sec', steps_per_sec)
            tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                        data=steps_per_sec,
                                        step=global_step)
            timed_at_step = global_step.numpy()
            time_acc = 0

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=train_metrics[:2])

        if global_step.numpy() % train_checkpoint_interval == 0:
            train_checkpointer.save(global_step=global_step.numpy())

        if global_step.numpy() % policy_checkpoint_interval == 0:
            policy_checkpointer.save(global_step=global_step.numpy())

        if global_step.numpy() % rb_checkpoint_interval == 0:
            rb_checkpointer.save(global_step=global_step.numpy())

        if global_step.numpy() % eval_interval == 0:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics',
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics)
    return train_loss
Пример #10
0
def train(root_dir,
          agent,
          environment,
          training_loops,
          steps_per_loop,
          additional_metrics=(),
          training_data_spec_transformation_fn=None):
    """Perform `training_loops` iterations of training.

  Checkpoint results.

  If one or more baseline_reward_fns are provided, the regret is computed
  against each one of them. Here is example baseline_reward_fn:

  def baseline_reward_fn(observation, per_action_reward_fns):
   rewards = ... # compute reward for each arm
   optimal_action_reward = ... # take the maximum reward
   return optimal_action_reward

  Args:
    root_dir: path to the directory where checkpoints and metrics will be
      written.
    agent: an instance of `TFAgent`.
    environment: an instance of `TFEnvironment`.
    training_loops: an integer indicating how many training loops should be run.
    steps_per_loop: an integer indicating how many driver steps should be
      executed and presented to the trainer during each training loop.
    additional_metrics: Tuple of metric objects to log, in addition to default
      metrics `NumberOfEpisodes`, `AverageReturnMetric`, and
      `AverageEpisodeLengthMetric`.
    training_data_spec_transformation_fn: Optional function that transforms the
    data items before they get to the replay buffer.
  """

    # TODO(b/127641485): create evaluation loop with configurable metrics.
    if training_data_spec_transformation_fn is None:
        data_spec = agent.policy.trajectory_spec
    else:
        data_spec = training_data_spec_transformation_fn(
            agent.policy.trajectory_spec)
    replay_buffer = get_replay_buffer(data_spec, environment.batch_size,
                                      steps_per_loop)

    # `step_metric` records the number of individual rounds of bandit interaction;
    # that is, (number of trajectories) * batch_size.
    step_metric = tf_metrics.EnvironmentSteps()
    metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=environment.batch_size)
    ] + list(additional_metrics)

    if isinstance(environment.reward_spec(), dict):
        metrics += [
            tf_metrics.AverageReturnMultiMetric(
                reward_spec=environment.reward_spec(),
                batch_size=environment.batch_size)
        ]
    else:
        metrics += [
            tf_metrics.AverageReturnMetric(batch_size=environment.batch_size)
        ]

    if training_data_spec_transformation_fn is not None:
        add_batch_fn = lambda data: replay_buffer.add_batch(  # pylint: disable=g-long-lambda
            training_data_spec_transformation_fn(data))
    else:
        add_batch_fn = replay_buffer.add_batch

    observers = [add_batch_fn, step_metric] + metrics

    driver = dynamic_step_driver.DynamicStepDriver(env=environment,
                                                   policy=agent.collect_policy,
                                                   num_steps=steps_per_loop *
                                                   environment.batch_size,
                                                   observers=observers)

    training_loop = get_training_loop_fn(driver, replay_buffer, agent,
                                         steps_per_loop)
    checkpoint_manager = restore_and_get_checkpoint_manager(
        root_dir, agent, metrics, step_metric)
    saver = policy_saver.PolicySaver(agent.policy)

    summary_writer = tf.summary.create_file_writer(root_dir)
    summary_writer.set_as_default()
    for _ in range(training_loops):
        training_loop()
        metric_utils.log_metrics(metrics)
        for metric in metrics:
            metric.tf_summaries(train_step=step_metric.result())
        checkpoint_manager.save()
        saver.save(os.path.join(root_dir, 'policy_%d' % step_metric.result()))
Пример #11
0
def train_eval(root_dir, tf_env, eval_tf_env, agent, num_iterations,
               initial_collect_steps, collect_steps_per_iteration,
               replay_buffer_capacity, train_steps_per_iteration, batch_size,
               use_tf_functions, num_eval_episodes, eval_interval,
               train_checkpoint_interval, policy_checkpoint_interval,
               rb_checkpoint_interval, log_interval, summary_interval,
               summaries_flush_secs):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        #tf_metrics.ChosenActionHistogram(buffer_size=num_eval_episodes),
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        #tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        tf_env = tf_env
        eval_tf_env = eval_tf_env

        tf_agent = agent

        train_metrics = [
            #tf_metrics.ChosenActionHistogram(),
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=1),
            #tf_metrics.AverageEpisodeLengthMetric(),
        ]

        diverged = False

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            max_to_keep=1,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  max_to_keep=1,
                                                  global_step=global_step)

        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()
        best_policy = -1000
        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        #Collect initial replay data.
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if np.isnan(train_loss.loss).any():
                diverged = True
                break
            elif np.isinf(train_loss.loss).any():
                diverged = True
                break

            if global_step.numpy() % log_interval == 0:
                print('step = {0}, loss = {1}'.format(global_step.numpy(),
                                                      train_loss.loss))

                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                print('{0} steps/sec'.format(steps_per_sec))
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

                if results["AverageReturn"].numpy() > best_policy:
                    print("New best policy found")
                    print(results["AverageReturn"].numpy())
                    best_policy = results["AverageReturn"].numpy()
                    policy_checkpointer.save(global_step=global_step.numpy())

                metric_utils.log_metrics(eval_metrics)
        return train_loss
Пример #12
0
def train_eval(
        root_dir,
        experiment_name,
        train_eval_dir=None,
        universe='gym',
        env_name='HalfCheetah-v2',
        domain_name='cheetah',
        task_name='run',
        action_repeat=1,
        num_iterations=int(1e7),
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        model_network_ctor=model_distribution_network.ModelDistributionNetwork,
        critic_input='state',
        actor_input='state',
        compressor_descriptor='preprocessor_32_3',
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=int(1e5),
        # increase if necessary since buffers with images are huge
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        model_train_steps_per_iteration=1,
        initial_model_train_steps=100000,
        batch_size=256,
        model_batch_size=32,
        sequence_length=4,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        model_learning_rate=1e-4,
        td_errors_loss_fn=functools.partial(
            tf.compat.v1.losses.mean_squared_error, weights=0.5),
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for summaries and logging
        num_images_per_summary=1,
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=0,  # enable if necessary since buffers with images are huge
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        gpu_allow_growth=False,
        gpu_memory_limit=None):
    """A simple train and eval for SLAC."""
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    if train_eval_dir is None:
        train_eval_dir = get_train_eval_dir(root_dir, universe, env_name,
                                            domain_name, task_name,
                                            experiment_name)
    train_dir = os.path.join(train_eval_dir, 'train')
    eval_dir = os.path.join(train_eval_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(name='AverageReturnEvalPolicy',
                                       buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalPolicy',
            buffer_size=num_eval_episodes),
    ]
    eval_greedy_metrics = [
        py_metrics.AverageReturnMetric(name='AverageReturnEvalGreedyPolicy',
                                       buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(
            name='AverageEpisodeLengthEvalGreedyPolicy',
            buffer_size=num_eval_episodes),
    ]
    eval_summary_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create the environment.
        trainable_model = model_train_steps_per_iteration != 0
        state_only = (actor_input == 'state' and critic_input == 'state'
                      and not trainable_model
                      and initial_model_train_steps == 0)
        # Save time from unnecessarily rendering observations.
        observations_whitelist = ['state'] if state_only else None
        py_env, eval_py_env = load_environments(
            universe,
            env_name=env_name,
            domain_name=domain_name,
            task_name=task_name,
            observations_whitelist=observations_whitelist,
            action_repeat=action_repeat)
        tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True)
        original_control_timestep = get_control_timestep(eval_py_env)
        control_timestep = original_control_timestep * float(action_repeat)
        fps = int(np.round(1.0 / control_timestep))
        render_fps = int(np.round(1.0 / original_control_timestep))

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        if model_train_steps_per_iteration not in (0,
                                                   train_steps_per_iteration):
            raise NotImplementedError
        model_net = model_network_ctor(observation_spec, action_spec)
        if compressor_descriptor == 'model':
            compressor_net = model_net.compressor
        elif re.match('preprocessor_(\d+)_(\d+)', compressor_descriptor):
            m = re.match('preprocessor_(\d+)_(\d+)', compressor_descriptor)
            filters, n_layers = m.groups()
            filters = int(filters)
            n_layers = int(n_layers)
            compressor_net = compressor_network.Preprocessor(filters,
                                                             n_layers=n_layers)
        elif re.match('compressor_(\d+)', compressor_descriptor):
            m = re.match('compressor_(\d+)', compressor_descriptor)
            filters, = m.groups()
            filters = int(filters)
            compressor_net = compressor_network.Compressor(filters)
        elif re.match('softlearning_(\d+)_(\d+)', compressor_descriptor):
            m = re.match('softlearning_(\d+)_(\d+)', compressor_descriptor)
            filters, n_layers = m.groups()
            filters = int(filters)
            n_layers = int(n_layers)
            compressor_net = compressor_network.SoftlearningPreprocessor(
                filters, n_layers=n_layers)
        elif compressor_descriptor == 'd4pg':
            compressor_net = compressor_network.D4pgPreprocessor()
        else:
            raise NotImplementedError(compressor_descriptor)

        actor_state_size = 0
        for _actor_input in actor_input.split('__'):
            if _actor_input == 'state':
                state_size, = observation_spec['state'].shape
                actor_state_size += state_size
            elif _actor_input == 'latent':
                actor_state_size += model_net.state_size
            elif _actor_input == 'feature':
                actor_state_size += compressor_net.feature_size
            elif _actor_input in ('sequence_feature',
                                  'sequence_action_feature'):
                actor_state_size += compressor_net.feature_size * sequence_length
                if _actor_input == 'sequence_action_feature':
                    actor_state_size += tf.compat.dimension_value(
                        action_spec.shape[0]) * (sequence_length - 1)
            else:
                raise NotImplementedError
        actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ),
                                                  dtype=tf.float32)

        critic_state_size = 0
        for _critic_input in critic_input.split('__'):
            if _critic_input == 'state':
                state_size, = observation_spec['state'].shape
                critic_state_size += state_size
            elif _critic_input == 'latent':
                critic_state_size += model_net.state_size
            elif _critic_input == 'feature':
                critic_state_size += compressor_net.feature_size
            elif _critic_input in ('sequence_feature',
                                   'sequence_action_feature'):
                critic_state_size += compressor_net.feature_size * sequence_length
                if _critic_input == 'sequence_action_feature':
                    critic_state_size += tf.compat.dimension_value(
                        action_spec.shape[0]) * (sequence_length - 1)
            else:
                raise NotImplementedError
        critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ),
                                                   dtype=tf.float32)

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            actor_input_spec, action_spec, fc_layer_params=actor_fc_layers)
        critic_net = critic_network.CriticNetwork(
            (critic_input_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = slac_agent.SlacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            model_network=model_net,
            compressor_network=compressor_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            model_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=model_learning_rate),
            sequence_length=sequence_length,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            trainable_model=trainable_model,
            critic_input=critic_input,
            actor_input=actor_input,
            model_batch_size=model_batch_size,
            control_timestep=control_timestep,
            num_images_per_summary=num_images_per_summary,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)
        eval_greedy_py_policy = py_tf_policy.PyTFPolicy(
            greedy_policy.GreedyPolicy(tf_agent.policy))

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(
                py_metrics.AverageReturnMetric(buffer_size=1)),
            tf_py_metric.TFPyMetric(
                py_metrics.AverageEpisodeLengthMetric(buffer_size=1)),
        ]

        collect_policy = tf_agent.collect_policy
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        initial_policy_state = initial_collect_policy.get_initial_state(1)
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run(
                policy_state=initial_policy_state)

        policy_state = collect_policy.get_initial_state(1)
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run(
                policy_state=policy_state)

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[-2]

        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=batch_size,
            num_steps=sequence_length +
            1).unbatch().filter(_filter_invalid_transition).batch(
                batch_size, drop_remainder=True).prefetch(3)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()

        train_op = tf_agent.train(trajectories)
        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        if initial_model_train_steps:
            with tf.name_scope('initial'):
                model_train_op = tf_agent.train_model(trajectories)
                model_summary_ops = []
                for summary_op in tf.compat.v1.summary.all_v2_summary_ops():
                    if summary_op not in summary_ops:
                        model_summary_ops.append(summary_op)

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics + eval_greedy_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=train_metrics[:2])
            if eval_interval:
                eval_images_ph = tf.compat.v1.placeholder(dtype=tf.uint8,
                                                          shape=[None] * 5)
                eval_images_summary = gif_utils.gif_summary_v2(
                    'ObservationVideoEvalPolicy', eval_images_ph, 1, fps)
                eval_render_images_summary = gif_utils.gif_summary_v2(
                    'VideoEvalPolicy', eval_images_ph, 1, render_fps)
                eval_greedy_images_summary = gif_utils.gif_summary_v2(
                    'ObservationVideoEvalGreedyPolicy', eval_images_ph, 1, fps)
                eval_greedy_render_images_summary = gif_utils.gif_summary_v2(
                    'VideoEvalGreedyPolicy', eval_images_ph, 1, render_fps)

        train_config_saver = gin.tf.GinConfigSaverHook(train_dir,
                                                       summarize_config=False)
        eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir,
                                                      summarize_config=False)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=2)
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step,
                                                  max_to_keep=2)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        with tf.compat.v1.Session() as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            # Initialize training.
            sess.run(dataset_iterator.initializer)
            common.initialize_uninitialized_variables(sess)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            train_config_saver.after_create_session(sess)
            eval_config_saver.after_create_session(sess)

            global_step_val = sess.run(global_step)

            if global_step_val == 0:
                if eval_interval:
                    # Initial eval of randomly initialized policy
                    for _eval_metrics, _eval_py_policy, \
                        _eval_render_images_summary, _eval_images_summary in (
                        (eval_metrics, eval_py_policy,
                         eval_render_images_summary, eval_images_summary),
                        (eval_greedy_metrics, eval_greedy_py_policy,
                         eval_greedy_render_images_summary, eval_greedy_images_summary)):
                        compute_summaries(
                            _eval_metrics,
                            eval_py_env,
                            _eval_py_policy,
                            num_episodes=num_eval_episodes,
                            num_episodes_to_render=num_images_per_summary,
                            images_ph=eval_images_ph,
                            render_images_summary=_eval_render_images_summary,
                            images_summary=_eval_images_summary)
                    sess.run(eval_summary_flush_op)

                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            policy_state_val = sess.run(policy_state)
            collect_call = sess.make_callable(collect_op,
                                              feed_list=[policy_state])
            train_step_call = sess.make_callable([train_op, summary_ops])
            if initial_model_train_steps:
                model_train_step_call = sess.make_callable(
                    [model_train_op, model_summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            # steps_per_second summary should always be recorded since it's only called every log_interval steps
            with tf.compat.v2.summary.record_if(True):
                steps_per_second_summary = tf.compat.v2.summary.scalar(
                    name='global_steps_per_sec',
                    data=steps_per_second_ph,
                    step=global_step)

            for iteration in range(global_step_val,
                                   initial_model_train_steps + num_iterations):
                start_time = time.time()
                if iteration < initial_model_train_steps:
                    total_loss_val, _ = model_train_step_call()
                else:
                    time_step_val, policy_state_val = collect_call(
                        policy_state_val)
                    for _ in range(train_steps_per_iteration):
                        total_loss_val, _ = train_step_call()

                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if log_interval and global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss_val.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if (train_checkpoint_interval
                        and global_step_val % train_checkpoint_interval == 0):
                    train_checkpointer.save(global_step=global_step_val)

                if iteration < initial_model_train_steps:
                    continue

                if eval_interval and global_step_val % eval_interval == 0:
                    for _eval_metrics, _eval_py_policy, \
                        _eval_render_images_summary, _eval_images_summary in (
                        (eval_metrics, eval_py_policy,
                         eval_render_images_summary, eval_images_summary),
                        (eval_greedy_metrics, eval_greedy_py_policy,
                         eval_greedy_render_images_summary, eval_greedy_images_summary)):
                        compute_summaries(
                            _eval_metrics,
                            eval_py_env,
                            _eval_py_policy,
                            num_episodes=num_eval_episodes,
                            num_episodes_to_render=num_images_per_summary,
                            images_ph=eval_images_ph,
                            render_images_summary=_eval_render_images_summary,
                            images_summary=_eval_images_summary)
                    sess.run(eval_summary_flush_op)

                if (policy_checkpoint_interval
                        and global_step_val % policy_checkpoint_interval == 0):
                    policy_checkpointer.save(global_step=global_step_val)

                if (rb_checkpoint_interval
                        and global_step_val % rb_checkpoint_interval == 0):
                    rb_checkpointer.save(global_step=global_step_val)
Пример #13
0
    def __init__(self,
                 env,
                 algorithm,
                 observation_transformer: Callable = None,
                 observers=[],
                 use_rollout_state=False,
                 metrics=[],
                 training=True,
                 greedy_predict=False,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 summarize_action_distributions=False,
                 train_step_counter=None):
        """Create a PolicyDriver.

        Args:
            env (TFEnvironment): A TFEnvoronmnet
            algorithm (OnPolicyAlgorith): The algorithm for training
            observers (list[Callable]): An optional list of observers that are
                updated after every step in the environment. Each observer is a
                callable(time_step.Trajectory).
            use_rollout_state (bool): Include the RNN state for the experiences
                used for off-policy training
            metrics (list[TFStepMetric]): An optiotional list of metrics.
            training (bool): True for training, false for evaluating
            greedy_predict (bool): use greedy action for evaluation (i.e.
                training==False).
            debug_summaries (bool): A bool to gather debug summaries.
            summarize_grads_and_vars (bool): If True, gradient and network
                variable summaries will be written during training.
            summarize_action_distributions (bool): If True, generate summaris
                for the action distributions.
            train_step_counter (tf.Variable): An optional counter to increment
                every time the a new iteration is started. If None, it will use
                tf.summary.experimental.get_step(). If this is still None, a
                counter will be created.
        """
        metric_buf_size = max(10, env.batch_size)
        standard_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=metric_buf_size),
            tf_metrics.AverageEpisodeLengthMetric(buffer_size=metric_buf_size),
        ]
        self._metrics = standard_metrics + metrics
        self._exp_observers = []
        self._use_rollout_state = use_rollout_state

        super(PolicyDriver, self).__init__(env, None,
                                           observers + self._metrics)

        self._algorithm = algorithm
        self._training = training
        self._greedy_predict = greedy_predict
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._summarize_action_distributions = summarize_action_distributions
        self._observation_transformer = observation_transformer
        self._train_step_counter = common.get_global_counter(
            train_step_counter)
        self._proc = psutil.Process(os.getpid())
        if training:
            self._policy_state_spec = algorithm.train_state_spec
        else:
            self._policy_state_spec = algorithm.predict_state_spec
        self._initial_state = self.get_initial_policy_state()
Пример #14
0
class T48GymTensorflowContext:

    max_episode_steps = 20000  # @param {type:"integer"}
    replay_buffer_max_length = 100000  # @param {type:"integer"}
    learning_rate = 1e-3  # @param {type:"number"}
    train_dir = "/Volumes/SECONDARY/t48_gym_tensorflow"
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    def __init__(self):
        self._train_py_env = suite_gym.load(T48GymEnv.GYM_ENV_NAME, max_episode_steps=T48GymTensorflowContext.max_episode_steps)
        self._eval_py_env = suite_gym.load(T48GymEnv.GYM_ENV_NAME, max_episode_steps=T48GymTensorflowContext.max_episode_steps)
        self._train_env = tf_py_environment.TFPyEnvironment(self._train_py_env)
        self._eval_env = tf_py_environment.TFPyEnvironment(self._eval_py_env)

        self._global_step = tf.compat.v1.train.get_or_create_global_step()

        self._q_net = q_network.QNetwork(
          self._train_env.observation_spec(),
          self._train_env.action_spec(),
          fc_layer_params=(100,))
        self._agent = dqn_agent.DdqnAgent(
            self._train_env.time_step_spec(),
            self._train_env.action_spec(),
            q_network=self._q_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=T48GymTensorflowContext.learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            train_step_counter=self._global_step,
            epsilon_greedy=0.0)
        self._agent.initialize()
        self._agent.train = common.function(self._agent.train)

        self._replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=self._agent.collect_data_spec,
            batch_size=self._train_env.batch_size,
            max_length=T48GymTensorflowContext.replay_buffer_max_length)
        self._dataset = self._replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=self._train_env.batch_size,
            num_steps=2).prefetch(3)
        self._agent.initialize()

        self._iterator = iter(self._dataset)

        self._RANDOM_POLICY = random_tf_policy.RandomTFPolicy(self._train_env.time_step_spec(),
                                                              self._train_env.action_spec())
        self._collect_policy = self._agent.collect_policy
        self._eval_policy = self._agent.policy

        self._collect_driver = dynamic_step_driver.DynamicStepDriver(
            self._train_env,
            self._collect_policy,
            observers=[self._replay_buffer.add_batch] + T48GymTensorflowContext.train_metrics,
            num_steps=2)

        self._train_checkpointer = common.Checkpointer(
            ckpt_dir=T48GymTensorflowContext.train_dir,
            global_step=self._global_step,
            agent=self._agent,
            metrics=metric_utils.MetricsGroup(T48GymTensorflowContext.train_metrics, 'train_metrics'))
        self._policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(T48GymTensorflowContext.train_dir, 'policy'),
            global_step=self._global_step,
            policy=self._eval_policy)
        self._rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(T48GymTensorflowContext.train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=self._replay_buffer)

        self._tf_policy_saver = policy_saver.PolicySaver(self._agent.policy)

        self._train_checkpointer.initialize_or_restore()
        self._policy_checkpointer.initialize_or_restore()
        self._rb_checkpointer.initialize_or_restore()

    def run_episodes(self, episode_count):
        for episode_num in range(episode_count):

            print("StartingEpisode:", episode_num)
            time_step = self._train_env.reset()
            policy_state = self._collect_policy.get_initial_state(self._train_env.batch_size)
            print("ResetEnvironment")

            while not time_step.is_last():
                time_step, policy_state = self._collect_driver.run(
                    time_step=time_step,
                    policy_state=policy_state,
                )
                print("TimeStep:", str(time_step))
                print("PolicyState:", policy_state)

                experience, unused_info = next(self._iterator)
                print("Experience:", experience)
                print("UnusedInfo:", unused_info)

                train_loss = self._agent.train(experience)
                print("TrainLoss:", str(train_loss))
                print()

                self._tf_policy_saver.save(os.path.join(T48GymTensorflowContext.train_dir, 'policy'))
                self._train_py_env.render()

            print("EndOfEpisode:", episode_num)
            print()
            self._train_checkpointer.save(global_step=self._global_step.numpy())
            self._rb_checkpointer.save(global_step=self._global_step.numpy())
            self._policy_checkpointer.save(global_step=self._global_step.numpy())
Пример #15
0
def train_eval(

        ##############################################
        # types of params:
        # 0: specific to algorithm (gin file 0)
        # 1: specific to environment (gin file 1)
        # 2: specific to experiment (gin file 2 + command line)

        # Note: there are other important params
        # in eg ModelDistributionNetwork that the gin files specify
        # like sparse vs dense rewards, latent dimensions, etc.
        ##############################################

    # basic params for running/logging experiment
    root_dir,  # 2
        experiment_name,  # 2
        num_iterations=int(1e7),  # 2
        seed=1,  # 2
        gpu_allow_growth=False,  # 2
        gpu_memory_limit=None,  # 2
        verbose=True,  # 2
        policy_checkpoint_freq_in_iter=100,  # policies needed for future eval                             # 2
        train_checkpoint_freq_in_iter=0,  #default don't save                                              # 2
        rb_checkpoint_freq_in_iter=0,  #default don't save                                                 # 2
        logging_freq_in_iter=10,  # printing to terminal                                                   # 2
        summary_freq_in_iter=10,  # saving to tb                                                           # 2
        num_images_per_summary=2,  # 2
        summaries_flush_secs=10,  # 2
        max_episode_len_override=None,  # 2
        num_trials_to_render=1,  # 2

        # environment, action mode, etc.
    env_name='HalfCheetah-v2',  # 1
        action_repeat=1,  # 1
        action_mode='joint_position',  # joint_position or joint_delta_position                           # 1
        double_camera=False,  # camera input                                                               # 1
        universe='gym',  # default
        task_reward_dim=1,  # default

        # dims for all networks
    actor_fc_layers=(256, 256),  # 1
        critic_obs_fc_layers=None,  # 1
        critic_action_fc_layers=None,  # 1
        critic_joint_fc_layers=(256, 256),  # 1
        num_repeat_when_concatenate=None,  # 1

        # networks
    critic_input='state',  # 0
        actor_input='state',  # 0

        # specifying tasks and eval
    episodes_per_trial=1,  # 2
        num_train_tasks=10,  # 2
        num_eval_tasks=10,  # 2
        num_eval_trials=10,  # 2
        eval_interval=10,  # 2
        eval_on_holdout_tasks=True,  # 2

        # data collection/buffer
    init_collect_trials_per_task=None,  # 2
        collect_trials_per_task=None,  # 2
        num_tasks_to_collect_per_iter=5,  # 2
        replay_buffer_capacity=int(1e5),  # 2

        # training
    init_model_train_ratio=0.8,  # 2
        model_train_ratio=1,  # 2
        model_train_freq=1,  # 2
        ac_train_ratio=1,  # 2
        ac_train_freq=1,  # 2
        num_tasks_per_train=5,  # 2
        train_trials_per_task=5,  # 2
        model_bs_in_steps=256,  # 2
        ac_bs_in_steps=128,  # 2

        # default AC learning rates, gamma, etc.
    target_update_tau=0.005,
        target_update_period=1,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        model_learning_rate=1e-4,
        td_errors_loss_fn=functools.partial(
            tf.compat.v1.losses.mean_squared_error, weights=0.5),
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        log_image_strips=False,
        stop_model_training=1E10,
        eval_only=False,  # evaluate checkpoints ONLY
        log_image_observations=False,
        load_offline_data=False,  # whether to use offline data
        offline_data_dir=None,  # replay buffer's dir
        offline_episode_len=None,  # episode len of episodes stored in rb
        offline_ratio=0,  # ratio of data that is from offline buffer
):

    g = tf.Graph()

    # register all gym envs
    max_steps_dict = {
        "HalfCheetahVel-v0": 50,
        "SawyerReach-v0": 40,
        "SawyerReachMT-v0": 40,
        "SawyerPeg-v0": 40,
        "SawyerPegMT-v0": 40,
        "SawyerPegMT4box-v0": 40,
        "SawyerShelfMT-v0": 40,
        "SawyerKitchenMT-v0": 40,
        "SawyerShelfMT-v2": 40,
        "SawyerButtons-v0": 40,
    }
    if max_episode_len_override:
        max_steps_dict[env_name] = max_episode_len_override
    register_all_gym_envs(max_steps_dict)

    # set max_episode_len based on our env
    max_episode_len = max_steps_dict[env_name]

    ######################################################
    # Calculate additional params
    ######################################################

    # convert to number of steps
    env_steps_per_trial = episodes_per_trial * max_episode_len
    real_env_steps_per_trial = episodes_per_trial * (max_episode_len + 1)
    env_steps_per_iter = num_tasks_to_collect_per_iter * collect_trials_per_task * env_steps_per_trial
    per_task_collect_steps = collect_trials_per_task * env_steps_per_trial

    # initial collect + train
    init_collect_env_steps = num_train_tasks * init_collect_trials_per_task * env_steps_per_trial
    init_model_train_steps = int(init_collect_env_steps *
                                 init_model_train_ratio)

    # collect + train
    collect_env_steps_per_iter = num_tasks_to_collect_per_iter * per_task_collect_steps
    model_train_steps_per_iter = int(env_steps_per_iter * model_train_ratio)
    ac_train_steps_per_iter = int(env_steps_per_iter * ac_train_ratio)

    # other
    global_steps_per_iter = collect_env_steps_per_iter + model_train_steps_per_iter + ac_train_steps_per_iter
    sample_episodes_per_task = train_trials_per_task * episodes_per_trial  # number of episodes to sample from each replay
    model_bs_in_trials = model_bs_in_steps // real_env_steps_per_trial

    # assertions that make sure parameters make sense
    assert model_bs_in_trials > 0, "model batch size need to be at least as big as one full real trial"
    assert num_tasks_to_collect_per_iter <= num_train_tasks, "when sampling replace=False"
    assert num_tasks_per_train * train_trials_per_task >= model_bs_in_trials, "not enough data for one batch model train"
    assert num_tasks_per_train * train_trials_per_task * env_steps_per_trial >= ac_bs_in_steps, "not enough data for one batch ac train"

    ######################################################
    # Print a summary of params
    ######################################################
    MELD_summary_string = f"""\n\n\n
==============================================================
==============================================================
  \n
  MELD algorithm summary:

  * each trial consists of {episodes_per_trial} episodes
  * episode length: {max_episode_len}, trial length: {env_steps_per_trial}
  * {num_train_tasks} train tasks, {num_eval_tasks} eval tasks, hold-out: {eval_on_holdout_tasks}
  * environment: {env_name}
  
  For each of {num_train_tasks} tasks:
    Do {init_collect_trials_per_task} trials of initial collect
  (total {init_collect_env_steps} env steps)
  
  Do {init_model_train_steps} steps of initial model training
    
  For i in range(inf):
    For each of {num_tasks_to_collect_per_iter} randomly selected tasks:
      Do {collect_trials_per_task} trials of collect
    (which is {collect_trials_per_task*env_steps_per_trial} env steps per task)
    (for a total of {num_tasks_to_collect_per_iter*collect_trials_per_task*env_steps_per_trial} env steps in the iteration)
    
    if i % model_train_freq(={model_train_freq}):
      Do {model_train_steps_per_iter} steps of model training
        - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials.
        - pick randomly {model_bs_in_trials} trials, train model on whole trials.
    
    if i % ac_train_freq(={ac_train_freq}):
      Do {ac_train_steps_per_iter} steps of ac training
        - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials.
        - pick randomly {ac_bs_in_steps} transitions, not including between trial transitions, 
          to train ac.
  
  
  * Other important params:
  Evaluate policy every {eval_interval} iters, equivalent to {global_steps_per_iter*eval_interval/1000:.1f}k global steps
  Average evaluation across {num_eval_trials} trials
  Save summary to tensorboard every {summary_freq_in_iter} iters, equivalent to {global_steps_per_iter*summary_freq_in_iter/1000:.1f}k global steps
  Checkpoint:
   - training checkpoint every {train_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*train_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint
   - policy checkpoint every {policy_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*policy_checkpoint_freq_in_iter//1000}k global steps, keep all checkpoints
   - replay buffer checkpoint every {rb_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*rb_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint
    
  \n
=============================================================
=============================================================
  """

    print(MELD_summary_string)
    time.sleep(1)

    ######################################################
    # Seed + name + GPU configs + directories for saving
    ######################################################
    np.random.seed(int(seed))
    experiment_name += "_seed" + str(seed)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    train_eval_dir = get_train_eval_dir(root_dir, universe, env_name,
                                        experiment_name)
    train_dir = os.path.join(train_eval_dir, 'train')
    eval_dir = os.path.join(train_eval_dir, 'eval')
    eval_dir_2 = os.path.join(train_eval_dir, 'eval2')

    ######################################################
    # Train and Eval Summary Writers
    ######################################################
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_summary_flush_op = eval_summary_writer.flush()

    eval_logger = Logger(eval_dir_2)

    ######################################################
    # Train and Eval metrics
    ######################################################
    eval_buffer_size = num_eval_trials * episodes_per_trial * max_episode_len  # across all eval trials in each evaluation
    eval_metrics = []
    for position in range(
            episodes_per_trial
    ):  # have metrics for each episode position, to track whether it is learning
        eval_metrics_pos = [
            py_metrics.AverageReturnMetric(name='c_AverageReturnEval_' +
                                           str(position),
                                           buffer_size=eval_buffer_size),
            py_metrics.AverageEpisodeLengthMetric(
                name='f_AverageEpisodeLengthEval_' + str(position),
                buffer_size=eval_buffer_size),
            custom_metrics.AverageScoreMetric(
                name="d_AverageScoreMetricEval_" + str(position),
                buffer_size=eval_buffer_size),
        ]
        eval_metrics.extend(eval_metrics_pos)

    train_buffer_size = num_train_tasks * episodes_per_trial
    train_metrics = [
        tf_metrics.NumberOfEpisodes(name='NumberOfEpisodes'),
        tf_metrics.EnvironmentSteps(name='EnvironmentSteps'),
        tf_py_metric.TFPyMetric(
            py_metrics.AverageReturnMetric(name="a_AverageReturnTrain",
                                           buffer_size=train_buffer_size)),
        tf_py_metric.TFPyMetric(
            py_metrics.AverageEpisodeLengthMetric(
                name="e_AverageEpisodeLengthTrain",
                buffer_size=train_buffer_size)),
        tf_py_metric.TFPyMetric(
            custom_metrics.AverageScoreMetric(name="b_AverageScoreTrain",
                                              buffer_size=train_buffer_size)),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step(
    )  # will be use to record number of model grad steps + ac grad steps + env_step

    log_cond = get_log_condition_tensor(
        global_step, init_collect_trials_per_task, env_steps_per_trial,
        num_train_tasks, init_model_train_steps, collect_trials_per_task,
        num_tasks_to_collect_per_iter, model_train_steps_per_iter,
        ac_train_steps_per_iter, summary_freq_in_iter, eval_interval)

    with tf.compat.v2.summary.record_if(log_cond):

        ######################################################
        # Create env
        ######################################################
        py_env, eval_py_env, train_tasks, eval_tasks = load_environments(
            universe,
            action_mode,
            env_name=env_name,
            observations_whitelist=['state', 'pixels', "env_info"],
            action_repeat=action_repeat,
            num_train_tasks=num_train_tasks,
            num_eval_tasks=num_eval_tasks,
            eval_on_holdout_tasks=eval_on_holdout_tasks,
            return_multiple_tasks=True,
        )
        override_reward_func = None
        if load_offline_data:
            py_env.set_task_dict(train_tasks)
            override_reward_func = py_env.override_reward_func

        tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True)

        # Get data specs from env
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        original_control_timestep = get_control_timestep(eval_py_env)

        # fps
        control_timestep = original_control_timestep * float(action_repeat)
        render_fps = int(np.round(1.0 / original_control_timestep))

        ######################################################
        # Latent variable model
        ######################################################
        if verbose:
            print("-- start constructing model networks --")

        model_net = ModelDistributionNetwork(
            double_camera=double_camera,
            observation_spec=observation_spec,
            num_repeat_when_concatenate=num_repeat_when_concatenate,
            task_reward_dim=task_reward_dim,
            episodes_per_trial=episodes_per_trial,
            max_episode_len=max_episode_len
        )  # rest of arguments provided via gin

        if verbose:
            print("-- finish constructing AC networks --")

        ######################################################
        # Compressor Network for Actor/Critic
        # The model's compressor is also used by the AC
        # compressor function: images --> features
        ######################################################

        compressor_net = model_net.compressor

        ######################################################
        # Specs for Actor and Critic
        ######################################################
        if actor_input == 'state':
            actor_state_size = observation_spec['state'].shape[0]
        elif actor_input == 'latentSample':
            actor_state_size = model_net.state_size
        elif actor_input == "latentDistribution":
            actor_state_size = 2 * model_net.state_size  # mean and (diagonal) variance of gaussian, of two latents
        else:
            raise NotImplementedError
        actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ),
                                                  dtype=tf.float32)

        if critic_input == 'state':
            critic_state_size = observation_spec['state'].shape[0]
        elif critic_input == 'latentSample':
            critic_state_size = model_net.state_size
        elif critic_input == "latentDistribution":
            critic_state_size = 2 * model_net.state_size  # mean and (diagonal) variance of gaussian, of two latents
        else:
            raise NotImplementedError
        critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ),
                                                   dtype=tf.float32)

        ######################################################
        # Actor and Critic Networks
        ######################################################
        if verbose:
            print("-- start constructing Actor and Critic networks --")

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            actor_input_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
        )

        critic_net = critic_network.CriticNetwork(
            (critic_input_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        if verbose:
            print("-- finish constructing AC networks --")
            print("-- start constructing agent --")

        ######################################################
        # Create the agent
        ######################################################

        which_posterior_overwrite = None
        which_reward_overwrite = None

        meld_agent = MeldAgent(
            # specs
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            # step counter
            train_step_counter=
            global_step,  # will count number of model training steps
            # networks
            actor_network=actor_net,
            critic_network=critic_net,
            model_network=model_net,
            compressor_network=compressor_net,
            # optimizers
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            model_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=model_learning_rate),
            # target update
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            # inputs
            critic_input=critic_input,
            actor_input=actor_input,
            # bs stuff
            model_batch_size=model_bs_in_steps,
            ac_batch_size=ac_bs_in_steps,
            # other
            num_tasks_per_train=num_tasks_per_train,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            control_timestep=control_timestep,
            num_images_per_summary=num_images_per_summary,
            task_reward_dim=task_reward_dim,
            episodes_per_trial=episodes_per_trial,
            # offline data
            override_reward_func=override_reward_func,
            offline_ratio=offline_ratio,
        )

        if verbose:
            print("-- finish constructing agent --")

        ######################################################
        # Replay buffers + observers to add data to them
        ######################################################
        replay_buffers = []
        replay_observers = []
        for _ in range(num_train_tasks):
            replay_buffer_episodic = episodic_replay_buffer.EpisodicReplayBuffer(
                meld_agent.collect_policy.
                trajectory_spec,  # spec of each point stored in here (i.e. Trajectory)
                capacity=replay_buffer_capacity,
                completed_only=
                True,  # in as_dataset, if num_steps is None, this means return full episodes
                # device='GPU:0', # gpu not supported for some reason
                begin_episode_fn=lambda traj: traj.is_first()[
                    0],  # first step of seq we add should be is_first
                end_episode_fn=lambda traj: traj.is_last()[
                    0],  # last step of seq we add should be is_last
                dataset_drop_remainder=
                True,  #`as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items
            )
            replay_buffer = StatefulEpisodicReplayBuffer(
                replay_buffer_episodic)  # adding num_episodes here is bad
            replay_buffers.append(replay_buffer)
            replay_observers.append([replay_buffer.add_sequence])

        if load_offline_data:
            # for each task, has a separate replay buffer for relabeled data
            replay_buffers_withRelabel = []
            replay_observers_withRelabel = []
            for _ in range(num_train_tasks):
                replay_buffer_episodic_withRelabel = episodic_replay_buffer.EpisodicReplayBuffer(
                    meld_agent.collect_policy.
                    trajectory_spec,  # spec of each point stored in here (i.e. Trajectory)
                    capacity=replay_buffer_capacity,
                    completed_only=
                    True,  # in as_dataset, if num_steps is None, this means return full episodes
                    # device='GPU:0', # gpu not supported for some reason
                    begin_episode_fn=lambda traj: traj.is_first()[
                        0],  # first step of seq we add should be is_first
                    end_episode_fn=lambda traj: traj.is_last()[
                        0],  # last step of seq we add should be is_last
                    dataset_drop_remainder=True,
                    # `as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items
                )
                replay_buffer_withRelabel = StatefulEpisodicReplayBuffer(
                    replay_buffer_episodic_withRelabel
                )  # adding num_episodes here is bad
                replay_buffers_withRelabel.append(replay_buffer_withRelabel)
                replay_observers_withRelabel.append(
                    [replay_buffer_withRelabel.add_sequence])

        if verbose:
            print("-- finish constructing replay buffers --")
            print("-- start constructing policies and collect ops --")

        ######################################################
        # Policies
        #####################################################

        # init collect policy (random)
        init_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        # eval
        eval_py_policy = py_tf_policy.PyTFPolicy(meld_agent.policy)

        ################################################################################
        # Collect ops : use policies to get data + have the observer put data into corresponding RB
        ################################################################################

        #init collection (with random policy)
        init_collect_ops = []
        for task_idx in range(num_train_tasks):
            # put init data into the rb + track with the train metric
            observers = replay_observers[task_idx] + train_metrics

            # initial collect op
            init_collect_op = DynamicTrialDriver(
                tf_env,
                init_collect_policy,
                num_trials_to_collect=init_collect_trials_per_task,
                observers=observers,
                episodes_per_trial=
                episodes_per_trial,  # policy state will not be reset within these episodes
                max_episode_len=max_episode_len,
            ).run()  # collect one trial
            init_collect_ops.append(init_collect_op)

        # data collection for training (with collect policy)
        collect_ops = []
        for task_idx in range(num_train_tasks):
            collect_op = DynamicTrialDriver(
                tf_env,
                meld_agent.collect_policy,
                num_trials_to_collect=collect_trials_per_task,
                observers=replay_observers[task_idx] +
                train_metrics,  # put data into 1st RB + track with 1st pol metrics
                episodes_per_trial=
                episodes_per_trial,  # policy state will not be reset within these episodes
                max_episode_len=max_episode_len,
            ).run()  # collect one trial
            collect_ops.append(collect_op)

        if verbose:
            print("-- finish constructing policies and collect ops --")
            print("-- start constructing replay buffer->training pipeline --")

        ######################################################
        # replay buffer --> dataset --> iterate to get trajecs for training
        ######################################################

        # get some data from all task replay buffers (even though won't actually train on all of them)
        dataset_iterators = []
        all_tasks_trajectories_fromdense = []
        for task_idx in range(num_train_tasks):
            dataset = replay_buffers[task_idx].as_dataset(
                sample_batch_size=
                sample_episodes_per_task,  # number of episodes to sample
                num_steps=max_episode_len + 1
            ).prefetch(
                3
            )  # +1 to include the last state: a trajectory with n transition has n+1 states
            # iterator to go through the data
            dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
                dataset)
            dataset_iterators.append(dataset_iterator)
            # get sample_episodes_per_task sequences, each of length num_steps
            trajectories_task_i, _ = dataset_iterator.get_next()
            all_tasks_trajectories_fromdense.append(trajectories_task_i)

        if load_offline_data:
            # have separate dataset for relabel data
            dataset_iterators_withRelabel = []
            all_tasks_trajectories_fromdense_withRelabel = []
            for task_idx in range(num_train_tasks):
                dataset = replay_buffers_withRelabel[task_idx].as_dataset(
                    sample_batch_size=
                    sample_episodes_per_task,  # number of episodes to sample
                    num_steps=offline_episode_len + 1
                ).prefetch(
                    3
                )  # +1 to include the last state: a trajectory with n transition has n+1 states
                # iterator to go through the data
                dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
                    dataset)
                dataset_iterators_withRelabel.append(dataset_iterator)
                # get sample_episodes_per_task sequences, each of length num_steps
                trajectories_task_i, _ = dataset_iterator.get_next()
                all_tasks_trajectories_fromdense_withRelabel.append(
                    trajectories_task_i)

        if verbose:
            print("-- finish constructing replay buffer->training pipeline --")
            print("-- start constructing model and AC training ops --")

        ######################################
        # Decoding latent samples into rewards
        ######################################

        latent_samples_1_ph = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=(None, None, meld_agent._model_network.latent1_size))
        latent_samples_2_ph = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=(None, None, meld_agent._model_network.latent2_size))
        decode_rews_op = meld_agent._model_network.decode_latents_into_reward(
            latent_samples_1_ph, latent_samples_2_ph)

        ######################################
        # Model/Actor/Critic train + summary ops
        ######################################

        # train AC on data from replay buffer
        if load_offline_data:
            ac_train_op = meld_agent.train_ac_meld(
                all_tasks_trajectories_fromdense,
                all_tasks_trajectories_fromdense_withRelabel)
        else:
            ac_train_op = meld_agent.train_ac_meld(
                all_tasks_trajectories_fromdense)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        if verbose:
            print("-- finish constructing AC training ops --")

        ############################
        # Model train + summary ops
        ############################

        # train model on data from replay buffer
        if load_offline_data:
            model_train_op, check_step_types = meld_agent.train_model_meld(
                all_tasks_trajectories_fromdense,
                all_tasks_trajectories_fromdense_withRelabel)
        else:
            model_train_op, check_step_types = meld_agent.train_model_meld(
                all_tasks_trajectories_fromdense)

        model_summary_ops, model_summary_ops_2 = [], []
        for summary_op in tf.compat.v1.summary.all_v2_summary_ops():
            if summary_op not in summary_ops:
                model_summary_ops.append(summary_op)

        if verbose:
            print("-- finish constructing model training ops --")
            print("-- start constructing checkpointers --")

        ########################
        # Eval metrics
        ########################

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=train_metrics[:2])

        ########################
        # Create savers
        ########################
        train_config_saver = gin.tf.GinConfigSaverHook(train_dir,
                                                       summarize_config=False)
        eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir,
                                                      summarize_config=False)

        ########################
        # Create checkpointers
        ########################

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=meld_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=1)
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=meld_agent.policy,
            global_step=global_step,
            max_to_keep=99999999999
        )  # keep many policy checkpoints, in case of future eval
        rb_checkpointers = []
        for buffer_idx in range(len(replay_buffers)):
            rb_checkpointer = common.Checkpointer(
                ckpt_dir=os.path.join(train_dir, 'replay_buffers/',
                                      "task" + str(buffer_idx)),
                max_to_keep=1,
                replay_buffer=replay_buffers[buffer_idx])
            rb_checkpointers.append(rb_checkpointer)

        if load_offline_data:  # for LOADING data not for checkpointing. No new data going in anyways
            rb_checkpointers_withRelabel = []
            for buffer_idx in range(len(replay_buffers_withRelabel)):
                ckpt_dir = os.path.join(offline_data_dir,
                                        "task" + str(buffer_idx))
                rb_checkpointer = common.Checkpointer(
                    ckpt_dir=ckpt_dir,
                    max_to_keep=99999999999,
                    replay_buffer=replay_buffers_withRelabel[buffer_idx])
                rb_checkpointers_withRelabel.append(rb_checkpointer)
            # Notice: these replay buffers need to follow the same sequence of tasks as the current one

        if verbose:
            print("-- finish constructing checkpointers --")
            print("-- start main training loop --")

        with tf.compat.v1.Session() as sess:

            ########################
            # Initialize
            ########################

            if eval_only:
                sess.run(eval_summary_writer.init())
                load_eval_log(
                    train_eval_dir=train_eval_dir,
                    meld_agent=meld_agent,
                    global_step=global_step,
                    sess=sess,
                    eval_metrics=eval_metrics,
                    eval_py_env=eval_py_env,
                    eval_py_policy=eval_py_policy,
                    num_eval_trials=num_eval_trials,
                    max_episode_len=max_episode_len,
                    episodes_per_trial=episodes_per_trial,
                    log_image_strips=log_image_strips,
                    num_trials_to_render=num_trials_to_render,
                    train_tasks=
                    train_tasks,  # in case want to eval on a train task
                    eval_tasks=eval_tasks,
                    model_net=model_net,
                    render_fps=render_fps,
                    decode_rews_op=decode_rews_op,
                    latent_samples_1_ph=latent_samples_1_ph,
                    latent_samples_2_ph=latent_samples_2_ph,
                )
                return

            # Initialize checkpointing
            train_checkpointer.initialize_or_restore(sess)
            for rb_checkpointer in rb_checkpointers:
                rb_checkpointer.initialize_or_restore(sess)

            if load_offline_data:
                for rb_checkpointer in rb_checkpointers_withRelabel:
                    rb_checkpointer.initialize_or_restore(sess)

            # Initialize dataset iterators
            for dataset_iterator in dataset_iterators:
                sess.run(dataset_iterator.initializer)

            if load_offline_data:
                for dataset_iterator in dataset_iterators_withRelabel:
                    sess.run(dataset_iterator.initializer)

            # Initialize variables
            common.initialize_uninitialized_variables(sess)

            # Initialize summary writers
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            # Initialize savers
            train_config_saver.after_create_session(sess)
            eval_config_saver.after_create_session(sess)
            # Get value of step counter
            global_step_val = sess.run(global_step)

            if verbose:
                print("====== finished initialization ======")

            ################################################################
            # If this is start of new exp (i.e., 1st step) and not continuing old exp
            # eval rand policy + do initial data collection
            ################################################################
            fresh_start = (global_step_val == 0)

            if fresh_start:

                ########################
                # Evaluate initial policy
                ########################

                if eval_interval:
                    logging.info(
                        '\n\nDoing evaluation of initial policy on %d trials with randomly sampled tasks',
                        num_eval_trials)
                    perform_eval_and_summaries_meld(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_eval_trials,
                        max_episode_len,
                        episodes_per_trial,
                        log_image_strips=log_image_strips,
                        num_trials_to_render=num_eval_tasks,
                        eval_tasks=eval_tasks,
                        latent1_size=model_net.latent1_size,
                        latent2_size=model_net.latent2_size,
                        logger=eval_logger,
                        global_step_val=global_step_val,
                        render_fps=render_fps,
                        decode_rews_op=decode_rews_op,
                        latent_samples_1_ph=latent_samples_1_ph,
                        latent_samples_2_ph=latent_samples_2_ph,
                        log_image_observations=log_image_observations,
                    )
                    sess.run(eval_summary_flush_op)
                    logging.info(
                        'Done with evaluation of initial (random) policy.\n\n')

                ########################
                # Initial data collection
                ########################

                logging.info(
                    '\n\nGlobal step %d: Beginning init collect op with random policy. Collecting %dx {%d, %d} trials for each task',
                    global_step_val, init_collect_trials_per_task,
                    max_episode_len, episodes_per_trial)

                init_increment_global_step_op = global_step.assign_add(
                    env_steps_per_trial * init_collect_trials_per_task)

                for task_idx in range(num_train_tasks):
                    logging.info('on task %d / %d', task_idx + 1,
                                 num_train_tasks)
                    py_env.set_task_for_env(train_tasks[task_idx])
                    sess.run([
                        init_collect_ops[task_idx],
                        init_increment_global_step_op
                    ])  # incremented gs in granularity of task

                rb_checkpointer.save(global_step=global_step_val)
                logging.info('Finished init collect.\n\n')

            else:
                logging.info(
                    '\n\nGlobal step %d from loaded experiment: Skipping init collect op.\n\n',
                    global_step_val)

            #########################
            # Create calls
            #########################

            # [1] calls for running the policies to collect training data
            collect_calls = []
            increment_global_step_op = global_step.assign_add(
                env_steps_per_trial * collect_trials_per_task)
            for task_idx in range(num_train_tasks):
                collect_calls.append(
                    sess.make_callable(
                        [collect_ops[task_idx], increment_global_step_op]))

            # [2] call for doing a training step (A + C)
            ac_train_step_call = sess.make_callable([ac_train_op, summary_ops])

            # [3] call for doing a training step (model)
            model_train_step_call = sess.make_callable(
                [model_train_op, check_step_types, model_summary_ops])

            # [4] call for evaluating what global_step number we're on
            global_step_call = sess.make_callable(global_step)

            # reset keeping track of steps/time
            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            with train_summary_writer.as_default(
            ), tf.compat.v2.summary.record_if(True):
                steps_per_second_summary = tf.compat.v2.summary.scalar(
                    name='global_steps_per_sec',
                    data=steps_per_second_ph,
                    step=global_step)

            #################################
            # init model training
            #################################
            if fresh_start:
                logging.info(
                    '\n\nPerforming %d steps of init model training, each step on %d random tasks',
                    init_model_train_steps, num_tasks_per_train)
                for i in range(init_model_train_steps):

                    temp_start = time.time()
                    if i % 100 == 0:
                        print(".... init model training ", i, "/",
                              init_model_train_steps)

                    # init model training
                    total_loss_value_model, check_step_types, _ = model_train_step_call(
                    )

                    if PRINT_TIMING:
                        print("single model train step: ",
                              time.time() - temp_start)

            if verbose:
                print("\n\n\n-- start training loop --\n")

            #################################
            # Training Loop
            #################################
            start_time = time.time()
            for iteration in range(num_iterations):

                if iteration > 0:
                    g.finalize()

                # print("\n\n\niter", iteration, sess.run(curr_iter))
                print("global step", global_step_call())

                logging.info("Iteration: %d, Global step: %d\n", iteration,
                             global_step_val)

                ####################
                # collect data
                ####################
                logging.info(
                    '\nStarting batch data collection. Collecting %d {%d, %d} trials for each of %d tasks',
                    collect_trials_per_task, max_episode_len,
                    episodes_per_trial, num_tasks_to_collect_per_iter)

                # randomly select tasks to collect this iteration
                list_of_collect_task_idxs = np.random.choice(
                    len(train_tasks),
                    num_tasks_to_collect_per_iter,
                    replace=False)
                for count, task_idx in enumerate(list_of_collect_task_idxs):
                    logging.info('on randomly selected task %d / %d',
                                 count + 1, num_tasks_to_collect_per_iter)

                    # set task for the env
                    py_env.set_task_for_env(train_tasks[task_idx])

                    # collect data with collect policy
                    _, policy_state_val = collect_calls[task_idx]()

                logging.info('Finish data collection. Global step: %d\n',
                             global_step_call())

                ####################
                # train model
                ####################
                if (iteration
                        == 0) or ((iteration % model_train_freq == 0) and
                                  (global_step_val < stop_model_training)):
                    logging.info(
                        '\n\nPerforming %d steps of model training, each on %d random tasks',
                        model_train_steps_per_iter, num_tasks_per_train)
                    for model_iter in range(model_train_steps_per_iter):
                        temp_start_2 = time.time()

                        # train model
                        total_loss_value_model, _, _ = model_train_step_call()

                        # print("is logging step", model_iter, sess.run(is_logging_step))
                        if PRINT_TIMING:
                            print("2: single model train step: ",
                                  time.time() - temp_start_2)
                    logging.info('Finish model training. Global step: %d\n',
                                 global_step_call())
                else:
                    print("SKIPPING MODEL TRAINING")

                ####################
                # train actor critic
                ####################
                if iteration % ac_train_freq == 0:
                    logging.info(
                        '\n\nPerforming %d steps of AC training, each on %d random tasks \n\n',
                        ac_train_steps_per_iter, num_tasks_per_train)
                    for ac_iter in range(ac_train_steps_per_iter):
                        temp_start_2_ac = time.time()

                        # train ac
                        total_loss_value_ac, _ = ac_train_step_call()
                        if PRINT_TIMING:
                            print("2: single AC train step: ",
                                  time.time() - temp_start_2_ac)
                logging.info('Finish AC training. Global step: %d\n',
                             global_step_call())

                # add up time
                time_acc += time.time() - start_time

                ####################
                # logging/summaries
                ####################

                ### Eval
                if eval_interval and (iteration % eval_interval == 0):
                    logging.info(
                        '\n\nDoing evaluation of trained policy on %d trials with randomly sampled tasks',
                        num_eval_trials)

                    perform_eval_and_summaries_meld(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_eval_trials,
                        max_episode_len,
                        episodes_per_trial,
                        log_image_strips=log_image_strips,
                        num_trials_to_render=
                        num_trials_to_render,  # hardcoded: or gif will get too long
                        eval_tasks=eval_tasks,
                        latent1_size=model_net.latent1_size,
                        latent2_size=model_net.latent2_size,
                        logger=eval_logger,
                        global_step_val=global_step_call(),
                        render_fps=render_fps,
                        decode_rews_op=decode_rews_op,
                        latent_samples_1_ph=latent_samples_1_ph,
                        latent_samples_2_ph=latent_samples_2_ph,
                        log_image_observations=log_image_observations,
                    )

                ### steps_per_second_summary
                global_step_val = global_step_call()
                if logging_freq_in_iter and (iteration % logging_freq_in_iter
                                             == 0):
                    # log step number + speed (steps/sec)
                    logging.info(
                        'step = %d, loss = %f', global_step_val,
                        total_loss_value_ac.loss + total_loss_value_model.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f env_steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})

                    # reset keeping track of steps/time
                    timed_at_step = global_step_val
                    time_acc = 0

                ### train_checkpoint
                if train_checkpoint_freq_in_iter and (
                        iteration % train_checkpoint_freq_in_iter == 0):
                    train_checkpointer.save(global_step=global_step_val)

                ### policy_checkpointer
                if policy_checkpoint_freq_in_iter and (
                        iteration % policy_checkpoint_freq_in_iter == 0):
                    policy_checkpointer.save(global_step=global_step_val)

                ### rb_checkpointer
                if rb_checkpoint_freq_in_iter and (
                        iteration % rb_checkpoint_freq_in_iter == 0):
                    for rb_checkpointer in rb_checkpointers:
                        rb_checkpointer.save(global_step=global_step_val)
Пример #16
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=5e5,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(
            64,
            64,
        ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(6, ),
        output_fc_layer_params=(30, ),

        # Params for collect
        initial_collect_steps=2000,
        collect_steps_per_iteration=6,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=6,
        batch_size=32,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=1,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    clusters = pickle.load(open('clusters.pickle', 'rb'))
    graph = nx.read_gpickle('graph.gpickle')
    print(graph.nodes)
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name,
                           gym_kwargs={
                               'graph': graph,
                               'clusters': clusters
                           }))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name,
                           gym_kwargs={
                               'graph': graph,
                               'clusters': clusters
                           }))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        action_spec = tf_env.action_spec()
        num_actions = action_spec.maximum - action_spec.minimum + 1

        if train_sequence_length > 1:
            q_net = create_recurrent_network(input_fc_layer_params, lstm_size,
                                             output_fc_layer_params,
                                             num_actions)
        else:
            q_net = create_feedforward_network(fc_layer_params, num_actions)
            train_sequence_length = n_step_update
        q_net = GATNetwork(tf_env.observation_spec(), tf_env.action_spec(),
                           graph)
        #time_step = tf_env.reset()
        #q_net(time_step.observation, time_step.step_type)
        #q_net = actor_distribution_network.ActorDistributionNetwork(
        #	tf_env.observation_spec(),
        #	tf_env.action_spec(),
        #	fc_layer_params=fc_layer_params)

        #q_net = QNetwork(tf_env.observation_spec(), tf_env.action_spec(), 30)
        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        #critic_net = ddpg.critic_network.CriticNetwork(
        #(tf_env.observation_spec(), tf_env.action_spec()),
        #observation_fc_layer_params=None,
        #action_fc_layer_params=None,
        #joint_fc_layer_params=(64,64,),
        #kernel_initializer='glorot_uniform',
        #last_kernel_initializer='glorot_uniform')

        #tf_agent = DdpgAgent(tf_env.time_step_spec(),
        #			   tf_env.action_spec(),
        #			   actor_network=q_net,
        #			   critic_network=critic_net,
        #			   actor_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        #			   critic_optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        #			   ou_stddev=0.0,
        #			   ou_damping=0.0)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
            tf_metrics.MaxReturnMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        print(tf_env.envs[0]._gym_env.best_controllers)
        print(tf_env.envs[0]._gym_env.best_reward)
        tf_env.envs[0]._gym_env.reset()
        centroid_controllers, heuristic_distance = tf_env.envs[
            0]._gym_env.graphCentroidAction()
        # Convert heuristic controllers to actual
        print(centroid_controllers)
        # Assume all clusters same length
        #centroid_controllers.sort()
        #cluster_len = len(clusters[0])
        #for i in range(len(clusters)):
        #	centroid_controllers[i] -= i * cluster_len
        print(centroid_controllers)
        for cont in centroid_controllers:
            (_, reward_final, _, _) = tf_env.envs[0]._gym_env.step(cont)
        best_heuristic = reward_final
        print(tf_env.envs[0]._gym_env.controllers, reward_final)
        return train_loss
Пример #17
0
def train_eval(
    root_dir,
    env_name='cartpole',
    task_name='balance',
    observations_allowlist='position',
    num_iterations=100000,
    actor_fc_layers=(400, 300),
    actor_output_fc_layers=(100,),
    actor_lstm_size=(40,),
    critic_obs_fc_layers=(400,),
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(300,),
    critic_output_fc_layers=(100,),
    critic_lstm_size=(40,),
    # Params for collect
    initial_collect_episodes=1,
    collect_episodes_per_iteration=1,
    replay_buffer_capacity=100000,
    ou_stddev=0.2,
    ou_damping=0.15,
    # Params for target update
    target_update_tau=0.05,
    target_update_period=5,
    # Params for train
    # Params for train
    train_steps_per_iteration=200,
    batch_size=64,
    train_sequence_length=10,
    actor_learning_rate=1e-4,
    critic_learning_rate=1e-3,
    dqda_clipping=None,
    td_errors_loss_fn=None,
    gamma=0.995,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for checkpoints, summaries, and logging
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=True,
    summarize_grads_and_vars=True,
    eval_metrics_callback=None):

  """A simple train and eval for DDPG."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    if observations_allowlist is not None:
      env_wrappers = [
          functools.partial(
              wrappers.FlattenObservationsWrapper,
              observations_allowlist=[observations_allowlist])
      ]
    else:
      env_wrappers = []

    tf_env = tf_py_environment.TFPyEnvironment(
        suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers))
    eval_tf_env = tf_py_environment.TFPyEnvironment(
        suite_dm_control.load(env_name, task_name, env_wrappers=env_wrappers))

    actor_net = actor_rnn_network.ActorRnnNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        input_fc_layer_params=actor_fc_layers,
        lstm_size=actor_lstm_size,
        output_fc_layer_params=actor_output_fc_layers)

    critic_net_input_specs = (tf_env.time_step_spec().observation,
                              tf_env.action_spec())

    critic_net = critic_rnn_network.CriticRnnNetwork(
        critic_net_input_specs,
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        lstm_size=critic_lstm_size,
        output_fc_layer_params=critic_output_fc_layers,
    )

    tf_agent = ddpg_agent.DdpgAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        ou_stddev=ou_stddev,
        ou_damping=ou_damping,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        dqda_clipping=dqda_clipping,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=initial_collect_episodes)

    collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    logging.info(
        'Initializing replay buffer by collecting experience for %d episodes '
        'with a random policy.', initial_collect_episodes)
    initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [BxTx...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=train_sequence_length + 1).prefetch(3)
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    if use_tf_functions:
      train_step = common.function(train_step)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step.numpy(),
                     train_loss.loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step.numpy() % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

    return train_loss
Пример #18
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        env_load_fn=suite_gym.load,
        random_seed=0,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)
        eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            trajectories = replay_buffer.gather_all()
            total_loss, _ = tf_agent.train(experience=trajectories)
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = {}, train_time = {}'.format(
                    collect_time, train_time))
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
Пример #19
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v1',
        num_iterations=1000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=100,
        eval_interval=500000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=10000,
        summary_interval=10000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        # Create the environment.
        tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
        eval_py_env = suite_mujoco.load(env_name)

        # Get the data specs from the environment
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        collect_policy = tf_agent.collect_policy

        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=5 * batch_size,
            num_steps=2).apply(tf.data.experimental.unbatch()).filter(
                _filter_invalid_transition).batch(batch_size).prefetch(
                    batch_size * 5)
        dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
            dataset)
        trajectories, unused_info = dataset_iterator.get_next()
        train_op = tf_agent.train(trajectories)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        with tf.compat.v1.Session() as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            policy_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)

            # Initialize training.
            sess.run(dataset_iterator.initializer)
            common_utils.initialize_uninitialized_variables(sess)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            global_step_val = sess.run(global_step)

            if global_step_val == 0:
                # Initial eval of randomly initialized policy
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    eval_py_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    callback=eval_metrics_callback,
                    log=True,
                )

                # Run initial collect.
                logging.info('Global step %d: Running initial collect op.',
                             global_step_val)
                sess.run(initial_collect_op)

                # Checkpoint the initial replay buffer contents.
                rb_checkpointer.save(global_step=global_step_val)

                logging.info('Finished initial collect.')
            else:
                logging.info('Global step %d: Skipping initial collect op.',
                             global_step_val)

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(train_op)
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    total_loss = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
Пример #20
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for SAC."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))
    eval_tf_env = tf_py_environment.TFPyEnvironment(suite_mujoco.load(env_name))

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=normal_projection_net)
    critic_net = critic_network.CriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers)

    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)
    tf_agent.initialize()

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
        tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
    ]

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    # Collect initial replay data.
    logging.info(
        'Initializing replay buffer by collecting experience for %d steps with '
        'a random policy.', initial_collect_steps)
    initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    for _ in range(num_iterations):
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        experience, _ = next(iterator)
        train_loss = tf_agent.train(experience)
      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step.numpy(),
                     train_loss.loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step.numpy() % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

      global_step_val = global_step.numpy()
      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
    return train_loss
Пример #21
0
    def train(self,
              training_iterations=TRAINING_ITERATIONS,
              training_stock_list=None):
        self.reset(training_stock_list)

        train_dir = 'training_data_progress/train-' + self.name
        eval_dir = 'training_data_progress/eval-' + self.name

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=self.tf_agent.collect_data_spec,
            batch_size=self.tf_training_env.batch_size,
            max_length=MAX_BUFFER_SIZE)

        summaries_flush_secs = 10

        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=NUM_EVAL_EPISODES),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=NUM_EVAL_EPISODES)
        ]

        global_step = self.tf_agent.train_step_counter
        with tf.compat.v2.summary.record_if(
                lambda: tf.math.equal(global_step % LOG_INTERVAL, 0)):

            replay_observer = [replay_buffer.add_batch]

            train_metrics = [
                tf_metrics.NumberOfEpisodes(),
                tf_metrics.EnvironmentSteps(),
                tf_metrics.AverageReturnMetric(
                    buffer_size=NUM_EVAL_EPISODES,
                    batch_size=self.tf_training_env.batch_size),
                tf_metrics.AverageEpisodeLengthMetric(
                    buffer_size=NUM_EVAL_EPISODES,
                    batch_size=self.tf_training_env.batch_size),
            ]

            eval_policy = greedy_policy.GreedyPolicy(self.tf_agent.policy)
            initial_collect_policy = random_tf_policy.RandomTFPolicy(
                self.tf_training_env.time_step_spec(),
                self.tf_training_env.action_spec())
            collect_policy = self.tf_agent.collect_policy

            train_checkpointer = common.Checkpointer(
                ckpt_dir=train_dir,
                agent=self.tf_agent,
                global_step=global_step,
                metrics=metric_utils.MetricsGroup(train_metrics,
                                                  'train_metrics'))
            policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
                train_dir, 'policy'),
                                                      policy=eval_policy,
                                                      global_step=global_step)
            rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
                train_dir, 'replay_buffer'),
                                                  max_to_keep=1,
                                                  replay_buffer=replay_buffer)

            train_checkpointer.initialize_or_restore()
            rb_checkpointer.initialize_or_restore()

            initial_collect_driver_random = dynamic_step_driver.DynamicStepDriver(
                self.tf_training_env,
                initial_collect_policy,
                observers=replay_observer + train_metrics,
                num_steps=INIT_COLLECT_STEPS)
            initial_collect_driver_random.run = common.function(
                initial_collect_driver_random.run)

            collect_driver = dynamic_step_driver.DynamicStepDriver(
                self.tf_training_env,
                collect_policy,
                observers=replay_observer + train_metrics,
                num_steps=STEP_ITERATIONS)

            collect_driver.run = common.function(collect_driver.run)
            self.tf_agent.train = common.function(self.tf_agent.train)

            # Collect some initial data.
            # Random
            random_policy = random_tf_policy.RandomTFPolicy(
                self.tf_training_env.time_step_spec(),
                self.tf_training_env.action_spec())
            avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return(
                random_policy)
            print(
                'Random:\n\tAverage Return = {0}\n\tAverage Return Per Step = {1}\n\tPercent = {2}%'
                .format(avg_return, avg_return_per_step, avg_daily_percentage))
            self.gym_training_env.save_feature_distribution(self.name)

            # Agent
            avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return(
                self.tf_agent.policy)
            print(
                'Agent :\n\tAverage Return = {0}\n\tAverage Return Per Step = {1}\n\tPercent = {2}%'
                .format(avg_return, avg_return_per_step, avg_daily_percentage))
            self.eval_env.reset()
            self.eval_env.run_and_save_evaluation(str(0))
            self.gym_training_env.save_feature_distribution(self.name)

            evaluations = [self.get_evaluation()]
            returns = [self.eval_env.returns]
            actions_over_time_list = [self.eval_env.action_sets_over_time]

            # Collect initial replay data.
            print(
                'Initializing replay buffer by collecting experience for {} steps with '
                'a random policy.'.format(INIT_COLLECT_STEPS))
            initial_collect_driver_random.run()

            results = metric_utils.eager_compute(
                eval_metrics,
                self.tf_training_env,
                eval_policy,
                num_episodes=NUM_EVAL_EPISODES,
                train_step=global_step,
                summary_prefix='Metrics',
            )
            metric_utils.log_metrics(eval_metrics)

            time_step = None
            policy_state = collect_policy.get_initial_state(
                self.tf_training_env.batch_size)

            timed_at_step = global_step.numpy()
            time_acc = 0

            # Prepare replay buffer as dataset with invalid transitions filtered.
            def _filter_invalid_transition(trajectories, unused_arg1):
                return ~trajectories.is_boundary()[0]

            dataset = replay_buffer.as_dataset(
                sample_batch_size=BATCH_SIZE, num_steps=2).unbatch().filter(
                    _filter_invalid_transition).batch(BATCH_SIZE).prefetch(5)
            # Dataset generates trajectories with shape [Bx2x...]
            iterator = iter(dataset)

            def _train_step():
                try:
                    experience, _ = next(iterator)
                    return self.tf_agent.train(experience)
                except Exception as e:
                    print("Caught Exception:", e)
                    return 1e-20

            train_step = common.function(_train_step)

            for _ in range(training_iterations):
                start_time = time.time()
                time_step, policy_state = collect_driver.run(
                    time_step=time_step,
                    policy_state=policy_state,
                )
                for _ in range(STEP_ITERATIONS):
                    train_loss = train_step()
                time_acc += time.time() - start_time

                self.global_step_val = global_step.numpy()

                if self.global_step_val % LOG_INTERVAL == 0:
                    steps_per_sec = (self.global_step_val -
                                     timed_at_step) / time_acc
                    print(
                        self.name,
                        '\nstep = {0:d}:\n\tloss = {1:f}\n\t{2:.3f} steps/sec'.
                        format(self.global_step_val, train_loss.loss,
                               steps_per_sec))
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)
                    timed_at_step = self.global_step_val
                    time_acc = 0

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=global_step,
                                              step_metrics=train_metrics[:2])

                if self.global_step_val % EVAL_INTERVAL == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        self.tf_training_env,
                        eval_policy,
                        num_episodes=NUM_EVAL_EPISODES,
                        train_step=global_step,
                        summary_prefix='Metrics',
                    )
                    metric_utils.log_metrics(eval_metrics)

                    avg_return, avg_return_per_step, avg_daily_percentage = self.compute_avg_return(
                        self.tf_agent.policy)
                    print(
                        self.name,
                        '\nstep = {0}:\n\tloss = {1}\n\tAverage Return = {2}\n\tAverage Return Per Step = {3}\n\tPercent = {4}%'
                        .format(self.global_step_val, train_loss.loss,
                                avg_return, avg_return_per_step,
                                avg_daily_percentage))
                    self.eval_env.reset()
                    self.eval_env.run_and_save_evaluation(
                        str(self.global_step_val // EVAL_INTERVAL))
                    self.gym_training_env.save_feature_distribution(self.name)

                    if avg_daily_percentage == returns[-1]:
                        "---- Average return did not change since last time. Breaking loop."
                        break

                    evaluations.append(self.get_evaluation())
                    returns.append(self.eval_env.returns)
                    actions_over_time_list.append(
                        self.eval_env.action_sets_over_time)

                    train_checkpointer.save(global_step=self.global_step_val)
                    policy_checkpointer.save(global_step=self.global_step_val)
                    rb_checkpointer.save(global_step=self.global_step_val)

        training_report = util.load_training_report()
        agent_report = training_report.get(self.name, dict())
        agent_report["Training Results"] = returns
        agent_report["Evaluations"] = [max(e, 0.0) for e in evaluations]
        bins = [0.1 * i - 0.0000001 for i in range(11)]
        agent_report["Histograms"] = [
            str(list(map(int,
                         np.histogram(actions, bins, density=True)[0])))
            for actions in actions_over_time_list
        ]
        training_report[self.name] = agent_report
        util.save_training_report(training_report)

        print("---- Average-daily-percentage over training period for",
              self.name)
        print("\t\t", avg_daily_percentage)
        self.save()
        self.reset()
Пример #22
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_steps=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        exploration_noise_std=0.1,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        actor_update_period=2,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []
        environment = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)
        tf_env = tf_py_environment.TFPyEnvironment(environment)
        eval_py_env = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_rnn_network.CriticRnnNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        global_step = tf.compat.v1.train.get_or_create_global_step()
        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            exploration_noise_std=exploration_noise_std,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            actor_update_period=actor_update_period,
            dqda_clipping=dqda_clipping,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            debug_summaries=debug_summaries,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run(policy_state=policy_state)

        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run(
                policy_state=policy_state)

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        trajectories, unused_info = iterator.get_next()

        train_fn = common.function(tf_agent.train)
        train_op = train_fn(experience=trajectories)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable([train_op, summary_ops])
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.compat.v2.summary.scalar(
                name='global_steps_per_sec',
                data=steps_per_second_ph,
                step=global_step)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _ = train_step_call()
                time_acc += time.time() - start_time

                global_step_val = global_step_call()
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
Пример #23
0
def train_eval(
    root_dir,
    experiment_name,  # experiment name
    env_name='carla-v0',
    num_iterations=int(1e7),
    model_network_ctor_type='non-hierarchical',  # model net
    input_names=['camera', 'lidar'],  # names for inputs
    reconstruct_names=['roadmap'],  # names for masks
    pixor_names=['vh_clas', 'vh_regr', 'pixor_state'],  # names for pixor outputs
    reconstruct_pixor_state=True,  # whether to reconstruct pixor_state
    extra_names=['state'],  # extra inputs
    obs_size=64,  # size of observation image
    pixor_size=64,  # size of pixor output image
    perception_weight=1.0,  # weight of perception part loss
    # Params for collect
    initial_collect_steps=1000,
    replay_buffer_capacity=int(5e4+1),
    # Params for train
    training=True,  # whether to train, or just evaluate
    model_batch_size=32,  # model training batch size
    sequence_length=10,  # number of timesteps to train model
    model_learning_rate=1e-4,  # learning rate for model training
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=2000,
    # Params for summaries and logging
    num_images_per_summary=1,  # images for each summary
    train_checkpoint_interval=2000,
    log_interval=200,
    summary_interval=2000,
    summaries_flush_secs=10,
    summarize_grads_and_vars=False,
    gpu_allow_growth=True,  # GPU memory growth
    gpu_memory_limit=None,  # GPU memory limit
    action_repeat=1):  # Name of single observation channel, ['camera', 'lidar', 'birdeye']
  """A simple train and eval for SLAC."""
  # Setup GPU
  gpus = tf.config.experimental.list_physical_devices('GPU')
  if gpu_allow_growth:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  if gpu_memory_limit:
    for gpu in gpus:
      tf.config.experimental.set_virtual_device_configuration(
          gpu,
          [tf.config.experimental.VirtualDeviceConfiguration(
              memory_limit=gpu_memory_limit)])

  # Get train and eval direction
  root_dir = os.path.expanduser(root_dir)
  root_dir = os.path.join(root_dir, env_name, experiment_name)

  # Get summary writers
  summary_writer = tf.summary.create_file_writer(
      root_dir, flush_millis=summaries_flush_secs * 1000)
  summary_writer.set_as_default()

  # Eval metrics
  eval_metrics = [
      tf_metrics.AverageReturnMetric(
        name='AverageReturnEvalPolicy', buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(
        name='AverageEpisodeLengthEvalPolicy',
        buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()

  # Whether to record for summary
  with tf.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    # Create Carla environment
    py_env, eval_py_env = load_carla_env(env_name='carla-v0', lidar_bin=32/obs_size, pixor_size=pixor_size,
      obs_channels=list(set(input_names+reconstruct_names+pixor_names+extra_names)), action_repeat=action_repeat)

    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
    fps = int(np.round(1.0 / (py_env.dt * action_repeat)))

    # Specs
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    # Get model network
    if model_network_ctor_type == 'hierarchical':
      model_network_ctor = sequential_latent_pixor_network.PixorSLMHierarchical
    else:
      raise NotImplementedError
    model_net = model_network_ctor(
      input_names, reconstruct_names, obs_size=obs_size, pixor_size=pixor_size,
      reconstruct_pixor_state=reconstruct_pixor_state, perception_weight=perception_weight)

    # Build the perception agent
    actor_network = state_based_heuristic_actor_network.StateBasedHeuristicActorNetwork(
        observation_spec['state'],
        action_spec,
        desired_speed=9
        )

    tf_agent = perception_agent.PerceptionAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_network,
        model_network=model_net,
        model_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=model_learning_rate),
        num_images_per_summary=num_images_per_summary,
        sequence_length=sequence_length,
        gradient_clipping=gradient_clipping,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step,
        fps=fps)
    tf_agent.initialize()

    # Train metrics
    env_steps = tf_metrics.EnvironmentSteps()
    average_return = tf_metrics.AverageReturnMetric(
        buffer_size=num_eval_episodes,
        batch_size=tf_env.batch_size)
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        env_steps,
        average_return,
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size),
    ]

    # Get policies
    eval_policy = tf_agent.policy
    initial_collect_policy = tf_agent.collect_policy

    # Checkpointers
    train_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'train'),
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
        max_to_keep=2)
    train_checkpointer.initialize_or_restore()

    model_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(root_dir, 'model'),
        model=model_net,
        max_to_keep=2)

    # Evaluation
    compute_summaries(
      eval_metrics,
      eval_tf_env,
      eval_policy,
      train_step=global_step,
      summary_writer=summary_writer,
      num_episodes=num_eval_episodes,
      num_episodes_to_render=num_images_per_summary,
      model_net=model_net,
      fps=10,
      image_keys=['camera', 'lidar', 'roadmap'],
      pixor_size=pixor_size)

    # Collect/restore data and train
    if training:
      # Get replay buffer
      replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
          data_spec=tf_agent.collect_data_spec,
          batch_size=1,  # No parallel environments
          max_length=replay_buffer_capacity)
      replay_observer = [replay_buffer.add_batch]

      # Replay buffer checkpointer
      rb_checkpointer = common.Checkpointer(
          ckpt_dir=os.path.join(root_dir, 'replay_buffer'),
          max_to_keep=1,
          replay_buffer=replay_buffer)
      rb_checkpointer.initialize_or_restore()

      # Collect driver
      initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
          tf_env,
          initial_collect_policy,
          observers=replay_observer + train_metrics,
          num_steps=initial_collect_steps)

      # Optimize the performance by using tf functions
      initial_collect_driver.run = common.function(initial_collect_driver.run)

      # Collect initial replay data.
      if (global_step.numpy() == 0 and replay_buffer.num_frames() == 0):
        logging.info(
            'Collecting experience for %d steps '
            'with a model-based policy.', initial_collect_steps)
        initial_collect_driver.run()
        rb_checkpointer.save(global_step=global_step.numpy())

      # Dataset generates trajectories with shape [Bxslx...]
      dataset = replay_buffer.as_dataset(
          num_parallel_calls=3,
          sample_batch_size=model_batch_size,
          num_steps=sequence_length + 1).prefetch(3)
      iterator = iter(dataset)

      # Get train model step
      def train_step():
        experience, _ = next(iterator)
        return tf_agent.train(experience)
      train_step = common.function(train_step)

      # Start training
      for iteration in range(num_iterations):

        loss = train_step()

        # Log training information
        if global_step.numpy() % log_interval == 0:
          logging.info('global steps = %d, model loss = %f', global_step.numpy(), loss.loss)

        # Get training metrics
        for train_metric in train_metrics:
          train_metric.tf_summaries(train_step=global_step.numpy())

        # Evaluation
        if global_step.numpy() % eval_interval == 0:
          # Log evaluation metrics
          compute_summaries(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            train_step=global_step,
            summary_writer=summary_writer,
            num_episodes=num_eval_episodes,
            num_episodes_to_render=num_images_per_summary,
            model_net=model_net,
            fps=10,
            image_keys=['camera', 'lidar', 'roadmap'],
            pixor_size=pixor_size)

        # Save checkpoints
        global_step_val = global_step.numpy()
        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)
          model_checkpointer.save(global_step=global_step_val)
Пример #24
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        train_sequence_length=1,
        # Params for QNetwork
        fc_layer_params=(100, ),
        # Params for QRnnNetwork
        input_fc_layer_params=(50, ),
        lstm_size=(20, ),
        output_fc_layer_params=(20, ),

        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        if train_sequence_length != 1 and n_step_update != 1:
            raise NotImplementedError(
                'train_eval does not currently support n-step updates with stateful '
                'networks (i.e., RNNs)')

        if train_sequence_length > 1:
            q_net = q_rnn_network.QRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=input_fc_layer_params,
                lstm_size=lstm_size,
                output_fc_layer_params=output_fc_layer_params)
        else:
            q_net = q_network.QNetwork(tf_env.observation_spec(),
                                       tf_env.action_spec(),
                                       fc_layer_params=fc_layer_params)
            train_sequence_length = n_step_update

        # TODO(b/127301657): Decay epsilon based on global step, cf. cl/188907839
        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=epsilon_greedy,
            n_step_update=n_step_update,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=common.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if use_tf_functions:
            # To speed up collect use common.function.
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step.numpy())

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)
        return train_loss
Пример #25
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v1',
        env_load_fn=suite_mujoco.load,
        num_iterations=2000000,
        actor_fc_layers=(400, 300),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_parallel_environments=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        td_errors_loss_fn=tf.losses.huber_loss,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):
        if num_parallel_environments > 1:
            tf_env = tf_py_environment.TFPyEnvironment(
                parallel_py_environment.ParallelPyEnvironment(
                    [lambda: env_load_fn(env_name)] *
                    num_parallel_environments))
        else:
            tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        eval_py_env = env_load_fn(env_name)

        actor_net = actor_network.ActorNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers,
        )

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_network.CriticNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
        )

        tf_agent = ddpg_agent.DdpgAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec(),
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        global_step = tf.train.get_or_create_global_step()

        collect_policy = tf_agent.collect_policy()
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps).run()

        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)

        iterator = dataset.make_initializable_iterator()
        trajectories, unused_info = iterator.get_next()
        train_op = tf_agent.train(experience=trajectories,
                                  train_step_counter=global_step)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=tf.contrib.checkpoint.List(train_metrics))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy(),
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])
        summary_op = tf.contrib.summary.all_summary_ops()

        with eval_summary_writer.as_default(), \
             tf.contrib.summary.always_record_summaries():
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            tf.contrib.summary.initialize(session=sess)
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(
                [train_op, summary_op, global_step])

            timed_at_step = sess.run(global_step)
            time_acc = 0
            steps_per_second_ph = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _, global_step_val = train_step_call()
                time_acc += time.time() - start_time

                if global_step_val % log_interval == 0:
                    tf.logging.info('step = %d, loss = %f', global_step_val,
                                    loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    tf.logging.info('%.3f steps/sec' % steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
Пример #26
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=1000,
        # TODO(kbanoop): rename to policy_fc_layers.
        actor_fc_layers=(100, ),
        # Params for collect
        collect_episodes_per_iteration=2,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gradient_clipping=None,
        normalize_returns=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=100,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        rb_checkpoint_interval=200,
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_py_env = suite_gym.load(env_name)
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

        # TODO(kbanoop): Handle distributions without gin.
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers)

        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        experience = replay_buffer.gather_all()
        train_op = tf_agent.train(experience)
        clear_rb_op = replay_buffer.clear()

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=tf_agent.policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(train_step=global_step,
                                      step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            # TODO(sguada) Remove once Periodically can be saved.
            common.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            # Compute evaluation metrics.
            global_step_call = sess.make_callable(global_step)
            global_step_val = global_step_call()
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(train_op)
            clear_rb_call = sess.make_callable(clear_rb_op)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                total_loss = train_step_call()
                clear_rb_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
Пример #27
0
    agent.initialize()

    eval_policy = agent.policy  # greedy with respect to H
    collect_policy = agent.collect_policy  # epsilon greedy with respect to H

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=replay_buffer_max_length
    )

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    driver = dynamic_step_driver.DynamicStepDriver(
        train_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_steps=4
    )

    # collect some experience before training
    initial_collect_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec())

    init_driver = dynamic_step_driver.DynamicStepDriver(
        train_env,
Пример #28
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=1000,
        actor_fc_layers=(100, ),
        value_net_fc_layers=(100, ),
        use_value_network=False,
        use_tf_functions=True,
        # Params for collect
        collect_episodes_per_iteration=2,
        replay_buffer_capacity=2000,
        # Params for train
        learning_rate=1e-3,
        gamma=0.9,
        gradient_clipping=None,
        normalize_returns=True,
        value_estimation_loss_coef=0.2,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=100,
        # Params for checkpoints, summaries, and logging
        log_interval=100,
        summary_interval=100,
        summaries_flush_secs=1,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for Reinforce."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            fc_layer_params=actor_fc_layers)

        if use_value_network:
            value_net = value_network.ValueNetwork(
                tf_env.time_step_spec().observation,
                fc_layer_params=value_net_fc_layers)

        global_step = tf.compat.v1.train.get_or_create_global_step()
        tf_agent = reinforce_agent.ReinforceAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            value_network=value_net if use_value_network else None,
            value_estimation_loss_coef=value_estimation_loss_coef,
            gamma=gamma,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            normalize_returns=normalize_returns,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            experience = replay_buffer.gather_all()
            return tf_agent.train(experience)

        if use_tf_functions:
            # To speed up collect use TF function.
            collect_driver.run = common.function(collect_driver.run)
            # To speed up train use TF function.
            tf_agent.train = common.function(tf_agent.train)
            train_step = common.function(train_step)

        # Compute evaluation metrics.
        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        # TODO(b/126590894): Move this functionality into eager_compute_summaries
        if eval_metrics_callback is not None:
            eval_metrics_callback(metrics, global_step.numpy())

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            total_loss = train_step()
            replay_buffer.clear()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                metrics = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                # TODO(b/126590894): Move this functionality into
                # eager_compute_summaries.
                if eval_metrics_callback is not None:
                    eval_metrics_callback(metrics, global_step_val)
Пример #29
0
def train(
    root_dir,
    load_root_dir=None,
    env_load_fn=None,
    env_name=None,
    num_parallel_environments=1,  # pylint: disable=unused-argument
    agent_class=None,
    initial_collect_random=True,  # pylint: disable=unused-argument
    initial_collect_driver_class=None,
    collect_driver_class=None,
    num_global_steps=1000000,
    train_steps_per_iteration=1,
    train_metrics=None,
    # Safety Critic training args
    train_sc_steps=10,
    train_sc_interval=300,
    online_critic=False,
    # Params for eval
    run_eval=False,
    num_eval_episodes=30,
    eval_interval=1000,
    eval_metrics_callback=None,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=20000,
    keep_rb_checkpoint=False,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    early_termination_fn=None,
    env_metric_factories=None):  # pylint: disable=unused-argument
  """A simple train and eval for SC-SAC."""

  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  train_metrics = train_metrics or []

  if run_eval:
    eval_dir = os.path.join(root_dir, 'eval')
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    tf_env = env_load_fn(env_name)
    if not isinstance(tf_env, tf_py_environment.TFPyEnvironment):
      tf_env = tf_py_environment.TFPyEnvironment(tf_env)

    if run_eval:
      eval_py_env = env_load_fn(env_name)
      eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    print('obs spec:', observation_spec)
    print('action spec:', action_spec)

    if online_critic:
      resample_metric = tf_py_metric.TfPyMetric(
          py_metrics.CounterMetric('unsafe_ac_samples'))
      tf_agent = agent_class(
          time_step_spec,
          action_spec,
          train_step_counter=global_step,
          resample_metric=resample_metric)
    else:
      tf_agent = agent_class(
          time_step_spec, action_spec, train_step_counter=global_step)

    tf_agent.initialize()

    # Make the replay buffer.
    collect_data_spec = tf_agent.collect_data_spec

    logging.info('Allocating replay buffer ...')
    # Add to replay buffer and other agent specific observers.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, max_length=1000000)
    logging.info('RB capacity: %i', replay_buffer.capacity)
    logging.info('ReplayBuffer Collect data spec: %s', collect_data_spec)

    agent_observers = [replay_buffer.add_batch]
    if online_critic:
      online_replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
          collect_data_spec, max_length=10000)

      online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer')
      online_rb_checkpointer = common.Checkpointer(
          ckpt_dir=online_rb_ckpt_dir,
          max_to_keep=1,
          replay_buffer=online_replay_buffer)

      clear_rb = common.function(online_replay_buffer.clear)
      agent_observers.append(online_replay_buffer.add_batch)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
    ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    if not online_critic:
      eval_policy = tf_agent.policy
    else:
      eval_policy = tf_agent._safe_policy  # pylint: disable=protected-access

    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        time_step_spec, action_spec)
    if not online_critic:
      collect_policy = tf_agent.collect_policy
    else:
      collect_policy = tf_agent._safe_policy  # pylint: disable=protected-access

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    safety_critic_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'safety_critic'),
        safety_critic=tf_agent._safety_critic_network,  # pylint: disable=protected-access
        global_step=global_step)
    rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer)

    if load_root_dir:
      load_root_dir = os.path.expanduser(load_root_dir)
      load_train_dir = os.path.join(load_root_dir, 'train')
      misc.load_pi_ckpt(load_train_dir, tf_agent)  # loads tf_agent

    if load_root_dir is None:
      train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()
    safety_critic_checkpointer.initialize_or_restore()

    collect_driver = collect_driver_class(
        tf_env, collect_policy, observers=agent_observers + train_metrics)

    collect_driver.run = common.function(collect_driver.run)
    tf_agent.train = common.function(tf_agent.train)

    if not rb_checkpointer.checkpoint_exists:
      logging.info('Performing initial collection ...')
      common.function(
          initial_collect_driver_class(
              tf_env,
              initial_collect_policy,
              observers=agent_observers + train_metrics).run)()
      last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
      logging.info('Data saved after initial collection: %d steps', last_id)
      tf.print(
          replay_buffer._get_rows_for_id(last_id),  # pylint: disable=protected-access
          output_stream=logging.info)

    if run_eval:
      results = metric_utils.eager_compute(
          eval_metrics,
          eval_tf_env,
          eval_policy,
          num_episodes=num_eval_episodes,
          train_step=global_step,
          summary_writer=eval_summary_writer,
          summary_prefix='Metrics',
      )
      if eval_metrics_callback is not None:
        eval_metrics_callback(results, global_step.numpy())
      metric_utils.log_metrics(eval_metrics)
      if FLAGS.viz_pm:
        eval_fig_dir = osp.join(eval_dir, 'figs')
        if not tf.io.gfile.isdir(eval_fig_dir):
          tf.io.gfile.makedirs(eval_fig_dir)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3, num_steps=2).prefetch(3)
    iterator = iter(dataset)
    if online_critic:
      online_dataset = online_replay_buffer.as_dataset(
          num_parallel_calls=3, num_steps=2).prefetch(3)
      online_iterator = iter(online_dataset)

      @common.function
      def critic_train_step():
        """Builds critic training step."""
        experience, buf_info = next(online_iterator)
        if env_name in [
            'IndianWell', 'IndianWell2', 'IndianWell3', 'DrunkSpider',
            'DrunkSpiderShort'
        ]:
          safe_rew = experience.observation['task_agn_rew']
        else:
          safe_rew = agents.process_replay_buffer(
              online_replay_buffer, as_tensor=True)
          safe_rew = tf.gather(safe_rew, tf.squeeze(buf_info.ids), axis=1)
        ret = tf_agent.train_sc(experience, safe_rew)
        clear_rb()
        return ret

    @common.function
    def train_step():
      experience, _ = next(iterator)
      ret = tf_agent.train(experience)
      return ret

    if not early_termination_fn:
      early_termination_fn = lambda: False

    loss_diverged = False
    # How many consecutive steps was loss diverged for.
    loss_divergence_counter = 0
    mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')
    if online_critic:
      mean_resample_ac = tf.keras.metrics.Mean(name='mean_unsafe_ac_samples')
      resample_metric.reset()

    while (global_step.numpy() <= num_global_steps and
           not early_termination_fn()):
      # Collect and train.
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      if online_critic:
        mean_resample_ac(resample_metric.result())
        resample_metric.reset()
        if time_step.is_last():
          resample_ac_freq = mean_resample_ac.result()
          mean_resample_ac.reset_states()
          tf.compat.v2.summary.scalar(
              name='unsafe_ac_samples', data=resample_ac_freq, step=global_step)

      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
        mean_train_loss(train_loss.loss)

      if online_critic:
        if global_step.numpy() % train_sc_interval == 0:
          for _ in range(train_sc_steps):
            sc_loss, lambda_loss = critic_train_step()  # pylint: disable=unused-variable

      total_loss = mean_train_loss.result()
      mean_train_loss.reset_states()
      # Check for exploding losses.
      if (math.isnan(total_loss) or math.isinf(total_loss) or
          total_loss > MAX_LOSS):
        loss_divergence_counter += 1
        if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
          loss_diverged = True
          break
      else:
        loss_divergence_counter = 0

      time_acc += time.time() - start_time

      if global_step.numpy() % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step.numpy(), total_loss)
        steps_per_sec = (global_step.numpy() - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step.numpy()
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      global_step_val = global_step.numpy()
      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)
        safety_critic_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        if online_critic:
          online_rb_checkpointer.save(global_step=global_step_val)
        rb_checkpointer.save(global_step=global_step_val)

      if run_eval and global_step.numpy() % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)
        if FLAGS.viz_pm:
          savepath = 'step{}.png'.format(global_step_val)
          savepath = osp.join(eval_fig_dir, savepath)
          misc.record_episode_vis_summary(eval_tf_env, eval_policy, savepath)

  if not keep_rb_checkpoint:
    misc.cleanup_checkpoints(rb_ckpt_dir)

  if loss_diverged:
    # Raise an error at the very end after the cleanup.
    raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
        total_loss, global_step.numpy()))

  return total_loss
Пример #30
0
def train_eval(
        root_dir,
        env_name='gym_hand_sim:MplThumbGraspTrain-v0',
        op_env_name='gym_hand_sim:MplThumbGraspOp-v0',
        env_load_fn=suite_mujoco.load,
        random_seed=None,
        # TODO(b/127576522): rename to policy_fc_layers.
        actor_fc_layers=(512, 256),
        value_fc_layers=(512, 256),
        use_rnns=True,
        # Params for collect
        num_environment_steps=2500000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=20,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=500,
        policy_checkpoint_interval=500,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=True,
        summarize_grads_and_vars=False):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        eval_py_env = env_load_fn(env_name)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)
        op_py_env = env_load_fn(op_env_name)
        op_tf_env = tf_py_environment.TFPyEnvironment(op_py_env)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))

        #tf_env = wrappers.ActionDiscretizeWrapper(tf_env, num_actions=11)
        #eval_tf_env = wrappers.ActionDiscretizeWrapper(eval_tf_env, num_actions=11)

        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
                                             epsilon=1e-5)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                preprocessing_layers={
                    'observation':
                    tf.keras.layers.experimental.preprocessing.Normalization()
                },
                input_fc_layer_params=actor_fc_layers,
                lstm_size=(256, ),
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                preprocessing_layers={
                    'observation':
                    tf.keras.layers.experimental.preprocessing.Normalization()
                },
                input_fc_layer_params=value_fc_layers,
                lstm_size=(256, ),
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers,
                activation_fn=tf.keras.activations.tanh)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(),
                fc_layer_params=value_fc_layers,
                activation_fn=tf.keras.activations.tanh)

        tf_agent = ppo_clip_agent.PPOClipAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            entropy_regularization=0.01,
            importance_ratio_clipping=0.03,
            normalize_observations=False,
            normalize_rewards=False,
            use_gae=True,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(
                batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)

        train_checkpointer.initialize_or_restore()

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            # TODO(b/123828980): Enable once the cause for slowdown was identified.
            collect_driver.run = common.function(collect_driver.run,
                                                 autograph=False)
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()
            if global_step_val % eval_interval == 0:
                metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )

            start_time = time.time()
            collect_driver.run()
            collect_time += time.time() - start_time

            start_time = time.time()
            total_loss, _ = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             total_loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )

        print("done")
        time.sleep(10)
        print("starting teleop")

        for _ in range(num_eval_episodes):
            time_step = eval_tf_env.reset()
            policy_state = eval_policy.get_initial_state(
                eval_tf_env.batch_size)
            while True:
                action_step = eval_policy.action(time_step, policy_state)
                time_step = eval_tf_env.step(action_step.action)
                eval_py_env.render()