Ejemplo n.º 1
0
def get_modules(model_params, action_spec):
    """Gets Tensorflow modules for Q-function, policy, and behavior."""
    if len(model_params) == 1:
        model_params = tuple([model_params[0]] * 3)
    elif len(model_params) < 3:
        raise ValueError('Bad model parameters %s.' % model_params)
    model_params, n_q_fns, max_perturbation = model_params

    def q_net_factory():
        return networks.CriticNetwork(fc_layer_params=model_params[0])

    def p_net_factory():
        return networks.BCQActorNetwork(
            action_spec,
            fc_layer_params=model_params[1],
            max_perturbation=max_perturbation,
        )

    def b_net_factory():
        return networks.BCQVAENetwork(action_spec,
                                      fc_layer_params=model_params[2])

    modules = utils.Flags(
        q_net_factory=q_net_factory,
        p_net_factory=p_net_factory,
        b_net_factory=b_net_factory,
        n_q_fns=n_q_fns,
    )
    return modules
Ejemplo n.º 2
0
 def _get_agent_args(self):
   """Gets agent parameters associated with config."""
   agent_flags = self._agent_flags
   agent_args = utils.Flags(
       action_spec=agent_flags.action_spec,
       optimizers=agent_flags.optimizers,
       batch_size=agent_flags.batch_size,
       weight_decays=agent_flags.weight_decays,
       update_rate=agent_flags.update_rate,
       update_freq=agent_flags.update_freq,
       discount=agent_flags.discount,
       train_data=agent_flags.train_data,
       )
   agent_args.modules = self._get_modules()
   return agent_args
Ejemplo n.º 3
0
def get_modules(model_params, action_spec):
  """Creates modules for Q-value and policy."""
  model_params, n_q_fns = model_params
  if len(model_params) == 1:
    model_params = tuple([model_params[0]] * 2)
  elif len(model_params) < 2:
    raise ValueError('Bad model parameters %s.' % model_params)
  def q_net_factory():
    return networks.CriticNetwork(
        fc_layer_params=model_params[0])
  def p_net_factory():
    return networks.ActorNetwork(
        action_spec,
        fc_layer_params=model_params[1])
  modules = utils.Flags(
      q_net_factory=q_net_factory,
      p_net_factory=p_net_factory,
      n_q_fns=n_q_fns,
      )
  return modules
Ejemplo n.º 4
0
def get_modules(model_params, action_spec):
  """Gets Tensorflow modules for Q-function, policy, and discriminator."""
  model_params, n_q_fns = model_params
  if len(model_params) == 1:
    model_params = tuple([model_params[0]] * 3)
  elif len(model_params) < 3:
    raise ValueError('Bad model parameters %s.' % model_params)
  def q_net_factory():
    return networks.CriticNetwork(
        fc_layer_params=model_params[0])
  def p_net_factory():
    return networks.ActorNetwork(
        action_spec,
        fc_layer_params=model_params[1])
  def c_net_factory():
    return networks.CriticNetwork(
        fc_layer_params=model_params[2])
  modules = utils.Flags(
      q_net_factory=q_net_factory,
      p_net_factory=p_net_factory,
      c_net_factory=c_net_factory,
      n_q_fns=n_q_fns,
      )
  return modules
