Exemplo n.º 1
0
 def log_probs(self, states, actions):
     ts_ = trajectories.TimeStep(
         tf.stack([trajectories.StepType.MID] * 256, 0),
         tf.constant([1.] * 256, dtype=tf.float32),
         tf.constant([1.] * 256, dtype=tf.float32),
         tf.cast(states, tf.float32))
     dist = self.policy.distribution(ts_)
     return dist.action.log_prob(actions)
Exemplo n.º 2
0
    def act(self, states):
        """Act from states.

    Args:
      states: batch of states

    Returns:
      actions
    """
        ts = trajectories.TimeStep(trajectories.StepType.MID,
                                   tf.constant(1, dtype=tf.float32),
                                   tf.constant(1, dtype=tf.float32),
                                   tf.cast(states[0], tf.float32))
        ts2 = trajectories.TimeStep(
            tf.expand_dims(trajectories.StepType.MID, 0),
            tf.constant([1], dtype=tf.float32),
            tf.constant([1], dtype=tf.float32), tf.cast(states, tf.float32))

        act_d = self.policy.action(ts)
        action = tf.constant(act_d.action.item())
        log_prob = self.policy._policy.distribution(ts2).action.log_prob(  # pylint: disable=protected-access
            act_d.action)
        return action, log_prob
