Beispiel #1
0
def main(_):
    if FLAGS.eager:
        tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    tf.random.set_seed(FLAGS.seed)
    # np.random.seed(FLAGS.seed)
    # random.seed(FLAGS.seed)

    if 'procgen' in FLAGS.env_name:
        _, env_name, train_levels, _ = FLAGS.env_name.split('-')
        env = procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,
            env_name=env_name,
            num_levels=int(train_levels),
            start_level=0)

    elif FLAGS.env_name.startswith('pixels-dm'):
        if 'distractor' in FLAGS.env_name:
            _, _, domain_name, _, _ = FLAGS.env_name.split('-')
        else:
            _, _, domain_name, _ = FLAGS.env_name.split('-')

        if domain_name in ['cartpole']:
            FLAGS.set_default('action_repeat', 8)
        elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
            FLAGS.set_default('action_repeat', 4)
        elif domain_name in ['finger', 'walker']:
            FLAGS.set_default('action_repeat', 2)

        env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                FLAGS.action_repeat, FLAGS.frame_stack,
                                FLAGS.obs_type)

    hparam_str = utils.make_hparam_string(FLAGS.xm_parameters,
                                          algo_name=FLAGS.algo_name,
                                          seed=FLAGS.seed,
                                          task_name=FLAGS.env_name,
                                          ckpt_timesteps=FLAGS.ckpt_timesteps)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    if FLAGS.env_name.startswith('procgen'):
        # map env string to digit [1,16]
        env_id = [
            i for i, name in enumerate(PROCGEN_ENVS) if name == env_name
        ][0] + 1
        if FLAGS.ckpt_timesteps == 10_000_000:
            ckpt_iter = '0000020480'
        elif FLAGS.ckpt_timesteps == 25_000_000:
            ckpt_iter = '0000051200'
        policy_weights_dir = ('ppo_darts/'
                              '2021-06-22-16-36-54/%d/policies/checkpoints/'
                              'policy_checkpoint_%s/' % (env_id, ckpt_iter))
        policy_def_dir = ('ppo_darts/'
                          '2021-06-22-16-36-54/%d/policies/policy/' % (env_id))
        model = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            policy_def_dir,
            time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
            action_spec=env._action_spec,  # pylint: disable=protected-access
            policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
            info_spec=tf.TensorSpec(shape=(None, )),
            load_specs_from_pbtxt=False)
        model.update_from_checkpoint(policy_weights_dir)
        model = TfAgentsPolicy(model)
    else:
        if 'ddpg' in FLAGS.algo_name:
            model = ddpg.DDPG(env.observation_spec(),
                              env.action_spec(),
                              cross_norm='crossnorm' in FLAGS.algo_name)
        elif 'crr' in FLAGS.algo_name:
            model = awr.AWR(env.observation_spec(),
                            env.action_spec(),
                            f='bin_max')
        elif 'awr' in FLAGS.algo_name:
            model = awr.AWR(env.observation_spec(),
                            env.action_spec(),
                            f='exp_mean')
        elif 'sac_v1' in FLAGS.algo_name:
            model = sac_v1.SAC(env.observation_spec(),
                               env.action_spec(),
                               target_entropy=-env.action_spec().shape[0])
        elif 'asac' in FLAGS.algo_name:
            model = asac.ASAC(env.observation_spec(),
                              env.action_spec(),
                              target_entropy=-env.action_spec().shape[0])
        elif 'sac' in FLAGS.algo_name:
            model = sac.SAC(env.observation_spec(),
                            env.action_spec(),
                            target_entropy=-env.action_spec().shape[0],
                            cross_norm='crossnorm' in FLAGS.algo_name,
                            pcl_actor_update='pc' in FLAGS.algo_name)
        elif 'pcl' in FLAGS.algo_name:
            model = pcl.PCL(env.observation_spec(),
                            env.action_spec(),
                            target_entropy=-env.action_spec().shape[0])
        if 'distractor' in FLAGS.env_name:
            ckpt_path = os.path.join(
                ('experiments/20210622_2023.policy_weights_sac'
                 '_1M_dmc_distractor_hard_pixel/'), 'results',
                FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps))
        else:
            ckpt_path = os.path.join(
                ('experiments/20210607_2023.'
                 'policy_weights_dmc_1M_SAC_pixel'), 'results',
                FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps))

        model.load_weights(ckpt_path)
    print('Loaded model weights')

    with summary_writer.as_default():
        env = procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,
            env_name=env_name,
            num_levels=0,
            start_level=0)
        (avg_returns,
         avg_len) = evaluation.evaluate(env,
                                        model,
                                        num_episodes=100,
                                        return_distributions=False)
        tf.summary.scalar('evaluation/returns-all', avg_returns, step=0)
        tf.summary.scalar('evaluation/length-all', avg_len, step=0)