def train_eval_online(
    # Basic args.
    log_dir,
    agent_module,
    env_name='HalfCheetah-v2',
    # Train and eval args.
    total_train_steps=int(1e6),
    summary_freq=100,
    print_freq=1000,
    save_freq=int(1e8),
    eval_freq=5000,
    n_eval_episodes=20,
    # For saving a partially trained policy.
    eval_target=None,  # Target return value to stop training.
    eval_target_n=2,  # Stop after n consecutive evals above eval_target.
    # Agent train args.
    initial_explore_steps=10000,
    replay_buffer_size=int(1e6),
    model_params=(((200, 200),), 2),
    optimizers=(('adam', 0.001),),
    batch_size=256,
    weight_decays=(0.0,),
    update_freq=1,
    update_rate=0.005,
    discount=0.99,
    ):
  """Training a policy with online interaction."""
  # Create tf_env to get specs.
  tf_env = train_eval_utils.env_factory(env_name)
  tf_env_test = train_eval_utils.env_factory(env_name)
  observation_spec = tf_env.observation_spec()
  action_spec = tf_env.action_spec()

  # Initialize dataset.
  with tf.device('/cpu:0'):
    train_data = dataset.Dataset(
        observation_spec,
        action_spec,
        replay_buffer_size,
        circular=True,
        )
  data_ckpt = tf.train.Checkpoint(data=train_data)
  data_ckpt_name = os.path.join(log_dir, 'replay')

  time_st_total = time.time()
  time_st = time.time()
  timed_at_step = 0

  # Collect data from random policy.
  explore_policy = policies.ContinuousRandomPolicy(action_spec)
  steps_collected = 0
  log_freq = 5000
  logging.info('Collecting data ...')
  collector = train_eval_utils.DataCollector(tf_env, explore_policy, train_data)
  while steps_collected < initial_explore_steps:
    count = collector.collect_transition()
    steps_collected += count
    if (steps_collected % log_freq == 0
        or steps_collected == initial_explore_steps) and count > 0:
      steps_per_sec = ((steps_collected - timed_at_step)
                       / (time.time() - time_st))
      timed_at_step = steps_collected
      time_st = time.time()
      logging.info('(%d/%d) steps collected at %.4g steps/s.', steps_collected,
                   initial_explore_steps, steps_per_sec)

  # Construct agent.
  agent_flags = utils.Flags(
      action_spec=action_spec,
      model_params=model_params,
      optimizers=optimizers,
      batch_size=batch_size,
      weight_decays=weight_decays,
      update_freq=update_freq,
      update_rate=update_rate,
      discount=discount,
      train_data=train_data)
  agent_args = agent_module.Config(agent_flags).agent_args
  agent = agent_module.Agent(**vars(agent_args))

  # Prepare savers for models and results.
  train_summary_dir = os.path.join(log_dir, 'train')
  eval_summary_dir = os.path.join(log_dir, 'eval')
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_summary_dir)
  eval_summary_writers = collections.OrderedDict()
  for policy_key in agent.test_policies.keys():
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        os.path.join(eval_summary_dir, policy_key))
    eval_summary_writers[policy_key] = eval_summary_writer
  agent_ckpt_name = os.path.join(log_dir, 'agent')
  eval_results = []

  # Train agent.
  logging.info('Start training ....')
  time_st = time.time()
  timed_at_step = 0
  target_partial_policy_saved = False
  collector = train_eval_utils.DataCollector(
      tf_env, agent.online_policy, train_data)
  for _ in range(total_train_steps):
    collector.collect_transition()
    agent.train_step()
    step = agent.global_step
    if step % summary_freq == 0 or step == total_train_steps:
      agent.write_train_summary(train_summary_writer)
    if step % print_freq == 0 or step == total_train_steps:
      agent.print_train_info()

    if step % eval_freq == 0 or step == total_train_steps:
      time_ed = time.time()
      time_cost = time_ed - time_st
      logging.info(
          'Training at %.4g steps/s.', (step - timed_at_step) / time_cost)
      eval_result, eval_infos = train_eval_utils.eval_policies(
          tf_env_test, agent.test_policies, n_eval_episodes)
      eval_results.append([step] + eval_result)
      # Cecide whether to save a partially trained policy based on current model
      # performance.
      if (eval_target is not None and len(eval_results) >= eval_target_n
          and not target_partial_policy_saved):
        evals_ = list([eval_results[-(i + 1)][1]
                       for i in range(eval_target_n)])
        evals_ = np.array(evals_)
        if np.min(evals_) >= eval_target:
          agent.save(agent_ckpt_name + '_partial_target')
          dataset.save_copy(train_data, data_ckpt_name + '_partial_target')
          logging.info('A partially trained policy was saved at step %d,'
                       ' with episodic return %.4g.', step, evals_[-1])
          target_partial_policy_saved = True
      logging.info('Testing at step %d:', step)
      for policy_key, policy_info in eval_infos.items():
        logging.info(utils.get_summary_str(
            step=None, info=policy_info, prefix=policy_key + ': '))
        utils.write_summary(eval_summary_writers[policy_key], step, policy_info)
      time_st = time.time()
      timed_at_step = step
    if step % save_freq == 0:
      agent.save(agent_ckpt_name + '-' + str(step))

  # Final save after training.
  agent.save(agent_ckpt_name + '_final')
  data_ckpt.write(data_ckpt_name + '_final')
  time_cost = time.time() - time_st_total
  logging.info('Training finished, time cost %.4gs.', time_cost)
  return np.array(eval_results)