Exemplo n.º 3
0
def main(_):
  if FLAGS.eager:
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

  tf.random.set_seed(FLAGS.seed)
  # np.random.seed(FLAGS.seed)
  # random.seed(FLAGS.seed)

  print('Env name: %s'%FLAGS.env_name)

  if 'procgen' in FLAGS.env_name:
    _, env_name, train_levels, _ = FLAGS.env_name.split('-')
    env = procgen_wrappers.TFAgentsParallelProcGenEnv(
        1,
        normalize_rewards=True,
        env_name=env_name,
        num_levels=int(train_levels),
        start_level=0)
    state_env = None

    timestep_spec = trajectories.time_step_spec(
        observation_spec=specs.ArraySpec(env._observation_spec.shape, np.uint8),  # pylint: disable=protected-access
        reward_spec=specs.ArraySpec(shape=(), dtype=np.float32))

    data_spec = trajectory.from_transition(
        timestep_spec,
        policy_step.PolicyStep(
            action=env._action_spec,  # pylint: disable=protected-access
            info=specs.ArraySpec(shape=(), dtype=np.int32)), timestep_spec)

    n_state = None
    # ckpt_steps = [10_000_000,15_000_000,20_000_000,25_000_000]
    ckpt_steps = [25_000_000]
  elif FLAGS.env_name.startswith('pixels-dm'):
    if 'distractor' in FLAGS.env_name:
      _, _, domain_name, _, _ = FLAGS.env_name.split('-')
    else:
      _, _, domain_name, _ = FLAGS.env_name.split('-')

    if domain_name in ['cartpole']:
      FLAGS.set_default('action_repeat', 8)
    elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
      FLAGS.set_default('action_repeat', 4)
    elif domain_name in ['finger', 'walker']:
      FLAGS.set_default('action_repeat', 2)

    env, state_env = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                    FLAGS.action_repeat, FLAGS.frame_stack,
                                    FLAGS.obs_type)

    if FLAGS.obs_type == 'pixels':
      data_spec = trajectory.from_transition(
          env.time_step_spec(), policy_step.PolicyStep(env.action_spec()),
          env.time_step_spec())
      ckpt_steps = FLAGS.ckpt_timesteps[0]
    else:
      data_spec = trajectory.from_transition(
          state_env.time_step_spec(),
          policy_step.PolicyStep(state_env.action_spec()),
          state_env.time_step_spec())
      n_state = state_env.observation_spec().shape[0]
      ckpt_steps = FLAGS.ckpt_timesteps[0]

  if FLAGS.numpy_dataset:
    tf.io.gfile.makedirs(os.path.join(FLAGS.save_dir, 'datasets'))
    def shard_fn(shard):
      return os.path.join(
          FLAGS.save_dir, 'datasets', FLAGS.env_name + '__%d__%d__%d.npy' %
          (int(ckpt_steps[-1]), FLAGS.max_timesteps, shard))

    observer = tf_utils.NumpyObserver(shard_fn, env)
    observer.allocate_arrays(FLAGS.max_timesteps)
  else:
    shard_fn = os.path.join(
        FLAGS.save_dir, 'datasets',
        FLAGS.env_name + '__%d__%d.tfrecord.shard-%d-of-%d' %
        (int(ckpt_steps[-1]), FLAGS.max_timesteps, FLAGS.worker_id,
         FLAGS.total_workers))
    observer = DummyObserver(
        shard_fn, data_spec, py_mode=True, compress_image=True)

  def load_model(checkpoint):
    checkpoint = int(checkpoint)
    print(checkpoint)
    if FLAGS.env_name.startswith('procgen'):
      env_id = [i for i, name in enumerate(
          PROCGEN_ENVS) if name == env_name][0]+1
      if checkpoint == 10_000_000:
        ckpt_iter = '0000020480'
      elif checkpoint == 15_000_000:
        ckpt_iter = '0000030720'
      elif checkpoint == 20_000_000:
        ckpt_iter = '0000040960'
      elif checkpoint == 25_000_000:
        ckpt_iter = '0000051200'
      policy_weights_dir = ('ppo_darts/'
                            '2021-06-22-16-36-54/%d/policies/checkpoints/'
                            'policy_checkpoint_%s/' % (env_id, ckpt_iter))
      policy_def_dir = ('ppo_darts/'
                        '2021-06-22-16-36-54/%d/policies/policy/' % (env_id))
      model = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
          policy_def_dir,
          time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
          action_spec=env._action_spec,  # pylint: disable=protected-access
          policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
          info_spec=tf.TensorSpec(shape=(None,)),
          load_specs_from_pbtxt=False)
      model.update_from_checkpoint(policy_weights_dir)
      model.actor = model.action
    else:
      if 'ddpg' in FLAGS.algo_name:
        model = ddpg.DDPG(
            env.observation_spec(),
            env.action_spec(),
            cross_norm='crossnorm' in FLAGS.algo_name)
      elif 'crr' in FLAGS.algo_name:
        model = awr.AWR(
            env.observation_spec(),
            env.action_spec(), f='bin_max')
      elif 'awr' in FLAGS.algo_name:
        model = awr.AWR(
            env.observation_spec(),
            env.action_spec(), f='exp_mean')
      elif 'sac_v1' in FLAGS.algo_name:
        model = sac_v1.SAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      elif 'asac' in FLAGS.algo_name:
        model = asac.ASAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      elif 'sac' in FLAGS.algo_name:
        model = sac.SAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0],
            cross_norm='crossnorm' in FLAGS.algo_name,
            pcl_actor_update='pc' in FLAGS.algo_name)
      elif 'pcl' in FLAGS.algo_name:
        model = pcl.PCL(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      if 'distractor' in FLAGS.env_name:
        ckpt_path = os.path.join(
            ('experiments/'
             '20210622_2023.policy_weights_sac_1M_dmc_distractor_hard_pixel/'),
            'results', FLAGS.env_name+'__'+str(checkpoint))
      else:
        ckpt_path = os.path.join(
            ('experiments/'
             '20210607_2023.policy_weights_dmc_1M_SAC_pixel'), 'results',
            FLAGS.env_name + '__' + str(checkpoint))

      model.load_weights(ckpt_path)
    print('Loaded model weights')
    return model

  # previous_time = time.time()
  timestep = env.reset()
  episode_return = 0
  episode_timesteps = 0

  actions = []
  time_steps = []

  def get_state_or_pixels(obs, obs_type):
    # obs of shape 1 x 84 x 84 x (n_state*frame_stack + 3*frame_stack)
    if len(obs.shape) == 4:
      obs = obs[0]
    if obs_type == 'state':
      obs = obs[0, 0, :n_state]
    else:
      obs_tmp = []
      for i in range(FLAGS.frame_stack):
        obs_tmp.append(obs[:, :, (i + 1) * (n_state) +
                           i * 3:((i + 1) * (n_state) + (i + 1) * 3)])
      obs = np.concatenate(obs_tmp, axis=-1)
    return obs

  k_model = 0
  model = load_model(ckpt_steps[k_model])
  reload_model = False

  def linear_scheduling(t):  # pylint: disable=unused-variable
    return 0.1 - 3.96e-9* t

  mixture_freq = FLAGS.max_timesteps // len(ckpt_steps)
  for i in tqdm.tqdm(range(FLAGS.max_timesteps)):
    if (i % mixture_freq) == 0 and i > 0:
      reload_model = True
    if np.all(timestep.is_last()):
      if FLAGS.env_name.startswith('procgen'):
        timestep = trajectories.TimeStep(
            timestep.step_type[0], timestep.reward[0], timestep.discount[0],
            (timestep.observation[0] * 255).astype(np.uint8))

      time_steps.append(
          ts.termination(
              get_state_or_pixels(timestep.observation()[0], 'state')
              if FLAGS.obs_type == 'state' else timestep.observation,
              timestep.reward if timestep.reward is not None else 1.0))
      # Write the episode into the TF Record

      for l in range(len(time_steps) - 1):
        t_ = min(l + FLAGS.n_step_returns, len(time_steps) - 1)
        n_step_return = 0.
        for j in range(l, t_):
          if len(time_steps[j].reward.shape) == 1:
            r_t = time_steps[j].reward[0]
          else:
            r_t = time_steps[j].reward

          n_step_return += FLAGS.discount**j * r_t

        t_ = min(l + 1 + FLAGS.n_step_returns, len(time_steps) - 1)
        n_step_return_tp1 = 0.
        for j in range(l + 1, t_):
          if len(time_steps[j].reward.shape) == 1:
            r_t = time_steps[j].reward[0]
          else:
            r_t = time_steps[j].reward

          n_step_return_tp1 += FLAGS.discount**j * r_t

        if len(time_steps[l].observation.shape) == 4:
          if len(time_steps[l].reward.shape) == 1:
            time_steps[l] = trajectories.TimeStep(time_steps[l].step_type[0],
                                                  n_step_return,
                                                  time_steps[l].discount[0],
                                                  time_steps[l].observation[0])
          else:
            time_steps[l] = trajectories.TimeStep(time_steps[l].step_type,
                                                  n_step_return,
                                                  time_steps[l].discount,
                                                  time_steps[l].observation[0])
        if len(time_steps[l + 1].observation.shape) == 4:
          if len(time_steps[l + 1].reward.shape) == 1:
            time_steps[l + 1] = trajectories.TimeStep(
                time_steps[l + 1].step_type[0], n_step_return_tp1,
                time_steps[l + 1].discount[0], time_steps[l + 1].observation[0])
          else:
            time_steps[l + 1] = trajectories.TimeStep(
                time_steps[l + 1].step_type, n_step_return_tp1,
                time_steps[l + 1].discount, time_steps[l + 1].observation[0])
        traj = trajectory.from_transition(time_steps[l], actions[l],
                                          time_steps[l + 1])
        if FLAGS.numpy_dataset:
          traj = Traj(traj, next_obs=time_steps[l+1].observation)
          observer(traj)
        else:
          observer(traj)

      timestep = env.reset()
      print(episode_return)
      episode_return = 0
      episode_timesteps = 0
      # previous_time = time.time()

      actions = []
      time_steps = []

      if reload_model:
        k_model += 1
        model = load_model(ckpt_steps[k_model])
        reload_model = False
    if FLAGS.env_name.startswith('procgen'):
      timestep = trajectories.TimeStep(
          timestep.step_type[0], timestep.reward[0], timestep.discount[0],
          (timestep.observation[0] * 255).astype(np.uint8))

    if episode_timesteps == 0:
      time_steps.append(
          ts.restart(
              get_state_or_pixels(timestep.observation, 'state') if FLAGS
              .obs_type == 'state' else (timestep.observation)))
    elif not timestep.is_last():
      time_steps.append(
          ts.transition(
              get_state_or_pixels(timestep.observation[0], 'state')
              if FLAGS.obs_type == 'state' else (timestep.observation),
              timestep.reward if timestep.reward is not None else 0.0,
              timestep.discount))

    if FLAGS.env_name.startswith('procgen'):
      # eps_t = linear_scheduling(i)
      eps_t = 0
      u = np.random.uniform(0, 1, size=1)
      if u > eps_t:
        timestep_act = trajectories.TimeStep(
            timestep.step_type, timestep.reward, timestep.discount,
            timestep.observation.astype(np.float32) / 255.)
        action = model.actor(timestep_act)
        action = action.action
      else:
        action = np.random.choice(
            env.action_spec().maximum.item() + 1, size=1)[0]
      next_timestep = env.step(action)
      info_arr = np.array(env._infos[0]['level_seed'], dtype=np.int32)  # pylint: disable=protected-access
      actions.append(policy_step.PolicyStep(action=action, state=(),
                                            info=info_arr))
    else:
      action = model.actor(
          tf.expand_dims(
              get_state_or_pixels(timestep.observation[0], 'pixel')
              if FLAGS.obs_type == 'state' else (timestep.observation[0]), 0),
          sample=True)
      next_timestep = env.step(action)
      actions.append(
          policy_step.PolicyStep(action=action.numpy()[0], state=(), info=()))

    episode_return += next_timestep.reward[0]
    episode_timesteps += 1

    timestep = next_timestep

  if FLAGS.numpy_dataset:
    observer.save(n_shards=10)