def main(_):
  tf.config.experimental_run_functions_eagerly(FLAGS.eager)

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  FLAGS.set_default('max_timesteps', FLAGS.max_timesteps // FLAGS.action_repeat)

  if 'pixels-dm' in FLAGS.env_name:
    if 'distractor' in FLAGS.env_name:
      _, _, domain_name, _, _ = FLAGS.env_name.split('-')
    else:
      _, _, domain_name, _ = FLAGS.env_name.split('-')

    if domain_name in ['cartpole']:
      FLAGS.set_default('action_repeat', 8)
    elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
      FLAGS.set_default('action_repeat', 4)
    elif domain_name in ['finger', 'walker']:
      FLAGS.set_default('action_repeat', 2)
    print('Loading env')
    env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat,
                            FLAGS.frame_stack)
    eval_env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                 FLAGS.action_repeat, FLAGS.frame_stack)
    print('Env loaded')
  else:
    raise Exception('Unsupported env')

  is_image_obs = (isinstance(env.observation_spec(), TensorSpec) and
                  len(env.observation_spec().shape) == 3)

  spec = (
      env.observation_spec(),
      env.action_spec(),
      env.reward_spec(),
      env.reward_spec(),  # discount spec
      env.observation_spec()  # next observation spec
  )
  print('Init replay')

  replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      spec, batch_size=1, max_length=FLAGS.max_length_replay_buffer)
  print('Replay created')
  @tf.function
  def add_to_replay(state, action, reward, discount, next_states):
    replay_buffer.add_batch((state, action, reward, discount, next_states))

  hparam_str = utils.make_hparam_string(
      FLAGS.xm_parameters, seed=FLAGS.seed, env_name=FLAGS.env_name,
      algo_name=FLAGS.algo_name)
  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'tb', hparam_str))
  results_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'results', hparam_str))
  print('Init actor')
  if 'ddpg' in FLAGS.algo_name:
    model = ddpg.DDPG(
        env.observation_spec(),
        env.action_spec(),
        cross_norm='crossnorm' in FLAGS.algo_name)
  elif 'crr' in FLAGS.algo_name:
    model = awr.AWR(
        env.observation_spec(),
        env.action_spec(), f='bin_max')
  elif 'awr' in FLAGS.algo_name:
    model = awr.AWR(
        env.observation_spec(),
        env.action_spec(), f='exp_mean')
  elif 'sac_v1' in FLAGS.algo_name:
    model = sac_v1.SAC(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0])
  elif 'asac' in FLAGS.algo_name:
    model = asac.ASAC(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0])
  elif 'sac' in FLAGS.algo_name:
    model = sac.SAC(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0],
        cross_norm='crossnorm' in FLAGS.algo_name,
        pcl_actor_update='pc' in FLAGS.algo_name)
  elif 'pcl' in FLAGS.algo_name:
    model = pcl.PCL(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0])

  print('Init random policy for warmup')
  initial_collect_policy = random_tf_policy.RandomTFPolicy(
      env.time_step_spec(), env.action_spec())
  print('Init replay buffer')
  dataset = replay_buffer.as_dataset(
      num_parallel_calls=tf.data.AUTOTUNE,
      sample_batch_size=FLAGS.sample_batch_size)
  if is_image_obs:
    dataset = dataset.map(image_aug,
                          num_parallel_calls=tf.data.AUTOTUNE,
                          deterministic=False).prefetch(3)
  else:
    dataset = dataset.prefetch(3)

  def repack(*data):
    return data[0]
  dataset = dataset.map(repack)
  replay_buffer_iter = iter(dataset)

  previous_time = time.time()
  timestep = env.reset()
  episode_return = 0
  episode_timesteps = 0
  step_mult = 1 if FLAGS.action_repeat < 1 else FLAGS.action_repeat

  print('Starting training')
  for i in tqdm.tqdm(range(FLAGS.max_timesteps)):
    if i % FLAGS.deployment_batch_size == 0:
      for _ in range(FLAGS.deployment_batch_size):
        last_timestep = timestep.is_last()
        if last_timestep:

          if episode_timesteps > 0:
            current_time = time.time()
            with summary_writer.as_default():
              tf.summary.scalar(
                  'train/returns',
                  episode_return,
                  step=(i + 1) * step_mult)
              tf.summary.scalar(
                  'train/FPS',
                  episode_timesteps / (current_time - previous_time),
                  step=(i + 1) * step_mult)

          timestep = env.reset()
          episode_return = 0
          episode_timesteps = 0
          previous_time = time.time()

        if (replay_buffer.num_frames() < FLAGS.num_random_actions or
            replay_buffer.num_frames() < FLAGS.deployment_batch_size):
          # Use policy only after the first deployment.
          policy_step = initial_collect_policy.action(timestep)
          action = policy_step.action
        else:
          action = model.actor(timestep.observation, sample=True)

        next_timestep = env.step(action)

        add_to_replay(timestep.observation, action, next_timestep.reward,
                      next_timestep.discount, next_timestep.observation)

        episode_return += next_timestep.reward[0]
        episode_timesteps += 1

        timestep = next_timestep

    if i + 1 >= FLAGS.start_training_timesteps:
      with summary_writer.as_default():
        info_dict = model.update_step(replay_buffer_iter)
      if (i + 1) % FLAGS.log_interval == 0:
        with summary_writer.as_default():
          for k, v in info_dict.items():
            tf.summary.scalar(f'training/{k}', v, step=(i + 1) * step_mult)

    if (i + 1) % FLAGS.eval_interval == 0:
      logging.info('Performing policy eval.')
      average_returns, evaluation_timesteps = evaluation.evaluate(
          eval_env, model)

      with results_writer.as_default():
        tf.summary.scalar(
            'evaluation/returns', average_returns, step=(i + 1) * step_mult)
        tf.summary.scalar(
            'evaluation/length', evaluation_timesteps, step=(i+1) * step_mult)
      logging.info('Eval at %d: ave returns=%f, ave episode length=%f',
                   (i + 1) * step_mult, average_returns, evaluation_timesteps)

    if ((i + 1) * step_mult) % 50_000 == 0:
      model.save_weights(
          os.path.join(FLAGS.save_dir, 'results', FLAGS.env_name + '__' + str(
              (i + 1) * step_mult)))
