示例#1
0
 def _make_env():
     # function to create a tf environment
     return tf_py_environment.TFPyEnvironment(
         suite_gym.load("MountainCarContinuous-v0"))
示例#2
0
      discount=1,
      spec_dtype_map=None,
      auto_reset=True,
      render_kwargs=None,
  )
eval_py_env = gym_wrapper.GymWrapper(
      ChangeRewardMountainCarEnv(),
      discount=1,
      spec_dtype_map=None,
      auto_reset=True,
      render_kwargs=None,
  )
train_py_env = wrappers.TimeLimit(train_py_env, duration=200)
eval_py_env = wrappers.TimeLimit(eval_py_env, duration=200)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

RL_train(train_env, eval_env, fc_layer_params = (48,64,), name = '_train')

"""Set num_iterations to 50000+ will let agent converge to less than 110 steps"""

iterations = range(len(returns))
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')

iterations = range(len(steps))
plt.plot(iterations, steps)
plt.ylabel('Average Step')
plt.xlabel('Iterations')
def main(unused_argv):
  tf.compat.v1.enable_v2_behavior()  # The trainer only runs with V2 enabled.

  with tf.device('/CPU:0'):  # due to b/128333994
    if FLAGS.normalize_reward_fns:
      action_reward_fns = (
          environment_utilities.normalized_sliding_linear_reward_fn_generator(
              CONTEXT_DIM, NUM_ACTIONS, REWARD_NOISE_VARIANCE))
    else:
      action_reward_fns = (
          environment_utilities.sliding_linear_reward_fn_generator(
              CONTEXT_DIM, NUM_ACTIONS, REWARD_NOISE_VARIANCE))

    env = sspe.StationaryStochasticPyEnvironment(
        functools.partial(
            environment_utilities.context_sampling_fn,
            batch_size=BATCH_SIZE,
            context_dim=CONTEXT_DIM),
        action_reward_fns,
        batch_size=BATCH_SIZE)
    environment = tf_py_environment.TFPyEnvironment(env)

    optimal_reward_fn = functools.partial(
        environment_utilities.tf_compute_optimal_reward,
        per_action_reward_fns=action_reward_fns)

    optimal_action_fn = functools.partial(
        environment_utilities.tf_compute_optimal_action,
        per_action_reward_fns=action_reward_fns)

    network = q_network.QNetwork(
        input_tensor_spec=environment.time_step_spec().observation,
        action_spec=environment.action_spec(),
        fc_layer_params=LAYERS)

    if FLAGS.agent == 'LinUCB':
      agent = lin_ucb_agent.LinearUCBAgent(
          time_step_spec=environment.time_step_spec(),
          action_spec=environment.action_spec(),
          alpha=AGENT_ALPHA,
          dtype=tf.float32)
    elif FLAGS.agent == 'LinTS':
      agent = lin_ts_agent.LinearThompsonSamplingAgent(
          time_step_spec=environment.time_step_spec(),
          action_spec=environment.action_spec(),
          alpha=AGENT_ALPHA,
          dtype=tf.float32)
    elif FLAGS.agent == 'epsGreedy':
      agent = neural_epsilon_greedy_agent.NeuralEpsilonGreedyAgent(
          time_step_spec=environment.time_step_spec(),
          action_spec=environment.action_spec(),
          reward_network=network,
          optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR),
          epsilon=EPSILON)
    elif FLAGS.agent == 'Mix':
      emit_policy_info = policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN
      agent_linucb = lin_ucb_agent.LinearUCBAgent(
          time_step_spec=environment.time_step_spec(),
          action_spec=environment.action_spec(),
          emit_policy_info=emit_policy_info,
          alpha=AGENT_ALPHA,
          dtype=tf.float32)
      agent_lints = lin_ts_agent.LinearThompsonSamplingAgent(
          time_step_spec=environment.time_step_spec(),
          action_spec=environment.action_spec(),
          emit_policy_info=emit_policy_info,
          alpha=AGENT_ALPHA,
          dtype=tf.float32)
      agent_epsgreedy = neural_epsilon_greedy_agent.NeuralEpsilonGreedyAgent(
          time_step_spec=environment.time_step_spec(),
          action_spec=environment.action_spec(),
          reward_network=network,
          optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR),
          emit_policy_info=emit_policy_info,
          epsilon=EPSILON)
      agent = exp3_mixture_agent.Exp3MixtureAgent(
          (agent_linucb, agent_lints, agent_epsgreedy))

    regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn)
    suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric(
        optimal_action_fn)

    trainer.train(
        root_dir=FLAGS.root_dir,
        agent=agent,
        environment=environment,
        training_loops=TRAINING_LOOPS,
        steps_per_loop=STEPS_PER_LOOP,
        additional_metrics=[regret_metric, suboptimal_arms_metric])
示例#4
0
def main():
    env = SquigglesEnvironment(num_notes=2)
    env = tf_py_environment.TFPyEnvironment(env)

    N = env.observation_spec().shape[0]

    _, the_hits, actions = get_beats(N, ITER, env, policy_saved_filename)

    fpsClock = pygame.time.Clock()
    pygame.init()

    DISPLAY = pygame.display.set_mode((WIDTH, HEIGHT))
    pygame.display.set_caption("Squigs")
    """ Here's different sounds to use
    ,
    "sound_effects/19827__cabled-mess__glockenspiel/348882__cabled-mess__glockenspiel-18-g3-04.wav",
    "sound_effects/19827__cabled-mess__glockenspiel/348889__cabled-mess__glockenspiel-23-a3-05.wav",
    "sound_effects/19827__cabled-mess__glockenspiel/348895__cabled-mess__glockenspiel-24-bb3-01.wav",
    "sound_effects/19827__cabled-mess__glockenspiel/348904__cabled-mess__glockenspiel-29-b3-02.wav",
    "sound_effects/19827__cabled-mess__glockenspiel/348914__cabled-mess__glockenspiel-39-d4-04.wav",
    "sound_effects/19827__cabled-mess__glockenspiel/348918__cabled-mess__glockenspiel-40-e4-01.wav",
    "sound_effects/19827__cabled-mess__glockenspiel/348921__cabled-mess__glockenspiel-43-f4-01.wav"
    "sound_effects/19827__cabled-mess__glockenspiel/348870__cabled-mess__glockenspiel-04-d3-04.wav",
    "sound_effects/19827__cabled-mess__glockenspiel/348871__cabled-mess__glockenspiel-06-e3-01.wav",
    "sound_effects/19827__cabled-mess__glockenspiel/348878__cabled-mess__glockenspiel-11-f3-02.wav",
    "sound_effects/19827__cabled-mess__glockenspiel/348908__cabled-mess__glockenspiel-33-c4-02.wav"
    """
    """
    "sound_effects/9008__jamieblam__metallophone/146077__jamieblam__1d-hard.wav"
    "sound_effects/9008__jamieblam__metallophone/146079__jamieblam__1c-hard.wav"
    """
    """
    "sound_effects/21030__samulis__vsco-2-ce-percussion-marimba/373577__samulis__marimba-b3-marimba-hit-outrigger-b2-loud-01.wav"
    "sound_effects/21030__samulis__vsco-2-ce-percussion-marimba/373582__samulis__marimba-e-2-marimba-hit-outrigger-f1-loud-01.wav"
    """
    """
    "sound_effects/9008__jamieblam__metallophone/146096__jamieblam__2e-hard.wav",
    ,
    "sound_effects/9008__jamieblam__metallophone/146100__jamieblam__2f-hard.wav"
    ,
    ,
    ,
    "sound_effects/9008__jamieblam__metallophone/146082__jamieblam__1f-hard.wav",

    ,
    "sound_effects/9008__jamieblam__metallophone/146091__jamieblam__2c-hard.wav",
    "sound_effects/9008__jamieblam__metallophone/146093__jamieblam__2b-hard.wav"
    """

    env_slider = SoundSlider(
        sound_list=the_hits,
        position_x=0,
        position_y=HEIGHT // 3,
        height=HEIGHT // 7,
        width=WIDTH,
        color=(100, 100, 255),
        soundfile_name=
        [  #"sound_effects/9008__jamieblam__metallophone/146079__jamieblam__1c-hard.wav"#"sound_effects/9008__jamieblam__metallophone/146097__jamieblam__2d-hard.wav"
            "sound_effects/drum11.wav"
        ])
    agent_slider = SoundSlider(
        sound_list=actions,
        position_x=0,
        position_y=HEIGHT * 2 // 3,
        height=HEIGHT // 7,
        width=WIDTH,
        color=(255, 150, 30),
        soundfile_name=
        [  #"sound_effects/9008__jamieblam__metallophone/146084__jamieblam__1e-hard.wav" #"sound_effects/9008__jamieblam__metallophone/146087__jamieblam__1g-hard.wav"
            "sound_effects/first_clap.wav"
        ])

    barrier = SoundBarrier(position_x=WIDTH * 2 // 3,
                           position_y=HEIGHT // 4,
                           height=HEIGHT * 5 // 8,
                           width=WIDTH // 56,
                           color=(255, 100, 100),
                           slider_list=[env_slider, agent_slider])

    start = False
    while True:
        DISPLAY.fill((0, 0, 0))
        pygame.event.pump()

        for event in pygame.event.get():
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_SPACE:
                    start = True
            if event.type == QUIT:
                pygame.quit()
                sys.exit()
        if start:
            env_slider.update()
            agent_slider.update()
            barrier.update()

        env_slider.render(DISPLAY)
        agent_slider.render(DISPLAY)
        barrier.render(DISPLAY)

        pygame.display.update()
        fpsClock.tick(FPS)