Ejemplo n.º 6
0
def train_eval_offline(
        # Basic args.
        log_dir,
        data_file,
        agent_module,
        env_name='HalfCheetah-v2',
        n_train=int(1e6),
        shuffle_steps=0,
        seed=0,
        use_seed_for_data=False,
        # Train and eval args.
        total_train_steps=int(1e6),
        summary_freq=100,
        print_freq=1000,
        save_freq=int(2e4),
        eval_freq=5000,
        n_eval_episodes=20,
        # Agent args.
        model_params=(((200, 200), ), 2),
        behavior_ckpt_file=None,
        value_penalty=True,
        alpha=1.0,
        #model_params=((200, 200),),
        optimizers=(('adam', 0.001), ),
        batch_size=256,
        weight_decays=(0.0, ),
        update_freq=1,
        update_rate=0.005,
        discount=0.99,
):
    """Training a policy with a fixed dataset."""
    # Create tf_env to get specs.
    print('[train_eval_offline.py] env_name=', env_name)
    print('[train_eval_offline.py] data_file=', data_file)
    print('[train_eval_offline.py] agent_module=', agent_module)
    print('[train_eval_offline.py] model_params=', model_params)
    print('[train_eval_offline.py] optimizers=', optimizers)
    print('[train_eval_offline.py] bckpt_file=', behavior_ckpt_file)
    print('[train_eval_offline.py] value_penalty=', value_penalty)

    tf_env = env_factory(env_name)
    observation_spec = tf_env.observation_spec()
    action_spec = tf_env.action_spec()

    # Prepare data.
    full_data = get_offline_data(tf_env)

    # Split data.
    n_train = min(n_train, full_data.size)
    logging.info('n_train %s.', n_train)
    if use_seed_for_data:
        rand = np.random.RandomState(seed)
    else:
        rand = np.random.RandomState(0)
    shuffled_indices = utils.shuffle_indices_with_steps(n=full_data.size,
                                                        steps=shuffle_steps,
                                                        rand=rand)
    train_indices = shuffled_indices[:n_train]
    train_data = full_data.create_view(train_indices)

    # Create agent.
    agent_flags = utils.Flags(observation_spec=observation_spec,
                              action_spec=action_spec,
                              model_params=model_params,
                              optimizers=optimizers,
                              batch_size=batch_size,
                              weight_decays=weight_decays,
                              update_freq=update_freq,
                              update_rate=update_rate,
                              discount=discount,
                              train_data=train_data)
    agent_args = agent_module.Config(agent_flags).agent_args
    my_agent_arg_dict = {}
    for k in vars(agent_args):
        my_agent_arg_dict[k] = vars(agent_args)[k]
    if 'brac_primal' in agent_module.__name__:
        my_agent_arg_dict['behavior_ckpt_file'] = behavior_ckpt_file
        my_agent_arg_dict['value_penalty'] = value_penalty
        my_agent_arg_dict['alpha'] = alpha
    print('agent:', agent_module.__name__)
    print('agent_args:', my_agent_arg_dict)
    #agent = agent_module.Agent(**vars(agent_args))
    agent = agent_module.Agent(**my_agent_arg_dict)
    agent_ckpt_name = os.path.join(log_dir, 'agent')

    # Restore agent from checkpoint if there exists one.
    if tf.io.gfile.exists('{}.index'.format(agent_ckpt_name)):
        logging.info('Checkpoint found at %s.', agent_ckpt_name)
        agent.restore(agent_ckpt_name)

    # Train agent.
    train_summary_dir = os.path.join(log_dir, 'train')
    eval_summary_dir = os.path.join(log_dir, 'eval')
    train_summary_writer = tf0.compat.v2.summary.create_file_writer(
        train_summary_dir)
    eval_summary_writers = collections.OrderedDict()
    for policy_key in agent.test_policies.keys():
        eval_summary_writer = tf0.compat.v2.summary.create_file_writer(
            os.path.join(eval_summary_dir, policy_key))
        eval_summary_writers[policy_key] = eval_summary_writer
    eval_results = []

    time_st_total = time.time()
    time_st = time.time()
    step = agent.global_step
    timed_at_step = step
    while step < total_train_steps:
        agent.train_step()
        step = agent.global_step
        if step % summary_freq == 0 or step == total_train_steps:
            agent.write_train_summary(train_summary_writer)
        if step % print_freq == 0 or step == total_train_steps:
            agent.print_train_info()
        if step % eval_freq == 0 or step == total_train_steps:
            time_ed = time.time()
            time_cost = time_ed - time_st
            logging.info('Training at %.4g steps/s.',
                         (step - timed_at_step) / time_cost)
            eval_result, eval_infos = train_eval_utils.eval_policies(
                tf_env, agent.test_policies, n_eval_episodes)
            eval_results.append([step] + eval_result)
            with open(os.path.join(log_dir, 'results.txt'), 'a') as logfile:
                logfile.write(str(eval_result) + '\n')
            logging.info('Testing at step %d:', step)
            for policy_key, policy_info in eval_infos.items():
                logging.info(
                    utils.get_summary_str(step=None,
                                          info=policy_info,
                                          prefix=policy_key + ': '))
                utils.write_summary(eval_summary_writers[policy_key], step,
                                    policy_info)
            time_st = time.time()
            timed_at_step = step
        if step % save_freq == 0:
            agent.save(agent_ckpt_name)
            logging.info('Agent saved at %s.', agent_ckpt_name)

    agent.save(agent_ckpt_name)
    time_cost = time.time() - time_st_total
    logging.info('Training finished, time cost %.4gs.', time_cost)
    return np.array(eval_results)