Beispiel #3
0
  def load_model(checkpoint):
    checkpoint = int(checkpoint)
    print(checkpoint)
    if FLAGS.env_name.startswith('procgen'):
      env_id = [i for i, name in enumerate(
          PROCGEN_ENVS) if name == env_name][0]+1
      if checkpoint == 10_000_000:
        ckpt_iter = '0000020480'
      elif checkpoint == 15_000_000:
        ckpt_iter = '0000030720'
      elif checkpoint == 20_000_000:
        ckpt_iter = '0000040960'
      elif checkpoint == 25_000_000:
        ckpt_iter = '0000051200'
      policy_weights_dir = ('ppo_darts/'
                            '2021-06-22-16-36-54/%d/policies/checkpoints/'
                            'policy_checkpoint_%s/' % (env_id, ckpt_iter))
      policy_def_dir = ('ppo_darts/'
                        '2021-06-22-16-36-54/%d/policies/policy/' % (env_id))
      model = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
          policy_def_dir,
          time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
          action_spec=env._action_spec,  # pylint: disable=protected-access
          policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
          info_spec=tf.TensorSpec(shape=(None,)),
          load_specs_from_pbtxt=False)
      model.update_from_checkpoint(policy_weights_dir)
      model.actor = model.action
    else:
      if 'ddpg' in FLAGS.algo_name:
        model = ddpg.DDPG(
            env.observation_spec(),
            env.action_spec(),
            cross_norm='crossnorm' in FLAGS.algo_name)
      elif 'crr' in FLAGS.algo_name:
        model = awr.AWR(
            env.observation_spec(),
            env.action_spec(), f='bin_max')
      elif 'awr' in FLAGS.algo_name:
        model = awr.AWR(
            env.observation_spec(),
            env.action_spec(), f='exp_mean')
      elif 'sac_v1' in FLAGS.algo_name:
        model = sac_v1.SAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      elif 'asac' in FLAGS.algo_name:
        model = asac.ASAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      elif 'sac' in FLAGS.algo_name:
        model = sac.SAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0],
            cross_norm='crossnorm' in FLAGS.algo_name,
            pcl_actor_update='pc' in FLAGS.algo_name)
      elif 'pcl' in FLAGS.algo_name:
        model = pcl.PCL(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      if 'distractor' in FLAGS.env_name:
        ckpt_path = os.path.join(
            ('experiments/'
             '20210622_2023.policy_weights_sac_1M_dmc_distractor_hard_pixel/'),
            'results', FLAGS.env_name+'__'+str(checkpoint))
      else:
        ckpt_path = os.path.join(
            ('experiments/'
             '20210607_2023.policy_weights_dmc_1M_SAC_pixel'), 'results',
            FLAGS.env_name + '__' + str(checkpoint))

      model.load_weights(ckpt_path)
    print('Loaded model weights')
    return model