def main(_):
  tf.random.set_seed(FLAGS.seed)

  if FLAGS.models_dir is None:
    raise ValueError('You must set a value for models_dir.')

  env = suite_mujoco.load(FLAGS.env_name)
  env.seed(FLAGS.seed)
  env = tf_py_environment.TFPyEnvironment(env)

  sac = actor_lib.Actor(env.observation_spec().shape[0], env.action_spec())

  model_filename = os.path.join(FLAGS.models_dir, 'DM-' + FLAGS.env_name,
                                str(FLAGS.model_seed), '1000000')
  sac.load_weights(model_filename)

  if FLAGS.std is None:
    if 'Reacher' in FLAGS.env_name:
      std = 0.5
    elif 'Ant' in FLAGS.env_name:
      std = 0.4
    elif 'Walker' in FLAGS.env_name:
      std = 2.0
    else:
      std = 0.75
  else:
    std = FLAGS.std

  def get_action(state):
    _, action, log_prob = sac(state, std)
    return action, log_prob

  dataset = dict(
      model_filename=model_filename,
      behavior_std=std,
      trajectories=dict(
          states=[],
          actions=[],
          log_probs=[],
          next_states=[],
          rewards=[],
          masks=[]))

  for i in range(FLAGS.num_episodes):
    timestep = env.reset()
    trajectory = dict(
        states=[],
        actions=[],
        log_probs=[],
        next_states=[],
        rewards=[],
        masks=[])

    while not timestep.is_last():
      action, log_prob = get_action(timestep.observation)
      next_timestep = env.step(action)

      trajectory['states'].append(timestep.observation)
      trajectory['actions'].append(action)
      trajectory['log_probs'].append(log_prob)
      trajectory['next_states'].append(next_timestep.observation)
      trajectory['rewards'].append(next_timestep.reward)
      trajectory['masks'].append(next_timestep.discount)

      timestep = next_timestep

    for k, v in trajectory.items():
      dataset['trajectories'][k].append(tf.concat(v, 0).numpy())

    logging.info('%d trajectories', i + 1)

  data_save_dir = os.path.join(FLAGS.save_dir, FLAGS.env_name,
                               str(FLAGS.model_seed))
  if not tf.io.gfile.isdir(data_save_dir):
    tf.io.gfile.makedirs(data_save_dir)

  save_filename = os.path.join(data_save_dir, f'dualdice_{FLAGS.std}.pckl')
  with tf.io.gfile.GFile(save_filename, 'wb') as f:
    pickle.dump(dataset, f)
示例#6
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    eval_env_name=None,
    env_load_fn=suite_mujoco.load,
    num_iterations=1000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
    gamma=0.99,
    reward_scale_factor=1.0,
    gradient_clipping=None,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):

  """A simple train and eval for SAC."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]
  eval_summary_flush_op = eval_summary_writer.flush()

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    # Create the environment.
    tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
    eval_env_name = eval_env_name or env_name
    eval_py_env = env_load_fn(eval_env_name)

    # Get the data specs from the environment
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=normal_projection_net)
    critic_net = critic_network.CriticNetwork(
        (observation_spec, action_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers)

    tf_agent = sac_agent.SacAgent(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=1,
        max_length=replay_buffer_capacity)
    replay_observer = [replay_buffer.add_batch]

    eval_py_policy = py_tf_policy.PyTFPolicy(
        greedy_policy.GreedyPolicy(tf_agent.policy))

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
        tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
    ]

    collect_policy = tf_agent.collect_policy
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())

    initial_collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps).run()

    collect_op = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration).run()

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]
    dataset = replay_buffer.as_dataset(
        sample_batch_size=5 * batch_size,
        num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(
                batch_size * 5)
    dataset_iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
    trajectories, unused_info = dataset_iterator.get_next()
    train_op = tf_agent.train(trajectories)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    with tf.compat.v1.Session() as sess:
      # Initialize graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)

      # Initialize training.
      sess.run(dataset_iterator.initializer)
      common.initialize_uninitialized_variables(sess)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())

      global_step_val = sess.run(global_step)

      if global_step_val == 0:
        # Initial eval of randomly initialized policy
        metric_utils.compute_summaries(
            eval_metrics,
            eval_py_env,
            eval_py_policy,
            num_episodes=num_eval_episodes,
            global_step=global_step_val,
            callback=eval_metrics_callback,
            log=True,
        )
        sess.run(eval_summary_flush_op)

        # Run initial collect.
        logging.info('Global step %d: Running initial collect op.',
                     global_step_val)
        sess.run(initial_collect_op)

        # Checkpoint the initial replay buffer contents.
        rb_checkpointer.save(global_step=global_step_val)

        logging.info('Finished initial collect.')
      else:
        logging.info('Global step %d: Skipping initial collect op.',
                     global_step_val)

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops])
      global_step_call = sess.make_callable(global_step)

      timed_at_step = global_step_call()
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        for _ in range(train_steps_per_iteration):
          total_loss, _ = train_step_call()
        time_acc += time.time() - start_time
        global_step_val = global_step_call()
        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val, total_loss.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
              log=True,
          )
          sess.run(eval_summary_flush_op)

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)
示例#7
0
 def testPyenv(self):
     py_env = PYEnvironmentMock()
     tf_env = tf_py_environment.TFPyEnvironment(py_env)
     self.assertIsInstance(tf_env.pyenv,
                           batched_py_environment.BatchedPyEnvironment)
示例#8
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        eval_env_name=None,
        num_iterations=1000000,
        # Params for networks.
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        num_parallel_environments=1,
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        train_sequence_length=20,
        critic_learning_rate=3e-4,
        actor_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=_DEFAULT_REWARD_SCALE,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for RNN SAC on DM control."""
    root_dir = os.path.expanduser(root_dir)

    if reward_scale_factor == _DEFAULT_REWARD_SCALE:
        # Use value recommended by https://arxiv.org/abs/1801.01290
        if env_name.startswith('Humanoid'):
            reward_scale_factor = 20.0
        else:
            reward_scale_factor = 5.0

    root_dir = os.path.expanduser(root_dir)

    summary_writer = tf.compat.v2.summary.create_file_writer(
        root_dir, flush_millis=summaries_flush_secs * 1000)
    summary_writer.set_as_default()

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []

        env_load_fn = functools.partial(suite_dm_control.load,
                                        task_name=task_name,
                                        env_wrappers=env_wrappers)

        if num_parallel_environments == 1:
            py_env = env_load_fn(env_name)
        else:
            py_env = parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_env_name = eval_env_name or env_name
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
            observation_spec,
            action_spec,
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers,
            continuous_projection_net=normal_projection_net)

        critic_net = critic_rnn_network.CriticRnnNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size * num_parallel_environments,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        env_steps = tf_metrics.EnvironmentSteps(prefix='Train')
        average_return = tf_metrics.AverageReturnMetric(
            prefix='Train',
            buffer_size=num_eval_episodes,
            batch_size=tf_env.batch_size)
        train_metrics = [
            tf_metrics.NumberOfEpisodes(prefix='Train'),
            env_steps,
            average_return,
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='Train',
                buffer_size=num_eval_episodes,
                batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(root_dir, 'train'),
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            root_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        if env_steps.result() == 0 or replay_buffer.num_frames() == 0:
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps'
                'with a random policy.', initial_collect_episodes)
            initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=env_steps.result(),
            summary_writer=summary_writer,
            summary_prefix='Eval',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, env_steps.result())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        time_acc = 0
        env_steps_before = env_steps.result().numpy()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            start_env_steps = env_steps.result()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            episode_steps = env_steps.result() - start_env_steps
            for _ in range(episode_steps):
                for _ in range(train_steps_per_iteration):
                    train_step()
                time_acc += time.time() - start_time

                if global_step.numpy() % log_interval == 0:
                    logging.info('env steps = %d, average return = %f',
                                 env_steps.result(), average_return.result())
                    env_steps_per_sec = (env_steps.result().numpy() -
                                         env_steps_before) / time_acc
                    logging.info('%.3f env steps/sec', env_steps_per_sec)
                    tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                                data=env_steps_per_sec,
                                                step=env_steps.result())
                    time_acc = 0
                    env_steps_before = env_steps.result().numpy()

                for train_metric in train_metrics:
                    train_metric.tf_summaries(train_step=env_steps.result())

                if global_step.numpy() % eval_interval == 0:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=env_steps.result(),
                        summary_writer=summary_writer,
                        summary_prefix='Eval',
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, env_steps.numpy())
                    metric_utils.log_metrics(eval_metrics)

                global_step_val = global_step.numpy()
                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)
