Example #1
0
def collect_data(
        log_dir,
        data_config,
        n_samples=int(1e6),
        env_name='HalfCheetah-v2',
        log_freq=int(1e4),
        n_eval_episodes=20,
):
    """Creates dataset of transitions based on desired config."""
    tf_env = train_eval_utils.env_factory(env_name)
    observation_spec = tf_env.observation_spec()
    action_spec = tf_env.action_spec()

    # Initialize dataset.
    sample_sizes = list([cfg[-1] for cfg in data_config])
    sample_sizes = get_sample_counts(n_samples, sample_sizes)
    with tf.device('/cpu:0'):
        data = dataset.Dataset(observation_spec,
                               action_spec,
                               n_samples,
                               circular=False)
    data_ckpt = tf.train.Checkpoint(data=data)
    data_ckpt_name = os.path.join(log_dir, 'data')

    # Collect data for each policy in data_config.
    time_st = time.time()
    test_results = collections.OrderedDict()
    for (policy_name, policy_cfg,
         _), n_transitions in zip(data_config, sample_sizes):
        policy_cfg = policy_loader.parse_policy_cfg(policy_cfg)
        policy = policy_loader.load_policy(policy_cfg, action_spec)
        logging.info('Testing policy %s...', policy_name)
        eval_mean, eval_std = train_eval_utils.eval_policy_episodes(
            tf_env, policy, n_eval_episodes)
        test_results[policy_name] = [eval_mean, eval_std]
        logging.info('Return mean %.4g, std %.4g.', eval_mean, eval_std)
        logging.info('Collecting data from policy %s...', policy_name)
        collect_n_transitions(tf_env, policy, data, n_transitions, log_freq)

    # Save final dataset.
    assert data.size == data.capacity
    data_ckpt.write(data_ckpt_name)
    time_cost = time.time() - time_st
    logging.info(
        'Finished: %d transitions collected, '
        'saved at %s, '
        'time cost %.4gs.', n_samples, data_ckpt_name, time_cost)
Example #2
0
def get_offline_data(tf_env):
    gym_env = tf_env.pyenv.envs[0]
    # offline_dataset = gym_env.unwrapped.get_dataset()
    offline_dataset = gym_env.get_dataset()
    dataset_size = len(offline_dataset["observations"])
    tf_dataset = dataset.Dataset(
        tf_env.observation_spec(), tf_env.action_spec(), size=dataset_size
    )
    observation_dtype = tf_env.observation_spec().dtype
    action_dtype = tf_env.action_spec().dtype

    offline_dataset["terminals"] = np.squeeze(offline_dataset["terminals"])
    offline_dataset["rewards"] = np.squeeze(offline_dataset["rewards"])
    (nonterminal_steps,) = np.where(
        np.logical_and(
            np.logical_not(offline_dataset["terminals"]),
            np.arange(dataset_size) < dataset_size - 1,
        )
    )
    logging.info(
        "Found %d non-terminal steps out of a total of %d steps."
        % (len(nonterminal_steps), dataset_size)
    )

    s1 = tf.convert_to_tensor(
        offline_dataset["observations"][nonterminal_steps], dtype=observation_dtype
    )
    s2 = tf.convert_to_tensor(
        offline_dataset["observations"][nonterminal_steps + 1], dtype=observation_dtype
    )
    a1 = tf.convert_to_tensor(
        offline_dataset["actions"][nonterminal_steps], dtype=action_dtype
    )
    a2 = tf.convert_to_tensor(
        offline_dataset["actions"][nonterminal_steps + 1], dtype=action_dtype
    )
    discount = tf.convert_to_tensor(
        1.0 - offline_dataset["terminals"][nonterminal_steps + 1], dtype=tf.float32
    )
    reward = tf.convert_to_tensor(
        offline_dataset["rewards"][nonterminal_steps], dtype=tf.float32
    )

    transitions = dataset.Transition(s1, s2, a1, a2, discount, reward)

    tf_dataset.add_transitions(transitions)
    return tf_dataset