Ejemplo n.º 7
0
def train_eval_offline(
    # Basic args.
    log_dir,
    data_file,
    agent_module,
    env_name='HalfCheetah-v2',
    n_train=int(1e6),
    shuffle_steps=0,
    seed=0,
    use_seed_for_data=False,
    # Train and eval args.
    total_train_steps=int(1e6),
    summary_freq=100,
    print_freq=1000,
    save_freq=int(2e4),
    eval_freq=5000,
    n_eval_episodes=20,
    # Agent args.
    model_params=(((200, 200),), 2),
    optimizers=(('adam', 0.001),),
    batch_size=256,
    weight_decays=(0.0,),
    update_freq=1,
    update_rate=0.005,
    discount=0.99,
    ):
  """Training a policy with a fixed dataset."""
  # Create tf_env to get specs.
  tf_env = train_eval_utils.env_factory(env_name)
  observation_spec = tf_env.observation_spec()
  action_spec = tf_env.action_spec()

  # Prepare data.
  logging.info('Loading data from %s ...', data_file)
  data_size = utils.load_variable_from_ckpt(data_file, 'data._capacity')
  with tf.device('/cpu:0'):
    full_data = dataset.Dataset(observation_spec, action_spec, data_size)
  data_ckpt = tf.train.Checkpoint(data=full_data)
  data_ckpt.restore(data_file)
  # Split data.
  n_train = min(n_train, full_data.size)
  logging.info('n_train %s.', n_train)
  if use_seed_for_data:
    rand = np.random.RandomState(seed)
  else:
    rand = np.random.RandomState(0)
  shuffled_indices = utils.shuffle_indices_with_steps(
      n=full_data.size, steps=shuffle_steps, rand=rand)
  train_indices = shuffled_indices[:n_train]
  train_data = full_data.create_view(train_indices)

  # Create agent.
  agent_flags = utils.Flags(
      observation_spec=observation_spec,
      action_spec=action_spec,
      model_params=model_params,
      optimizers=optimizers,
      batch_size=batch_size,
      weight_decays=weight_decays,
      update_freq=update_freq,
      update_rate=update_rate,
      discount=discount,
      train_data=train_data)
  agent_args = agent_module.Config(agent_flags).agent_args
  agent = agent_module.Agent(**vars(agent_args))
  agent_ckpt_name = os.path.join(log_dir, 'agent')

  # Restore agent from checkpoint if there exists one.
  if tf.io.gfile.exists('{}.index'.format(agent_ckpt_name)):
    logging.info('Checkpoint found at %s.', agent_ckpt_name)
    agent.restore(agent_ckpt_name)

  # Train agent.
  train_summary_dir = os.path.join(log_dir, 'train')
  eval_summary_dir = os.path.join(log_dir, 'eval')
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_summary_dir)
  eval_summary_writers = collections.OrderedDict()
  for policy_key in agent.test_policies.keys():
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        os.path.join(eval_summary_dir, policy_key))
    eval_summary_writers[policy_key] = eval_summary_writer
  eval_results = []

  time_st_total = time.time()
  time_st = time.time()
  step = agent.global_step
  timed_at_step = step
  while step < total_train_steps:
    agent.train_step()
    step = agent.global_step
    if step % summary_freq == 0 or step == total_train_steps:
      agent.write_train_summary(train_summary_writer)
    if step % print_freq == 0 or step == total_train_steps:
      agent.print_train_info()
    if step % eval_freq == 0 or step == total_train_steps:
      time_ed = time.time()
      time_cost = time_ed - time_st
      logging.info(
          'Training at %.4g steps/s.', (step - timed_at_step) / time_cost)
      eval_result, eval_infos = train_eval_utils.eval_policies(
          tf_env, agent.test_policies, n_eval_episodes)
      eval_results.append([step] + eval_result)
      logging.info('Testing at step %d:', step)
      for policy_key, policy_info in eval_infos.items():
        logging.info(utils.get_summary_str(
            step=None, info=policy_info, prefix=policy_key+': '))
        utils.write_summary(eval_summary_writers[policy_key], step, policy_info)
      time_st = time.time()
      timed_at_step = step
    if step % save_freq == 0:
      agent.save(agent_ckpt_name)
      logging.info('Agent saved at %s.', agent_ckpt_name)

  agent.save(agent_ckpt_name)
  time_cost = time.time() - time_st_total
  logging.info('Training finished, time cost %.4gs.', time_cost)
  return np.array(eval_results)