示例#9
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=20000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        agent_class=dqn_agent.DqnAgent,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):

        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_py_env = suite_gym.load(env_name)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = agent_class(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate),
            # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec(),
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy())

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        global_step = tf.train.get_or_create_global_step()

        replay_observer = [replay_buffer.add_batch]
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        initial_collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps).run()

        collect_policy = tf_agent.collect_policy()
        collect_op = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration).run()

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)

        iterator = dataset.make_initializable_iterator()
        trajectories, _ = iterator.get_next()
        train_op = tf_agent.train(experience=trajectories,
                                  train_step_counter=global_step)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=tf.contrib.checkpoint.List(train_metrics))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy(),
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])
        summary_op = tf.contrib.summary.all_summary_ops()

        with eval_summary_writer.as_default(), \
             tf.contrib.summary.always_record_summaries():
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            tf.contrib.summary.initialize(session=sess)
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(
                [train_op, summary_op, global_step])

            timed_at_step = sess.run(global_step)
            collect_time = 0
            train_time = 0
            steps_per_second_ph = tf.placeholder(tf.float32,
                                                 shape=(),
                                                 name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                # Train/collect/eval.
                start_time = time.time()
                collect_call()
                collect_time += time.time() - start_time
                start_time = time.time()
                for _ in range(train_steps_per_iteration):
                    loss_info_value, _, global_step_val = train_step_call()
                train_time += time.time() - start_time

                if global_step_val % log_interval == 0:
                    tf.logging.info('step = %d, loss = %f', global_step_val,
                                    loss_info_value.loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    tf.logging.info('%.3f steps/sec' % steps_per_sec)
                    tf.logging.info(
                        'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                    )
示例#10
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_whitelist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_steps=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        ou_stddev=0.2,
        ou_damping=0.15,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        dqda_clipping=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries, and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=10000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        eval_metrics_callback=None):
    """A simple train and eval for DDPG."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_whitelist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_whitelist=[observations_whitelist])
            ]
        else:
            env_wrappers = []
        environment = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)
        tf_env = tf_py_environment.TFPyEnvironment(environment)
        eval_py_env = suite_dm_control.load(env_name,
                                            task_name,
                                            env_wrappers=env_wrappers)

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_rnn_network.CriticRnnNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            ou_stddev=ou_stddev,
            ou_damping=ou_damping,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            dqda_clipping=dqda_clipping,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            debug_summaries=debug_summaries,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # TODO(oars): Refactor drivers to better handle policy states. Remove the
        # policy reset and passing down an empyt policy state to the driver.
        collect_policy = tf_agent.collect_policy
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        initial_collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_steps).run(policy_state=policy_state)

        policy_state = collect_policy.get_initial_state(tf_env.batch_size)
        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration).run(
                policy_state=policy_state)

        # Need extra step to generate transitions of train_sequence_length.
        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)

        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        trajectories, unused_info = iterator.get_next()
        train_op = tf_agent.train(experience=trajectories)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries(step_metrics=train_metrics[:2])

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries()

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session() as sess:
            # Initialize the graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            sess.run(iterator.initializer)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())
            sess.run(initial_collect_op)

            global_step_val = sess.run(global_step)
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )

            collect_call = sess.make_callable(collect_op)
            train_step_call = sess.make_callable(train_op)
            global_step_call = sess.make_callable(global_step)

            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            for _ in range(num_iterations):
                start_time = time.time()
                collect_call()
                for _ in range(train_steps_per_iteration):
                    loss_info_value = train_step_call()
                time_acc += time.time() - start_time
                global_step_val = global_step_call()

                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 loss_info_value.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    timed_at_step = global_step_val
                    time_acc = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
示例#11
0
print(env.action_spec())

time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

# AGENT

fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1


# Define a helper function to create Dense layers configured with the right
# activation and kernel initializer.
def dense_layer(num_units):
    return tf.keras.layers.Dense(
        num_units,
        activation=tf.keras.activations.relu,
示例#12
0
def main():
    parser = argparse.ArgumentParser()

    ## Essential parameters
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model stats and checkpoints will be written."
    )
    parser.add_argument("--env",
                        default=None,
                        type=str,
                        required=True,
                        help="The environment to train the agent on")
    parser.add_argument("--approx_env_boundaries",
                        default=False,
                        type=bool,
                        help="Whether to get the env boundaries approximately")
    parser.add_argument("--max_horizon", default=5, type=int)
    parser.add_argument("--atari",
                        default=False,
                        type=bool,
                        help="Gets some data Types correctly")

    ##agent parameters
    parser.add_argument("--reward_scale_factor", default=1.0, type=float)
    parser.add_argument("--debug_summaries", default=True, type=bool)
    parser.add_argument("--summarize_grads_and_vars", default=True, type=bool)

    ##transformer parameters
    parser.add_argument("--d_model", default=64, type=int)
    parser.add_argument("--num_layers", default=3, type=int)
    parser.add_argument("--dff", default=256, type=int)

    ##Training parameters
    parser.add_argument('--num_iterations',
                        type=int,
                        default=150000,
                        help="steps in the env")
    parser.add_argument('--num_iparallel',
                        type=int,
                        default=1,
                        help="how many envs should run in parallel")
    parser.add_argument("--collect_steps_per_iteration", default=1, type=int)
    parser.add_argument("--train_steps_per_iteration", default=1, type=int)

    ## Other parameters
    parser.add_argument("--num_eval_episodes", default=10, type=int)
    parser.add_argument("--eval_interval", default=1000, type=int)
    parser.add_argument("--log_interval", default=1000, type=int)
    parser.add_argument("--summary_interval", default=10, type=int)
    parser.add_argument("--run_graph_mode", default=True, type=bool)
    parser.add_argument("--checkpoint_interval", default=10000, type=int)
    parser.add_argument("--summary_flush", default=10,
                        type=int)  #what does this exactly do?

    # HP opt params
    parser.add_argument("--doubleQ",
                        default=True,
                        type=bool,
                        help="Whether to use a  DoubleQ agent")
    parser.add_argument("--custom_last_layer", default=True, type=bool)
    parser.add_argument("--custom_layer_init", default=0.5, type=float)
    parser.add_argument("--initial_collect_steps", default=1000, type=int)
    parser.add_argument("--loss_function",
                        default="element_wise_huber_loss",
                        type=str)
    parser.add_argument("--num_heads", default=4, type=int)
    parser.add_argument("--normalize_env", default=False, type=bool)
    parser.add_argument('--custom_lr_schedule',
                        default="No",
                        type=str,
                        help="whether to use a custom LR schedule")
    parser.add_argument("--epsilon_greedy", default=0.1, type=float)
    parser.add_argument("--target_update_period", default=1, type=int)
    parser.add_argument(
        "--rate", default=0.1, type=float
    )  # dropout rate  (might be not used depending on the q network)  #Setting this to 0.0 somehow break the code. Not relevant tho just select a network without dropout
    parser.add_argument("--gradient_clipping", default=None, type=bool)
    parser.add_argument("--replay_buffer_max_length", default=100000, type=int)
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--learning_rate", default=1e-5, type=float)
    parser.add_argument("--encoder_type",
                        default=2,
                        type=int,
                        help="Which Type of encoder is used for the model")
    parser.add_argument("--layer_type",
                        default=1,
                        type=int,
                        help="Which Type of layer is used for the encoder")
    parser.add_argument("--target_update_tau", default=1, type=float)
    parser.add_argument("--gamma", default=1.0, type=float)

    args = parser.parse_args()
    # List of encoder modules which we can use to change encoder based on a variable
    global_step = tf.compat.v1.train.get_or_create_global_step()

    baseEnv = gym.make(args.env)
    env = suite_gym.load(args.env)
    eval_env = suite_gym.load(args.env)
    if args.normalize_env == True:
        env = NormalizeWrapper(env, args.approx_env_boundaries, args.env)
        eval_env = NormalizeWrapper(eval_env, args.approx_env_boundaries,
                                    args.env)
    env = PyhistoryWrapper(env, args.max_horizon, args.atari)
    eval_env = PyhistoryWrapper(eval_env, args.max_horizon, args.atari)
    tf_env = tf_py_environment.TFPyEnvironment(env)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_env)

    q_net = QTransformer(tf_env.observation_spec(),
                         baseEnv.action_space.n,
                         num_layers=args.num_layers,
                         d_model=args.d_model,
                         num_heads=args.num_heads,
                         dff=args.dff,
                         rate=args.rate,
                         encoderType=args.encoder_type,
                         enc_layer_type=args.layer_type,
                         max_horizon=args.max_horizon,
                         custom_layer=args.custom_layer_init,
                         custom_last_layer=args.custom_last_layer)

    if args.custom_lr_schedule == "Transformer":  # builds a lr schedule according to the original usage for the transformer
        learning_rate = CustomSchedule(args.d_model,
                                       int(args.num_iterations / 10))
        optimizer = tf.keras.optimizers.Adam(learning_rate,
                                             beta_1=0.9,
                                             beta_2=0.98,
                                             epsilon=1e-9)

    elif args.custom_lr_schedule == "Transformer_low":  # builds a lr schedule according to the original usage for the transformer
        learning_rate = CustomSchedule(
            int(args.d_model / 2),
            int(args.num_iterations /
                10))  # --> same schedule with lower general lr
        optimizer = tf.keras.optimizers.Adam(learning_rate,
                                             beta_1=0.9,
                                             beta_2=0.98,
                                             epsilon=1e-9)

    elif args.custom_lr_schedule == "Linear":
        lrs = LinearCustomSchedule(learning_rate, args.num_iterations)
        optimizer = tf.keras.optimizers.Adam(lrs,
                                             beta_1=0.9,
                                             beta_2=0.98,
                                             epsilon=1e-9)

    else:
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=args.learning_rate)

    if args.loss_function == "element_wise_huber_loss":
        lf = element_wise_huber_loss
    elif args.loss_function == "element_wise_squared_loss":
        lf = element_wise_squared_loss

    if args.doubleQ == False:  # global step count
        agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=args.epsilon_greedy,
            target_update_tau=args.target_update_tau,
            target_update_period=args.target_update_period,
            td_errors_loss_fn=lf,
            optimizer=optimizer,
            gamma=args.gamma,
            reward_scale_factor=args.reward_scale_factor,
            gradient_clipping=args.gradient_clipping,
            debug_summaries=args.debug_summaries,
            summarize_grads_and_vars=args.summarize_grads_and_vars,
            train_step_counter=global_step)
    else:
        agent = dqn_agent.DdqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            epsilon_greedy=args.epsilon_greedy,
            target_update_tau=args.target_update_tau,
            td_errors_loss_fn=lf,
            target_update_period=args.target_update_period,
            optimizer=optimizer,
            gamma=args.gamma,
            reward_scale_factor=args.reward_scale_factor,
            gradient_clipping=args.gradient_clipping,
            debug_summaries=args.debug_summaries,
            summarize_grads_and_vars=args.summarize_grads_and_vars,
            train_step_counter=global_step)
    agent.initialize()

    count_weights(q_net)

    train_eval(root_dir=args.output_dir,
               tf_env=tf_env,
               eval_tf_env=eval_tf_env,
               agent=agent,
               num_iterations=args.num_iterations,
               initial_collect_steps=args.initial_collect_steps,
               collect_steps_per_iteration=args.collect_steps_per_iteration,
               replay_buffer_capacity=args.replay_buffer_max_length,
               train_steps_per_iteration=args.train_steps_per_iteration,
               batch_size=args.batch_size,
               use_tf_functions=args.run_graph_mode,
               num_eval_episodes=args.num_eval_episodes,
               eval_interval=args.eval_interval,
               train_checkpoint_interval=args.checkpoint_interval,
               policy_checkpoint_interval=args.checkpoint_interval,
               rb_checkpoint_interval=args.checkpoint_interval,
               log_interval=args.log_interval,
               summary_interval=args.summary_interval,
               summaries_flush_secs=args.summary_flush)

    pickle.dump(args, open(args.output_dir + "/training_args.p", "wb"))
    print("Successfully trained and evaluation.")
def train_eval(
        root_dir,
        env_name='MultiGrid-Empty-5x5-v0',
        env_load_fn=multiagent_gym_suite.load,
        random_seed=0,
        # Architecture params
        actor_fc_layers=(64, 64),
        value_fc_layers=(64, 64),
        lstm_size=(64, ),
        conv_filters=64,
        conv_kernel=3,
        direction_fc=5,
        entropy_regularization=0.,
        use_attention_networks=False,
        # Specialized agents
        inactive_agent_ids=tuple(),
        # Params for collect
        num_environment_steps=25000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=5,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=2,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=2,
        eval_interval=5,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=100,
        log_interval=10,
        summary_interval=10,
        summaries_flush_secs=1,
        use_tf_functions=True,
        debug_summaries=True,
        summarize_grads_and_vars=True,
        eval_metrics_callback=None,
        reinit_checkpoint_dir=None,
        debug=True):
    """A simple train and eval for PPO."""
    tf.compat.v1.enable_v2_behavior()

    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    if debug:
        logging.info('In debug mode, turning tf_functions off')
        use_tf_functions = False

    for a in inactive_agent_ids:
        logging.info('Fixing and not training agent %d', a)

    # Load multiagent gym environment and determine number of agents
    gym_env = env_load_fn(env_name)
    n_agents = gym_env.n_agents

    # Set up logging
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')
    saved_model_dir = os.path.join(root_dir, 'policy_saved_model')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        multiagent_metrics.AverageReturnMetric(n_agents,
                                               buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if random_seed is not None:
            tf.compat.v1.set_random_seed(random_seed)

        logging.info('Creating %d environments...', num_parallel_environments)
        wrappers = []
        if use_attention_networks:
            wrappers = [
                lambda env: utils.LSTMStateWrapper(env, lstm_size=lstm_size)
            ]

        eval_tf_env = tf_py_environment.TFPyEnvironment(
            env_load_fn(env_name,
                        gym_kwargs=dict(seed=random_seed),
                        gym_env_wrappers=wrappers))
        # pylint: disable=g-complex-comprehension
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                functools.partial(env_load_fn,
                                  environment_name=env_name,
                                  gym_env_wrappers=wrappers,
                                  gym_kwargs=dict(seed=random_seed * 1234 + i))
                for i in range(num_parallel_environments)
            ]))

        logging.info('Preparing to train...')
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]

        train_metrics = step_metrics + [
            multiagent_metrics.AverageReturnMetric(
                n_agents, batch_size=num_parallel_environments),
            tf_metrics.AverageEpisodeLengthMetric(
                batch_size=num_parallel_environments)
        ]

        logging.info('Creating agent...')
        tf_agent = multiagent_ppo.MultiagentPPO(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            n_agents=n_agents,
            learning_rate=learning_rate,
            actor_fc_layers=actor_fc_layers,
            value_fc_layers=value_fc_layers,
            lstm_size=lstm_size,
            conv_filters=conv_filters,
            conv_kernel=conv_kernel,
            direction_fc=direction_fc,
            entropy_regularization=entropy_regularization,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
            inactive_agent_ids=inactive_agent_ids,
            use_attention_networks=use_attention_networks)
        tf_agent.initialize()
        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        logging.info('Allocating replay buffer ...')
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        logging.info('RB capacity: %i', replay_buffer.capacity)

        # If reinit_checkpoint_dir is provided, the last agent in the checkpoint is
        # reinitialized. The other agents are novices.
        # Otherwise, all agents are reinitialized from train_dir.
        if reinit_checkpoint_dir:
            reinit_checkpointer = common.Checkpointer(
                ckpt_dir=reinit_checkpoint_dir,
                agent=tf_agent,
            )
            reinit_checkpointer.initialize_or_restore()
            temp_dir = os.path.join(train_dir, 'tmp')
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir,
                agent=tf_agent.agents[:-1],
            )
            agent_checkpointer.save(global_step=0)
            tf_agent = multiagent_ppo.MultiagentPPO(
                tf_env.time_step_spec(),
                tf_env.action_spec(),
                n_agents=n_agents,
                learning_rate=learning_rate,
                actor_fc_layers=actor_fc_layers,
                value_fc_layers=value_fc_layers,
                lstm_size=lstm_size,
                conv_filters=conv_filters,
                conv_kernel=conv_kernel,
                direction_fc=direction_fc,
                entropy_regularization=entropy_regularization,
                num_epochs=num_epochs,
                debug_summaries=debug_summaries,
                summarize_grads_and_vars=summarize_grads_and_vars,
                train_step_counter=global_step,
                inactive_agent_ids=inactive_agent_ids,
                non_learning_agents=list(range(n_agents - 1)),
                use_attention_networks=use_attention_networks)
            agent_checkpointer = common.Checkpointer(
                ckpt_dir=temp_dir, agent=tf_agent.agents[:-1])
            agent_checkpointer.initialize_or_restore()
            tf.io.gfile.rmtree(temp_dir)
            eval_policy = tf_agent.policy
            collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=multiagent_metrics.MultiagentMetricsGroup(
                train_metrics, 'train_metrics'))
        if not reinit_checkpoint_dir:
            train_checkpointer.initialize_or_restore()
        logging.info('Successfully initialized train checkpointer')

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        saved_model = policy_saver.PolicySaver(eval_policy,
                                               train_step=global_step)
        logging.info('Successfully initialized policy saver.')

        print('Using TFDriver')
        if use_attention_networks:
            collect_driver = utils.StateTFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)
        else:
            collect_driver = tf_driver.TFDriver(
                tf_env,
                collect_policy,
                observers=[replay_buffer.add_batch] + train_metrics,
                max_episodes=collect_episodes_per_iteration,
                disable_tf_function=not use_tf_functions)

        def train_step():
            trajectories = replay_buffer.gather_all()
            return tf_agent.train(experience=trajectories)

        if use_tf_functions:
            tf_agent.train = common.function(tf_agent.train, autograph=False)
            train_step = common.function(train_step)

        collect_time = 0
        train_time = 0
        timed_at_step = global_step.numpy()

        # How many consecutive steps was loss diverged for.
        loss_divergence_counter = 0

        # Save operative config as late as possible to include used configurables.
        if global_step.numpy() == 0:
            config_filename = os.path.join(
                train_dir,
                'operative_config-{}.gin'.format(global_step.numpy()))
            with tf.io.gfile.GFile(config_filename, 'wb') as f:
                f.write(gin.operative_config_str())

        total_episodes = 0
        logging.info('Commencing train loop!')
        while environment_steps_metric.result() < num_environment_steps:
            global_step_val = global_step.numpy()

            # Evaluation
            if global_step_val % eval_interval == 0:
                if debug:
                    logging.info('Performing evaluation at step %d',
                                 global_step_val)
                results = multiagent_metrics.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                    use_function=use_tf_functions,
                    use_attention_networks=use_attention_networks)
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                multiagent_metrics.log_metrics(eval_metrics)

            # Collect data
            if debug:
                logging.info('Collecting at step %d', global_step_val)
            start_time = time.time()
            time_step = tf_env.reset()
            policy_state = collect_policy.get_initial_state(tf_env.batch_size)
            if use_attention_networks:
                # Attention networks require previous policy state to compute attention
                # weights.
                time_step.observation['policy_state'] = (
                    policy_state['actor_network_state'][0],
                    policy_state['actor_network_state'][1])
            collect_driver.run(time_step, policy_state)
            collect_time += time.time() - start_time

            total_episodes += collect_episodes_per_iteration
            if debug:
                logging.info('Have collected a total of %d episodes',
                             total_episodes)

            # Train
            if debug:
                logging.info('Training at step %d', global_step_val)
            start_time = time.time()
            total_loss, extra_loss = train_step()
            replay_buffer.clear()
            train_time += time.time() - start_time

            # Check for exploding losses.
            if (math.isnan(total_loss) or math.isinf(total_loss)
                    or total_loss > MAX_LOSS):
                loss_divergence_counter += 1
                if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
                    logging.info(
                        'Loss diverged for too many timesteps, breaking...')
                    break
            else:
                loss_divergence_counter = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=step_metrics)

            if global_step_val % log_interval == 0:
                logging.info('step = %d, total loss = %f', global_step_val,
                             total_loss)
                for a in range(n_agents):
                    if not inactive_agent_ids or a not in inactive_agent_ids:
                        logging.info('Loss for agent %d = %f', a,
                                     extra_loss[a].loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info('collect_time = %.3f, train_time = %.3f',
                             collect_time, train_time)
                with tf.compat.v2.summary.record_if(True):
                    tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                                data=steps_per_sec,
                                                step=global_step)

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)
                    saved_model_path = os.path.join(
                        saved_model_dir,
                        'policy_' + ('%d' % global_step_val).zfill(9))
                    saved_model.save(saved_model_path)

                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

        # One final eval before exiting.
        results = multiagent_metrics.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
            use_function=use_tf_functions,
            use_attention_networks=use_attention_networks)
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        multiagent_metrics.log_metrics(eval_metrics)
示例#14
0
def train_level(level,
                consecutive_wins_flag=5,
                collect_random_steps=True,
                max_iterations=num_iterations):
    """
    create DQN agent to train a level of the game
    :param level: level of the game
    :param consecutive_wins_flag: number of consecutive wins in evaluation
    signifying the training is done
    :param collect_random_steps: whether to collect random steps at the beginning,
    always set to 'True' when the global step is 0.
    :param max_iterations: stop the training when it reaches the max iteration
    regardless of the result
    """
    global saving_time
    cells = query_level(level)
    size = len(cells)
    env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells))
    eval_env = tf_py_environment.TFPyEnvironment(GameEnv(size, cells))

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    fc_layer_params = (neuron_num_mapper[size], )

    q_net = q_network.QNetwork(env.observation_spec()[0],
                               env.action_spec(),
                               fc_layer_params=fc_layer_params,
                               activation_fn=tf.keras.activations.relu)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DdqnAgent(
        env.time_step_spec(),
        env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=global_step,
        observation_and_action_constraint_splitter=GameEnv.
        obs_and_mask_splitter)
    agent.initialize()

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=env.batch_size,
        max_length=replay_buffer_max_length)

    # drivers
    collect_driver = dynamic_step_driver.DynamicStepDriver(
        env,
        policy=agent.collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=collect_steps_per_iteration)

    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    eval_driver = dynamic_episode_driver.DynamicEpisodeDriver(
        eval_env,
        policy=agent.policy,
        observers=eval_metrics,
        num_episodes=num_eval_episodes)

    # checkpointer of the replay buffer and policy
    train_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        dir_path, 'trained_policies/train_lv{0}'.format(level)),
                                             max_to_keep=1,
                                             agent=agent,
                                             policy=agent.policy,
                                             global_step=global_step,
                                             replay_buffer=replay_buffer)

    # policy saver
    tf_policy_saver = policy_saver.PolicySaver(agent.policy)

    train_checkpointer.initialize_or_restore()

    # optimize by wrapping some of the code in a graph using TF function
    agent.train = common.function(agent.train)
    collect_driver.run = common.function(collect_driver.run)
    eval_driver.run = common.function(eval_driver.run)

    # collect initial replay data
    if collect_random_steps:
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec=env.time_step_spec(),
            action_spec=env.action_spec(),
            observation_and_action_constraint_splitter=GameEnv.
            obs_and_mask_splitter)

        dynamic_step_driver.DynamicStepDriver(
            env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch],
            num_steps=initial_collect_steps).run()

    # Dataset generates trajectories with shape [Bx2x...]
    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)
    iterator = iter(dataset)

    # train the model until 5 consecutive evaluation have reward greater than 100
    consecutive_eval_win = 0
    train_iterations = 0
    while consecutive_eval_win < consecutive_wins_flag and train_iterations < max_iterations:
        collect_driver.run()

        for _ in range(collect_steps_per_iteration):
            experience, _ = next(iterator)
            train_loss = agent.train(experience).loss

        # evaluate the training at intervals
        step = global_step.numpy()
        if step % eval_interval == 0:
            eval_driver.run()
            average_return = eval_metrics[0].result().numpy()
            average_len = eval_metrics[1].result().numpy()
            print("level: {0} step: {1} AverageReturn: {2} AverageLen: {3}".
                  format(level, step, average_return, average_len))

            # evaluate consecutive wins
            if average_return > 10:
                consecutive_eval_win += 1
            else:
                consecutive_eval_win = 0

        if step % save_interval == 0:
            start = time.time()
            train_checkpointer.save(global_step=step)
            saving_time += time.time() - start

        train_iterations += 1

    # save the policy
    train_checkpointer.save(global_step=global_step.numpy())
    tf_policy_saver.save(
        os.path.join(dir_path, 'trained_policies/policy_lv{0}'.format(level)))
示例#15
0
def get_env_and_policy(load_dir,
                       env_name,
                       alpha,
                       env_seed=0,
                       tabular_obs=False):
    if env_name == 'taxi':
        env = taxi.Taxi(tabular_obs=tabular_obs)
        env.seed(env_seed)
        policy_fn, policy_info_spec = taxi.get_taxi_policy(load_dir,
                                                           env,
                                                           alpha=alpha,
                                                           py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'grid':
        env = navigation.GridWalk(tabular_obs=tabular_obs)
        env.seed(env_seed)
        policy_fn, policy_info_spec = navigation.get_navigation_policy(
            env, epsilon_explore=0.1 + 0.6 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'low_rank':
        env = low_rank.LowRank()
        env.seed(env_seed)
        policy_fn, policy_info_spec = low_rank.get_low_rank_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'tree':
        env = tree.Tree(branching=2, depth=10)
        env.seed(env_seed)
        policy_fn, policy_info_spec = tree.get_tree_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'lowrank_tree':
        env = tree.Tree(branching=2, depth=3, duplicate=10)
        env.seed(env_seed)
        policy_fn, policy_info_spec = tree.get_tree_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name.startswith('bandit'):
        num_arms = int(env_name[6:]) if len(env_name) > 6 else 2
        env = bandit.Bandit(num_arms=num_arms)
        env.seed(env_seed)
        policy_fn, policy_info_spec = bandit.get_bandit_policy(
            env, epsilon_explore=1 - alpha, py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'small_tree':
        env = tree.Tree(branching=2, depth=3, loop=True)
        env.seed(env_seed)
        policy_fn, policy_info_spec = tree.get_tree_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'CartPole-v0':
        tf_env, policy = get_env_and_dqn_policy(
            env_name,
            os.path.join(load_dir, 'CartPole-v0', 'train', 'policy'),
            env_seed=env_seed,
            epsilon=0.3 + 0.15 * (1 - alpha))
    elif env_name == 'cartpole':  # Infinite-horizon cartpole.
        tf_env, policy = get_env_and_dqn_policy(
            'CartPole-v0',
            os.path.join(load_dir, 'CartPole-v0-250', 'train', 'policy'),
            env_seed=env_seed,
            epsilon=0.3 + 0.15 * (1 - alpha))
        env = InfiniteCartPole()
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
    elif env_name == 'FrozenLake-v0':
        tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0',
                                                os.path.join(
                                                    load_dir, 'FrozenLake-v0',
                                                    'train', 'policy'),
                                                env_seed=env_seed,
                                                epsilon=0.2 * (1 - alpha),
                                                ckpt_file='ckpt-100000')
    elif env_name == 'frozenlake':  # Infinite-horizon frozenlake.
        tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0',
                                                os.path.join(
                                                    load_dir, 'FrozenLake-v0',
                                                    'train', 'policy'),
                                                env_seed=env_seed,
                                                epsilon=0.2 * (1 - alpha),
                                                ckpt_file='ckpt-100000')
        env = InfiniteFrozenLake()
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
    elif env_name in ['Reacher-v2', 'reacher']:
        if env_name == 'Reacher-v2':
            env = suites.load_mujoco(env_name)
        else:
            env = gym_wrapper.GymWrapper(InfiniteReacher())
        env.seed(env_seed)
        tf_env = tf_py_environment.TFPyEnvironment(env)
        sac_policy = get_sac_policy(tf_env)
        directory = os.path.join(load_dir, 'Reacher-v2', 'train', 'policy')
        policy = load_policy(sac_policy, env_name, directory)
        policy = GaussianPolicy(policy,
                                0.4 - 0.3 * alpha,
                                emit_log_probability=True)
    elif env_name == 'HalfCheetah-v2':
        env = suites.load_mujoco(env_name)
        env.seed(env_seed)
        tf_env = tf_py_environment.TFPyEnvironment(env)
        sac_policy = get_sac_policy(tf_env)
        directory = os.path.join(load_dir, env_name, 'train', 'policy')
        policy = load_policy(sac_policy, env_name, directory)
        policy = GaussianPolicy(policy,
                                0.2 - 0.1 * alpha,
                                emit_log_probability=True)
    else:
        raise ValueError('Unrecognized environment %s.' % env_name)

    return tf_env, policy
def main(unused_argv):
    tf.compat.v1.enable_v2_behavior()  # The trainer only runs with V2 enabled.

    class LinearNormalReward(object):
        def __init__(self, theta):
            self.theta = theta

        def __call__(self, x):
            mu = np.dot(x, self.theta)
            return np.random.normal(mu, 1)

    def _global_context_sampling_fn():
        return np.random.randint(-10, 10, [4]).astype(np.float32)

    def _arm_context_sampling_fn():
        return np.random.randint(-2, 3, [5]).astype(np.float32)

    reward_fn = LinearNormalReward(HIDDEN_PARAM)

    env = sspe.StationaryStochasticPerArmPyEnvironment(
        _global_context_sampling_fn,
        _arm_context_sampling_fn,
        NUM_ACTIONS,
        reward_fn,
        batch_size=BATCH_SIZE)
    environment = tf_py_environment.TFPyEnvironment(env)

    obs_spec = environment.observation_spec()
    if FLAGS.network == 'commontower':
        network = (global_and_arm_feature_network.
                   create_feed_forward_common_tower_network(
                       obs_spec, (4, 3), (3, 4), (4, 2)))
    elif FLAGS.network == 'dotproduct':
        network = (global_and_arm_feature_network.
                   create_feed_forward_dot_product_network(
                       obs_spec, (4, 3, 6), (3, 4, 6)))
    if FLAGS.drop_arm_obs:

        def drop_arm_feature_fn(traj):
            transformed_traj = copy.deepcopy(traj)
            del transformed_traj.observation[
                bandit_spec_utils.PER_ARM_FEATURE_KEY]
            return transformed_traj
    else:
        drop_arm_feature_fn = None
    agent = neural_epsilon_greedy_agent.NeuralEpsilonGreedyAgent(
        time_step_spec=environment.time_step_spec(),
        action_spec=environment.action_spec(),
        reward_network=network,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR),
        epsilon=EPSILON,
        accepts_per_arm_features=True,
        training_data_spec_transformation_fn=drop_arm_feature_fn,
        emit_policy_info=policy_utilities.InfoFields.PREDICTED_REWARDS_MEAN)

    optimal_reward_fn = functools.partial(optimal_reward,
                                          hidden_param=HIDDEN_PARAM)
    optimal_action_fn = functools.partial(optimal_action,
                                          hidden_param=HIDDEN_PARAM)
    regret_metric = tf_bandit_metrics.RegretMetric(optimal_reward_fn)
    suboptimal_arms_metric = tf_bandit_metrics.SuboptimalArmsMetric(
        optimal_action_fn)

    trainer.train(root_dir=FLAGS.root_dir,
                  agent=agent,
                  environment=environment,
                  training_loops=TRAINING_LOOPS,
                  steps_per_loop=STEPS_PER_LOOP,
                  additional_metrics=[regret_metric, suboptimal_arms_metric],
                  training_data_spec_transformation_fn=drop_arm_feature_fn)
示例#17
0
def record_env():
    gif_path = "./images/test.gif"
    frames_path = "./images/episode-{i}-timestep-{t}.jpg"
    gating_bitmap = "./scenario-1/bitmaps/gating_mask.bmp"

    # pos_init = np.array([2.7, 2.0, 0.0])  # desired start state
    # pos_end_targ = np.array([-2.6, -1.5, 2.5])  # desired end state
    # state_init = np.array([2.7, 2.0, 0.0, 0.0, 0.0, 0.0])  # desired start state
    # state_end_targ = np.array([-2.6, -1.5, 2.5, 0.0, 0.0, 0.0])  # desired end state
    start_state = np.array(
        [2.7, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
         0.0])  # desired start state
    target_state = np.array(
        [-2.6, -1.5, 2.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
         0.0])  # desired end state

    eval_py_env = Quadcopter3DEnv(start_state=start_state,
                                  target_state=target_state,
                                  gating_bitmap=gating_bitmap)
    eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    # eval_env = py_environment.PyEnvironment(eval_py_env)
    # eval_env = eval_py_env

    def policy(i):
        # return tf.reshape(
        #     tf.constant([-0.1, 0.0, 0.0, 0.0], dtype=float_type),
        #     [1, -1]
        #     # tf.constant([-0.1, 0.01, 0.01, 0.01], dtype=float_type), [1, -1]
        # )
        rng = np.random.default_rng(i)
        action = (tf_agents.specs.array_spec.sample_bounded_spec(
            eval_py_env.action_spec(), rng) * 1000)
        print(action)
        return tf.reshape(action, [1, -1])

    num_episodes = 3

    ts = []
    for i in range(num_episodes):
        t = 0
        print("Episode {i}".format(i=i))
        time_step = eval_env.reset()
        while not time_step.is_last():
            action = policy(i * t)
            time_step = eval_env.step(action)
            state = time_step.observation
            # fig, ax = eval_py_env.render(state)
            fig = eval_py_env.render(state)
            fig.savefig(frames_path.format(i=i, t=t))
            t = t + 1
        ts.append(t - 1)
        # writer.append_data(imageio.imread(frames_path.format(i=i, t=t)))
    # kwargs_write = {'fps':1.0, 'quantizer':'nq'}
    # imageio.mimsave('./powers.gif', [plot_for_offset(i/4, 100) for i in range(10)], fps=1)

    with imageio.get_writer(gif_path, mode="I", fps=8) as writer:
        for i, t in zip(range(num_episodes), ts):
            for t_ in range(t):
                writer.append_data(
                    imageio.imread(frames_path.format(i=i, t=t_)))
示例#18
0
def simulate():
    # Set up the environments for the agent to train and test its performance
    envTrain = ComputerSnake.Snake()
    envEval = ComputerSnake.Snake(persistence = True)

    # Convert and wrap in TFPyEnvironment training and evaluation environments
    train_env = tf_py_environment.TFPyEnvironment(envTrain)
    eval_env = tf_py_environment.TFPyEnvironment(envEval)

    # Set up q network with necessary parameters
    fc_layer_params = (100,)
    q_net = q_network.QNetwork(
        train_env.observation_spec(),
        train_env.action_spec(),
        fc_layer_params=fc_layer_params
    )
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) # look up
    train_step_counter = tf.Variable(0)

    # Set up and initialize the DQN learning agent. It takes in the time_step spec,
    # action spec, the q network, the optimizer, a loss function, and train_step_counter
    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer, # look up
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=train_step_counter
    )
    agent.initialize()

    # Set up policies the agent can use
    eval_policy = agent.policy
    collect_policy = agent.collect_policy

    # Policy which randomly selects actions for each step
    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                    train_env.action_spec())

    #Buffer to store previous states
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

    # Dataset generates trajectories with shape [Bx2x...] This is so that the agent has access to both the current
    # and previous state to compute loss. Parallel calls and prefetching are used to optimize process.
    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=batch_size,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    # (Optional) Optimize by wrapping some of the code in a graph using TF function.
    agent.train = common.function(agent.train)

    # Reset the train step
    agent.train_step_counter.assign(0)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)

    # We initially fill the replay buffer with 100 trajectories to help the assistant
    collect_data(train_env, random_policy, replay_buffer, steps=5000)
    train_env.reset()

    # Here, we run the simulation to train the agent
    scores_list = []
    num_steps_arr = []
    for currStep in range(num_iterations):
        # Collect a few steps using collect_policy and save to the replay buffer.
        for _ in range(collect_steps_per_iteration):
            collect_step(train_env, agent.collect_policy, replay_buffer)

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss

        # Number of training steps so far
        step = agent.train_step_counter.numpy()

        # Prints every 1000 steps made by the training agent
        if step % log_interval == 0:
           print('Moves made = {0}'.format(step))

        # Evaluates the agent's policy every 5000 steps, prints results,
        # ands saves the results for later so they can be plotted
        if step % eval_interval == 0:
          avg_return = 0
          for i in range(num_eval_episodes):
              curr_return = compute_avg_return(eval_env, agent.policy, 1)
              scores_list.append(curr_return)
              num_steps_arr.append(currStep)
              avg_return += curr_return
          avg_return = avg_return/num_eval_episodes
          print('step = {0}: Average Return = {1}'.format(step, avg_return))
    plt.scatter(num_steps_arr, scores_list)
    plt.xlabel('Number of Steps Trained')
    plt.ylabel('Score')
    plt.title('Snake Reinforcement Learning')
    plt.show()
示例#19
0
num_train_iterations = 2000  # @param
num_epochs = 5  # @param
learning_rate = 1e-4  # @param
# Params for summaries and logging
log_interval = 1  # @param
use_tf_functions = True
debug_summaries = False
summarize_grads_and_vars = False

num_eval_episodes = 10  # @param
eval_interval = 10  # @param

global_step = tf.compat.v1.train.get_or_create_global_step()
tf.compat.v1.set_random_seed(0)
eval_py_env = suite_gym.load(env_name)
tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

actor_net = actor_distribution_network.ActorDistributionNetwork(
    tf_env.observation_spec(),
    tf_env.action_spec(),
    fc_layer_params=actor_fc_layers)
value_net = value_network.ValueNetwork(tf_env.observation_spec(),
                                       fc_layer_params=value_fc_layers)

tf_agent = ppo_agent.PPOAgent(
    tf_env.time_step_spec(),
    tf_env.action_spec(),
    optimizer,
    actor_net=actor_net,
    value_net=value_net,
            policy_save_handler.save("policy")
            with open("checkpoint/train_loss.pickle", "wb") as f:
                pickle.dump(all_train_loss, f)
            with open("checkpoint/all_metrics.pickle", "wb") as f:
                pickle.dump(all_metrics, f)
            with open("checkpoint/returns.pickle", "wb") as f:
                pickle.dump(returns, f)


if __name__ == '__main__':
    # tf_env = tf_py_environment.TFPyEnvironment(
    #   parallel_py_environment.ParallelPyEnvironment(
    #       [BombermanEnvironment] * N_PARALLEL_ENVIRONMENTS
    #   ))

    tf_env = tf_py_environment.TFPyEnvironment(BombermanEnvironment())
    eval_tf_env = tf_py_environment.TFPyEnvironment(BombermanEnvironment())

    q_net = QNetwork(tf_env.observation_spec(),
                     tf_env.action_spec(),
                     conv_layer_params=[(32, 3, 1), (32, 3, 1)],
                     fc_layer_params=[128, 64, 32])

    train_step = tf.Variable(0)
    update_period = 4
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)  # todo fine tune

    epsilon_fn = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=1.0,
        decay_steps=250000 // update_period,
        end_learning_rate=0.01)
示例#21
0
 def testMethodPropagation(self):
     env = self._get_py_env(True, False, batch_size=1)
     env.foo = mock.Mock()
     tf_env = tf_py_environment.TFPyEnvironment(env)
     tf_env.foo()
     env.foo.assert_called_once()
示例#22
0
def train_eval(
        root_dir,
        env_name='gym_solventx-v0',
        eval_env_name=None,
        env_load_fn=suite_gym.load,
        # The SAC paper reported:
        # Hopper and Cartpole results up to 1000000 iters,
        # Humanoid results up to 10000000 iters,
        # Other mujoco tasks up to 3000000 iters.
        num_iterations=3000000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Params for collect
        # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
        # HalfCheetah and Ant take 10000 initial collection steps.
        # Other mujoco tasks take 1000.
        # Different choices roughly keep the initial episodes about the same.
        initial_collect_steps=10000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=0.99,
        reward_scale_factor=0.1,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=5000,
        policy_checkpoint_interval=2500,
        rb_checkpoint_interval=25000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        eval_env_name = eval_env_name or env_name
        gym_env = gym.make(env_name, config_file=config_file)
        py_env = suite_gym.wrap_env(gym_env, max_episode_steps=100)
        tf_env = tf_py_environment.TFPyEnvironment(py_env)
        eval_gym_env = gym.make(eval_env_name, config_file=config_file)
        eval_py_env = suite_gym.wrap_env(eval_gym_env, max_episode_steps=100)
        eval_tf_env = tf_py_environment.TFPyEnvironment(eval_py_env)

        #tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(env_name))
        #eval_tf_env = tf_py_environment.TFPyEnvironment(env_load_fn(eval_env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network.
            TanhNormalProjectionNetwork)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=1,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        if replay_buffer.num_frames() == 0:
            initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
                tf_env,
                initial_collect_policy,
                observers=replay_observer + train_metrics,
                num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Prepare replay buffer as dataset with invalid transitions filtered.
        def _filter_invalid_transition(trajectories, unused_arg1):
            return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size, num_steps=2).unbatch().filter(
                _filter_invalid_transition).batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step_val)
                metric_utils.log_metrics(eval_metrics)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
示例#23
0
  max_episode_steps = 5000000
  # env = get_env(name='point_mass_full_goal', env_type='y', reward_type='sparse')
  # env = get_env(name='kitchen')
  env = get_env(name='playpen_reduced', task_list='rc_o', reward_type='sparse')

  base_dir = os.path.abspath('experiments/env_logs/playpen_reduced/symmetric/')
  env_log_dir = os.path.join(base_dir, 'rc_o/traj1/')
  # env = ResetFreeWrapper(env, reset_goal_frequency=500, full_reset_frequency=max_episode_steps)
  env = GoalTerminalResetWrapper(
      env,
      episodes_before_full_reset=max_episode_steps // 500,
      goal_reset_frequency=500)
  # env = Monitor(env, env_log_dir, video_callable=lambda x: x % 1 == 0, force=True)

  env = wrap_env(env)
  tf_env = tf_py_environment.TFPyEnvironment(env)
  tf_env.render = env.render
  time_step_spec = tf_env.time_step_spec()
  action_spec = tf_env.action_spec()
  policy = random_tf_policy.RandomTFPolicy(
      action_spec=action_spec, time_step_spec=time_step_spec)
  collect_data_spec = trajectory.Trajectory(
      step_type=time_step_spec.step_type,
      observation=time_step_spec.observation,
      action=action_spec,
      policy_info=policy.info_spec,
      next_step_type=time_step_spec.step_type,
      reward=time_step_spec.reward,
      discount=time_step_spec.discount)
  offline_data = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      data_spec=collect_data_spec, batch_size=1, max_length=int(1e5))
示例#24
0
文件: utils.py 项目: hr0nix/trackdays
def as_tf_env(env):
    return tf_py_environment.TFPyEnvironment(env)
示例#25
0
def train_eval(
    root_dir,
    env_name='CartPole-v0',
    num_iterations=1000,
    # TODO(b/127576522): rename to policy_fc_layers.
    actor_fc_layers=(100,),
    value_net_fc_layers=(100,),
    use_value_network=False,
    # Params for collect
    collect_episodes_per_iteration=2,
    replay_buffer_capacity=2000,
    # Params for train
    learning_rate=1e-3,
    gamma=0.9,
    gradient_clipping=None,
    normalize_returns=True,
    value_estimation_loss_coef=0.2,
    # Params for eval
    num_eval_episodes=10,
    eval_interval=100,
    # Params for checkpoints, summaries, and logging
    train_checkpoint_interval=100,
    policy_checkpoint_interval=100,
    rb_checkpoint_interval=200,
    log_interval=100,
    summary_interval=100,
    summaries_flush_secs=1,
    debug_summaries=True,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for Reinforce."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):
    eval_py_env = suite_gym.load(env_name)
    tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))

    # TODO(b/127870767): Handle distributions without gin.
    actor_net = actor_distribution_network.ActorDistributionNetwork(
        tf_env.time_step_spec().observation,
        tf_env.action_spec(),
        fc_layer_params=actor_fc_layers)

    if use_value_network:
      value_net = value_network.ValueNetwork(
          tf_env.time_step_spec().observation,
          fc_layer_params=value_net_fc_layers)

    tf_agent = reinforce_agent.ReinforceAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        actor_network=actor_net,
        value_network=value_net if use_value_network else None,
        value_estimation_loss_coef=value_estimation_loss_coef,
        gamma=gamma,
        optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate),
        normalize_returns=normalize_returns,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        tf_agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    collect_policy = tf_agent.collect_policy

    collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
        tf_env,
        collect_policy,
        observers=[replay_buffer.add_batch] + train_metrics,
        num_episodes=collect_episodes_per_iteration).run()

    experience = replay_buffer.gather_all()
    train_op = tf_agent.train(experience)
    clear_rb_op = replay_buffer.clear()

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=tf_agent.policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    summary_ops = []
    for train_metric in train_metrics:
      summary_ops.append(train_metric.tf_summaries(
          train_step=global_step, step_metrics=train_metrics[:2]))

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
      for eval_metric in eval_metrics:
        eval_metric.tf_summaries(train_step=global_step)

    init_agent_op = tf_agent.initialize()

    with tf.compat.v1.Session() as sess:
      # Initialize the graph.
      train_checkpointer.initialize_or_restore(sess)
      rb_checkpointer.initialize_or_restore(sess)
      # TODO(b/126239733): Remove once Periodically can be saved.
      common.initialize_uninitialized_variables(sess)

      sess.run(init_agent_op)
      sess.run(train_summary_writer.init())
      sess.run(eval_summary_writer.init())

      # Compute evaluation metrics.
      global_step_call = sess.make_callable(global_step)
      global_step_val = global_step_call()
      metric_utils.compute_summaries(
          eval_metrics,
          eval_py_env,
          eval_py_policy,
          num_episodes=num_eval_episodes,
          global_step=global_step_val,
          callback=eval_metrics_callback,
      )

      collect_call = sess.make_callable(collect_op)
      train_step_call = sess.make_callable([train_op, summary_ops])
      clear_rb_call = sess.make_callable(clear_rb_op)

      timed_at_step = global_step_call()
      time_acc = 0
      steps_per_second_ph = tf.compat.v1.placeholder(
          tf.float32, shape=(), name='steps_per_sec_ph')
      steps_per_second_summary = tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_second_ph,
          step=global_step)

      for _ in range(num_iterations):
        start_time = time.time()
        collect_call()
        total_loss, _ = train_step_call()
        clear_rb_call()
        time_acc += time.time() - start_time
        global_step_val = global_step_call()

        if global_step_val % log_interval == 0:
          logging.info('step = %d, loss = %f', global_step_val, total_loss.loss)
          steps_per_sec = (global_step_val - timed_at_step) / time_acc
          logging.info('%.3f steps/sec', steps_per_sec)
          sess.run(
              steps_per_second_summary,
              feed_dict={steps_per_second_ph: steps_per_sec})
          timed_at_step = global_step_val
          time_acc = 0

        if global_step_val % train_checkpoint_interval == 0:
          train_checkpointer.save(global_step=global_step_val)

        if global_step_val % policy_checkpoint_interval == 0:
          policy_checkpointer.save(global_step=global_step_val)

        if global_step_val % rb_checkpoint_interval == 0:
          rb_checkpointer.save(global_step=global_step_val)

        if global_step_val % eval_interval == 0:
          metric_utils.compute_summaries(
              eval_metrics,
              eval_py_env,
              eval_py_policy,
              num_episodes=num_eval_episodes,
              global_step=global_step_val,
              callback=eval_metrics_callback,
          )
def train_eval(
        root_dir,
        env_name='SocialBot-ICubWalkPID-v0',
        num_iterations=10000000,
        actor_fc_layers=(256, 128),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 128),
        # Params for collect
        initial_collect_steps=2000,
        collect_steps_per_iteration=1,
        replay_buffer_capacity=1000000,
        num_parallel_environments=12,
        # Params for target update
        target_update_tau=0.005,
        target_update_period=1,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=256,
        actor_learning_rate=5e-4,
        critic_learning_rate=5e-4,
        alpha_learning_rate=5e-4,
        td_errors_loss_fn=tf.compat.v1.losses.mean_squared_error,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        rb_checkpoint_interval=50000,
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=True,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for SAC."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf_env = tf_py_environment.TFPyEnvironment(
           parallel_py_environment.ParallelPyEnvironment(
               [lambda: suite_socialbot.load(env_name,wrap_with_process=False)] * num_parallel_environments))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_socialbot.load(env_name))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=normal_projection_net)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_py_metric.TFPyMetric(py_metrics.AverageReturnMetric()),
            tf_py_metric.TFPyMetric(py_metrics.AverageEpisodeLengthMetric()),
        ]

        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        collect_policy = tf_agent.collect_policy

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=eval_policy,
            global_step=global_step)
        rb_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer,
            num_steps=initial_collect_steps)

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3, sample_batch_size=batch_size,
            num_steps=2).prefetch(3)
        iterator = iter(dataset)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                experience, _ = next(iterator)
                train_loss = tf_agent.train(experience)
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (
                    global_step.numpy() - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(
                    name='global_steps_per_sec',
                    data=steps_per_sec,
                    step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(
                    train_step=global_step, step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

            global_step_val = global_step.numpy()
            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
        return train_loss
示例#27
0
def train_eval(
        root_dir,
        env_name='cartpole',
        task_name='balance',
        observations_allowlist='position',
        num_iterations=100000,
        actor_fc_layers=(400, 300),
        actor_output_fc_layers=(100, ),
        actor_lstm_size=(40, ),
        critic_obs_fc_layers=(400, ),
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(300, ),
        critic_output_fc_layers=(100, ),
        critic_lstm_size=(40, ),
        # Params for collect
        initial_collect_episodes=1,
        collect_episodes_per_iteration=1,
        replay_buffer_capacity=100000,
        exploration_noise_std=0.1,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=200,
        batch_size=64,
        actor_update_period=2,
        train_sequence_length=10,
        actor_learning_rate=1e-4,
        critic_learning_rate=1e-3,
        td_errors_loss_fn=None,
        gamma=0.995,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        use_tf_functions=True,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=10000,
        # Params for checkpoints, summaries, and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for TD3."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        if observations_allowlist is not None:
            env_wrappers = [
                functools.partial(
                    wrappers.FlattenObservationsWrapper,
                    observations_allowlist=[observations_allowlist])
            ]
        else:
            env_wrappers = []

        tf_env = tf_py_environment.TFPyEnvironment(
            suite_dm_control.load(env_name,
                                  task_name,
                                  env_wrappers=env_wrappers))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_dm_control.load(env_name,
                                  task_name,
                                  env_wrappers=env_wrappers))

        actor_net = actor_rnn_network.ActorRnnNetwork(
            tf_env.time_step_spec().observation,
            tf_env.action_spec(),
            input_fc_layer_params=actor_fc_layers,
            lstm_size=actor_lstm_size,
            output_fc_layer_params=actor_output_fc_layers)

        critic_net_input_specs = (tf_env.time_step_spec().observation,
                                  tf_env.action_spec())

        critic_net = critic_rnn_network.CriticRnnNetwork(
            critic_net_input_specs,
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            lstm_size=critic_lstm_size,
            output_fc_layer_params=critic_output_fc_layers,
        )

        tf_agent = td3_agent.Td3Agent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            exploration_noise_std=exploration_noise_std,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            actor_update_period=actor_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step,
        )
        tf_agent.initialize()

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy
        collect_policy = tf_agent.collect_policy

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=initial_collect_episodes)

        collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_episodes=collect_episodes_per_iteration)

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d episodes '
            'with a random policy.', initial_collect_episodes)
        initial_collect_driver.run()

        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(results, global_step.numpy())
        metric_utils.log_metrics(eval_metrics)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [BxTx...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=train_sequence_length +
                                           1).prefetch(3)
        iterator = iter(dataset)

        def train_step():
            experience, _ = next(iterator)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss.loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                results = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    train_step=global_step,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(results, global_step.numpy())
                metric_utils.log_metrics(eval_metrics)

        return train_loss
示例#28
0
def train_agent(iterations, modeldir, logdir, policydir):
    """Train and convert the model using TF Agents."""

    train_py_env = planestrike_py_environment.PlaneStrikePyEnvironment(
        board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2
    )
    eval_py_env = planestrike_py_environment.PlaneStrikePyEnvironment(
        board_size=BOARD_SIZE, discount=DISCOUNT, max_steps=BOARD_SIZE**2
    )

    train_env = tf_py_environment.TFPyEnvironment(train_py_env)
    eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

    # Alternatively you could use ActorDistributionNetwork as actor_net
    actor_net = tfa.networks.Sequential(
        [
            tfa.keras_layers.InnerReshape([BOARD_SIZE, BOARD_SIZE], [BOARD_SIZE**2]),
            tf.keras.layers.Dense(FC_LAYER_PARAMS, activation="relu"),
            tf.keras.layers.Dense(BOARD_SIZE**2),
            tf.keras.layers.Lambda(lambda t: tfp.distributions.Categorical(logits=t)),
        ],
        input_spec=train_py_env.observation_spec(),
    )

    optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)

    train_step_counter = tf.Variable(0)

    tf_agent = reinforce_agent.ReinforceAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        actor_network=actor_net,
        optimizer=optimizer,
        normalize_returns=True,
        train_step_counter=train_step_counter,
    )

    tf_agent.initialize()

    eval_policy = tf_agent.policy
    collect_policy = tf_agent.collect_policy

    tf_policy_saver = policy_saver.PolicySaver(collect_policy)

    # Use reverb as replay buffer
    replay_buffer_signature = tensor_spec.from_spec(tf_agent.collect_data_spec)
    table = reverb.Table(
        REPLAY_BUFFER_TABLE_NAME,
        max_size=REPLAY_BUFFER_CAPACITY,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=replay_buffer_signature,
    )  # specify signature here for validation at insertion time

    reverb_server = reverb.Server([table])

    replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
        tf_agent.collect_data_spec,
        sequence_length=None,
        table_name=REPLAY_BUFFER_TABLE_NAME,
        local_server=reverb_server,
    )

    replay_buffer_observer = reverb_utils.ReverbAddEpisodeObserver(
        replay_buffer.py_client, REPLAY_BUFFER_TABLE_NAME, REPLAY_BUFFER_CAPACITY
    )

    # Optimize by wrapping some of the code in a graph using TF function.
    tf_agent.train = common.function(tf_agent.train)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return_and_steps(
        eval_env, tf_agent.policy, NUM_EVAL_EPISODES
    )

    summary_writer = tf.summary.create_file_writer(logdir)

    for i in range(iterations):
        # Collect a few episodes using collect_policy and save to the replay buffer.
        collect_episode(
            train_py_env,
            collect_policy,
            COLLECT_EPISODES_PER_ITERATION,
            replay_buffer_observer,
        )

        # Use data from the buffer and update the agent's network.
        iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
        trajectories, _ = next(iterator)
        tf_agent.train(experience=trajectories)
        replay_buffer.clear()

        logger = tf.get_logger()
        if i % EVAL_INTERVAL == 0:
            avg_return, avg_episode_length = compute_avg_return_and_steps(
                eval_env, eval_policy, NUM_EVAL_EPISODES
            )
            with summary_writer.as_default():
                tf.summary.scalar("Average return", avg_return, step=i)
                tf.summary.scalar("Average episode length", avg_episode_length, step=i)
                summary_writer.flush()
            logger.info(
                "iteration = {0}: Average Return = {1}, Average Episode Length = {2}".format(
                    i, avg_return, avg_episode_length
                )
            )

    summary_writer.close()

    tf_policy_saver.save(policydir)
示例#29
0
def train_eval(
        root_dir,
        tf_master='',
        env_name='HalfCheetah-v1',
        env_load_fn=suite_mujoco.load,
        random_seed=0,
        # TODO(kbanoop): rename to policy_fc_layers.
        actor_fc_layers=(200, 100),
        value_fc_layers=(200, 100),
        use_rnns=False,
        # Params for collect
        num_environment_steps=10000000,
        collect_episodes_per_iteration=30,
        num_parallel_environments=30,
        replay_buffer_capacity=1001,  # Per-environment
        # Params for train
    num_epochs=25,
        learning_rate=1e-4,
        # Params for eval
        num_eval_episodes=30,
        eval_interval=500,
        # Params for summaries and logging
        train_checkpoint_interval=100,
        policy_checkpoint_interval=50,
        rb_checkpoint_interval=200,
        log_interval=50,
        summary_interval=50,
        summaries_flush_secs=1,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for PPO."""
    if root_dir is None:
        raise AttributeError('train_eval requires a root_dir.')

    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        batched_py_metric.BatchedPyMetric(
            AverageReturnMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
        batched_py_metric.BatchedPyMetric(
            AverageEpisodeLengthMetric,
            metric_args={'buffer_size': num_eval_episodes},
            batch_size=num_parallel_environments),
    ]
    eval_summary_writer_flush_op = eval_summary_writer.flush()

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        tf.compat.v1.set_random_seed(random_seed)
        eval_py_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_name)] * num_parallel_environments)
        tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment(
                [lambda: env_load_fn(env_name)] * num_parallel_environments))
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate)

        if use_rnns:
            actor_net = actor_distribution_rnn_network.ActorDistributionRnnNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                input_fc_layer_params=actor_fc_layers,
                output_fc_layer_params=None)
            value_net = value_rnn_network.ValueRnnNetwork(
                tf_env.observation_spec(),
                input_fc_layer_params=value_fc_layers,
                output_fc_layer_params=None)
        else:
            actor_net = actor_distribution_network.ActorDistributionNetwork(
                tf_env.observation_spec(),
                tf_env.action_spec(),
                fc_layer_params=actor_fc_layers)
            value_net = value_network.ValueNetwork(
                tf_env.observation_spec(), fc_layer_params=value_fc_layers)

        tf_agent = ppo_agent.PPOAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            optimizer,
            actor_net=actor_net,
            value_net=value_net,
            num_epochs=num_epochs,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)

        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            tf_agent.collect_data_spec,
            batch_size=num_parallel_environments,
            max_length=replay_buffer_capacity)

        eval_py_policy = py_tf_policy.PyTFPolicy(tf_agent.policy)

        # TODO(sguada): Reenable metrics when ready for batch data.
        environment_steps_metric = tf_metrics.EnvironmentSteps()
        environment_steps_count = environment_steps_metric.result()
        step_metrics = [
            tf_metrics.NumberOfEpisodes(),
            environment_steps_metric,
        ]
        train_metrics = step_metrics + [
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        # Add to replay buffer and other agent specific observers.
        replay_buffer_observer = [replay_buffer.add_batch]

        collect_policy = tf_agent.collect_policy

        collect_op = dynamic_episode_driver.DynamicEpisodeDriver(
            tf_env,
            collect_policy,
            observers=replay_buffer_observer + train_metrics,
            num_episodes=collect_episodes_per_iteration).run()

        trajectories = replay_buffer.gather_all()

        train_op, _ = tf_agent.train(experience=trajectories)

        with tf.control_dependencies([train_op]):
            clear_replay_op = replay_buffer.clear()

        with tf.control_dependencies([clear_replay_op]):
            train_op = tf.identity(train_op)

        train_checkpointer = common_utils.Checkpointer(
            ckpt_dir=train_dir,
            agent=tf_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
        policy_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=tf_agent.policy,
            global_step=global_step)
        rb_checkpointer = common_utils.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
            max_to_keep=1,
            replay_buffer=replay_buffer)

        for train_metric in train_metrics:
            train_metric.tf_summaries()

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(step_metrics=step_metrics)

        init_agent_op = tf_agent.initialize()

        with tf.compat.v1.Session(tf_master) as sess:
            # Initialize graph.
            train_checkpointer.initialize_or_restore(sess)
            rb_checkpointer.initialize_or_restore(sess)
            # TODO(sguada) Remove once Periodically can be saved.
            common_utils.initialize_uninitialized_variables(sess)

            sess.run(init_agent_op)
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            collect_time = 0
            train_time = 0
            timed_at_step = sess.run(global_step)
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            steps_per_second_summary = tf.contrib.summary.scalar(
                name='global_steps/sec', tensor=steps_per_second_ph)

            while sess.run(environment_steps_count) < num_environment_steps:
                global_step_val = sess.run(global_step)
                if global_step_val % eval_interval == 0:
                    metric_utils.compute_summaries(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_episodes=num_eval_episodes,
                        global_step=global_step_val,
                        callback=eval_metrics_callback,
                        log=True,
                    )
                    sess.run(eval_summary_writer_flush_op)

                start_time = time.time()
                sess.run(collect_op)
                collect_time += time.time() - start_time
                start_time = time.time()
                total_loss = sess.run(train_op)
                train_time += time.time() - start_time

                global_step_val = sess.run(global_step)
                if global_step_val % log_interval == 0:
                    logging.info('step = %d, loss = %f', global_step_val,
                                 total_loss)
                    steps_per_sec = ((global_step_val - timed_at_step) /
                                     (collect_time + train_time))
                    logging.info('%.3f steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})
                    logging.info(
                        '%s', 'collect_time = {}, train_time = {}'.format(
                            collect_time, train_time))
                    timed_at_step = global_step_val
                    collect_time = 0
                    train_time = 0

                if global_step_val % train_checkpoint_interval == 0:
                    train_checkpointer.save(global_step=global_step_val)

                if global_step_val % policy_checkpoint_interval == 0:
                    policy_checkpointer.save(global_step=global_step_val)

                if global_step_val % rb_checkpoint_interval == 0:
                    rb_checkpointer.save(global_step=global_step_val)

            # One final eval before exiting.
            metric_utils.compute_summaries(
                eval_metrics,
                eval_py_env,
                eval_py_policy,
                num_episodes=num_eval_episodes,
                global_step=global_step_val,
                callback=eval_metrics_callback,
                log=True,
            )
            sess.run(eval_summary_writer_flush_op)
示例#30
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        log_interval=1000,
        summary_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.contrib.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.contrib.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
    ]

    with tf.contrib.summary.record_summaries_every_n_global_steps(
            summary_interval):

        tf_env = tf_py_environment.TFPyEnvironment(suite_gym.load(env_name))
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            suite_gym.load(env_name))

        trajectory_spec = trajectory.from_transition(
            time_step=tf_env.time_step_spec(),
            action_step=policy_step.PolicyStep(action=tf_env.action_spec()),
            next_time_step=tf_env.time_step_spec())
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=trajectory_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)

        q_net = q_network.QNetwork(tf_env.time_step_spec().observation,
                                   tf_env.action_spec(),
                                   fc_layer_params=fc_layer_params)

        tf_agent = dqn_agent.DqnAgent(
            tf_env.time_step_spec(),
            tf_env.action_spec(),
            q_network=q_net,
            # TODO(kbanoop): Decay epsilon based on global step, cf. cl/188907839
            epsilon_greedy=epsilon_greedy,
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate),
            td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars)

        tf_agent.initialize()
        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(),
            tf_metrics.AverageEpisodeLengthMetric(),
        ]

        eval_policy = tf_agent.policy()
        collect_policy = tf_agent.collect_policy()

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=collect_steps_per_iteration)

        global_step = tf.compat.v1.train.get_or_create_global_step()

        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())

        # Collect initial replay data.
        logging.info(
            'Initializing replay buffer by collecting experience for %d steps with '
            'a random policy.', initial_collect_steps)
        dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=[replay_buffer.add_batch] + train_metrics,
            num_steps=initial_collect_steps).run()

        metrics = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
            eval_metrics_callback(metrics, global_step.numpy())

        time_step = None
        policy_state = ()

        timed_at_step = global_step.numpy()
        time_acc = 0

        # Dataset generates trajectories with shape [Bx2x...]
        dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                           sample_batch_size=batch_size,
                                           num_steps=2).prefetch(3)
        iterator = iter(dataset)

        for _ in range(num_iterations):
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            for _ in range(train_steps_per_iteration):
                experience, _ = next(iterator)
                train_loss = tf_agent.train(experience,
                                            train_step_counter=global_step)
            time_acc += time.time() - start_time

            if global_step.numpy() % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step.numpy(),
                             train_loss)
                steps_per_sec = (global_step.numpy() -
                                 timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.contrib.summary.scalar(name='global_steps/sec',
                                          tensor=steps_per_sec)
                timed_at_step = global_step.numpy()
                time_acc = 0

            for train_metric in train_metrics:
                train_metric.tf_summaries(step_metrics=train_metrics[:2])

            if global_step.numpy() % eval_interval == 0:
                metrics = metric_utils.eager_compute(
                    eval_metrics,
                    eval_tf_env,
                    eval_policy,
                    num_episodes=num_eval_episodes,
                    summary_writer=eval_summary_writer,
                    summary_prefix='Metrics',
                )
                if eval_metrics_callback is not None:
                    eval_metrics_callback(metrics, global_step.numpy())
        return train_loss