def train_eval_online(
    # Basic args.
    log_dir,
    agent_module,
    env_name='HalfCheetah-v2',
    # Train and eval args.
    total_train_steps=int(1e6),
    summary_freq=100,
    print_freq=1000,
    save_freq=int(1e8),
    eval_freq=5000,
    n_eval_episodes=20,
    # For saving a partially trained policy.
    eval_target=None,  # Target return value to stop training.
    eval_target_n=2,  # Stop after n consecutive evals above eval_target.
    # Agent train args.
    initial_explore_steps=10000,
    replay_buffer_size=int(1e6),
    model_params=(((200, 200),), 2),
    optimizers=(('adam', 0.001),),
    batch_size=256,
    weight_decays=(0.0,),
    update_freq=1,
    update_rate=0.005,
    discount=0.99,
    ):
  """Training a policy with online interaction."""
  # Create tf_env to get specs.
  tf_env = train_eval_utils.env_factory(env_name)
  tf_env_test = train_eval_utils.env_factory(env_name)
  observation_spec = tf_env.observation_spec()
  action_spec = tf_env.action_spec()

  # Initialize dataset.
  with tf.device('/cpu:0'):
    train_data = dataset.Dataset(
        observation_spec,
        action_spec,
        replay_buffer_size,
        circular=True,
        )
  data_ckpt = tf.train.Checkpoint(data=train_data)
  data_ckpt_name = os.path.join(log_dir, 'replay')

  time_st_total = time.time()
  time_st = time.time()
  timed_at_step = 0

  # Collect data from random policy.
  explore_policy = policies.ContinuousRandomPolicy(action_spec)
  steps_collected = 0
  log_freq = 5000
  logging.info('Collecting data ...')
  collector = train_eval_utils.DataCollector(tf_env, explore_policy, train_data)
  while steps_collected < initial_explore_steps:
    count = collector.collect_transition()
    steps_collected += count
    if (steps_collected % log_freq == 0
        or steps_collected == initial_explore_steps) and count > 0:
      steps_per_sec = ((steps_collected - timed_at_step)
                       / (time.time() - time_st))
      timed_at_step = steps_collected
      time_st = time.time()
      logging.info('(%d/%d) steps collected at %.4g steps/s.', steps_collected,
                   initial_explore_steps, steps_per_sec)

  # Construct agent.
  agent_flags = utils.Flags(
      action_spec=action_spec,
      model_params=model_params,
      optimizers=optimizers,
      batch_size=batch_size,
      weight_decays=weight_decays,
      update_freq=update_freq,
      update_rate=update_rate,
      discount=discount,
      train_data=train_data)
  agent_args = agent_module.Config(agent_flags).agent_args
  agent = agent_module.Agent(**vars(agent_args))

  # Prepare savers for models and results.
  train_summary_dir = os.path.join(log_dir, 'train')
  eval_summary_dir = os.path.join(log_dir, 'eval')
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_summary_dir)
  eval_summary_writers = collections.OrderedDict()
  for policy_key in agent.test_policies.keys():
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        os.path.join(eval_summary_dir, policy_key))
    eval_summary_writers[policy_key] = eval_summary_writer
  agent_ckpt_name = os.path.join(log_dir, 'agent')
  eval_results = []

  # Train agent.
  logging.info('Start training ....')
  time_st = time.time()
  timed_at_step = 0
  target_partial_policy_saved = False
  collector = train_eval_utils.DataCollector(
      tf_env, agent.online_policy, train_data)
  for _ in range(total_train_steps):
    collector.collect_transition()
    agent.train_step()
    step = agent.global_step
    if step % summary_freq == 0 or step == total_train_steps:
      agent.write_train_summary(train_summary_writer)
    if step % print_freq == 0 or step == total_train_steps:
      agent.print_train_info()

    if step % eval_freq == 0 or step == total_train_steps:
      time_ed = time.time()
      time_cost = time_ed - time_st
      logging.info(
          'Training at %.4g steps/s.', (step - timed_at_step) / time_cost)
      eval_result, eval_infos = train_eval_utils.eval_policies(
          tf_env_test, agent.test_policies, n_eval_episodes)
      eval_results.append([step] + eval_result)
      # Cecide whether to save a partially trained policy based on current model
      # performance.
      if (eval_target is not None and len(eval_results) >= eval_target_n
          and not target_partial_policy_saved):
        evals_ = list([eval_results[-(i + 1)][1]
                       for i in range(eval_target_n)])
        evals_ = np.array(evals_)
        if np.min(evals_) >= eval_target:
          agent.save(agent_ckpt_name + '_partial_target')
          dataset.save_copy(train_data, data_ckpt_name + '_partial_target')
          logging.info('A partially trained policy was saved at step %d,'
                       ' with episodic return %.4g.', step, evals_[-1])
          target_partial_policy_saved = True
      logging.info('Testing at step %d:', step)
      for policy_key, policy_info in eval_infos.items():
        logging.info(utils.get_summary_str(
            step=None, info=policy_info, prefix=policy_key + ': '))
        utils.write_summary(eval_summary_writers[policy_key], step, policy_info)
      time_st = time.time()
      timed_at_step = step
    if step % save_freq == 0:
      agent.save(agent_ckpt_name + '-' + str(step))

  # Final save after training.
  agent.save(agent_ckpt_name + '_final')
  data_ckpt.write(data_ckpt_name + '_final')
  time_cost = time.time() - time_st_total
  logging.info('Training finished, time cost %.4gs.', time_cost)
  return np.array(eval_results)
Example #4
0
def train_eval_offline(
    # Basic args.
    log_dir,
    data_file,
    agent_module,
    env_name='HalfCheetah-v2',
    n_train=int(1e6),
    shuffle_steps=0,
    seed=0,
    use_seed_for_data=False,
    # Train and eval args.
    total_train_steps=int(1e6),
    summary_freq=100,
    print_freq=1000,
    save_freq=int(2e4),
    eval_freq=5000,
    n_eval_episodes=20,
    # Agent args.
    model_params=(((200, 200),), 2),
    optimizers=(('adam', 0.001),),
    batch_size=256,
    weight_decays=(0.0,),
    update_freq=1,
    update_rate=0.005,
    discount=0.99,
    ):
  """Training a policy with a fixed dataset."""
  # Create tf_env to get specs.
  tf_env = train_eval_utils.env_factory(env_name)
  observation_spec = tf_env.observation_spec()
  action_spec = tf_env.action_spec()

  # Prepare data.
  logging.info('Loading data from %s ...', data_file)
  data_size = utils.load_variable_from_ckpt(data_file, 'data._capacity')
  with tf.device('/cpu:0'):
    full_data = dataset.Dataset(observation_spec, action_spec, data_size)
  data_ckpt = tf.train.Checkpoint(data=full_data)
  data_ckpt.restore(data_file)
  # Split data.
  n_train = min(n_train, full_data.size)
  logging.info('n_train %s.', n_train)
  if use_seed_for_data:
    rand = np.random.RandomState(seed)
  else:
    rand = np.random.RandomState(0)
  shuffled_indices = utils.shuffle_indices_with_steps(
      n=full_data.size, steps=shuffle_steps, rand=rand)
  train_indices = shuffled_indices[:n_train]
  train_data = full_data.create_view(train_indices)

  # Create agent.
  agent_flags = utils.Flags(
      observation_spec=observation_spec,
      action_spec=action_spec,
      model_params=model_params,
      optimizers=optimizers,
      batch_size=batch_size,
      weight_decays=weight_decays,
      update_freq=update_freq,
      update_rate=update_rate,
      discount=discount,
      train_data=train_data)
  agent_args = agent_module.Config(agent_flags).agent_args
  agent = agent_module.Agent(**vars(agent_args))
  agent_ckpt_name = os.path.join(log_dir, 'agent')

  # Restore agent from checkpoint if there exists one.
  if tf.io.gfile.exists('{}.index'.format(agent_ckpt_name)):
    logging.info('Checkpoint found at %s.', agent_ckpt_name)
    agent.restore(agent_ckpt_name)

  # Train agent.
  train_summary_dir = os.path.join(log_dir, 'train')
  eval_summary_dir = os.path.join(log_dir, 'eval')
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_summary_dir)
  eval_summary_writers = collections.OrderedDict()
  for policy_key in agent.test_policies.keys():
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        os.path.join(eval_summary_dir, policy_key))
    eval_summary_writers[policy_key] = eval_summary_writer
  eval_results = []

  time_st_total = time.time()
  time_st = time.time()
  step = agent.global_step
  timed_at_step = step
  while step < total_train_steps:
    agent.train_step()
    step = agent.global_step
    if step % summary_freq == 0 or step == total_train_steps:
      agent.write_train_summary(train_summary_writer)
    if step % print_freq == 0 or step == total_train_steps:
      agent.print_train_info()
    if step % eval_freq == 0 or step == total_train_steps:
      time_ed = time.time()
      time_cost = time_ed - time_st
      logging.info(
          'Training at %.4g steps/s.', (step - timed_at_step) / time_cost)
      eval_result, eval_infos = train_eval_utils.eval_policies(
          tf_env, agent.test_policies, n_eval_episodes)
      eval_results.append([step] + eval_result)
      logging.info('Testing at step %d:', step)
      for policy_key, policy_info in eval_infos.items():
        logging.info(utils.get_summary_str(
            step=None, info=policy_info, prefix=policy_key+': '))
        utils.write_summary(eval_summary_writers[policy_key], step, policy_info)
      time_st = time.time()
      timed_at_step = step
    if step % save_freq == 0:
      agent.save(agent_ckpt_name)
      logging.info('Agent saved at %s.', agent_ckpt_name)

  agent.save(agent_ckpt_name)
  time_cost = time.time() - time_st_total
  logging.info('Training finished, time cost %.4gs.', time_cost)
  return np.array(eval_results)