Beispiel #1
0
def main(_):
  if FLAGS.eager:
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

  tf.random.set_seed(FLAGS.seed)
  # np.random.seed(FLAGS.seed)
  # random.seed(FLAGS.seed)

  print('Env name: %s'%FLAGS.env_name)

  if 'procgen' in FLAGS.env_name:
    _, env_name, train_levels, _ = FLAGS.env_name.split('-')
    env = procgen_wrappers.TFAgentsParallelProcGenEnv(
        1,
        normalize_rewards=True,
        env_name=env_name,
        num_levels=int(train_levels),
        start_level=0)
    state_env = None

    timestep_spec = trajectories.time_step_spec(
        observation_spec=specs.ArraySpec(env._observation_spec.shape, np.uint8),  # pylint: disable=protected-access
        reward_spec=specs.ArraySpec(shape=(), dtype=np.float32))

    data_spec = trajectory.from_transition(
        timestep_spec,
        policy_step.PolicyStep(
            action=env._action_spec,  # pylint: disable=protected-access
            info=specs.ArraySpec(shape=(), dtype=np.int32)), timestep_spec)

    n_state = None
    # ckpt_steps = [10_000_000,15_000_000,20_000_000,25_000_000]
    ckpt_steps = [25_000_000]
  elif FLAGS.env_name.startswith('pixels-dm'):
    if 'distractor' in FLAGS.env_name:
      _, _, domain_name, _, _ = FLAGS.env_name.split('-')
    else:
      _, _, domain_name, _ = FLAGS.env_name.split('-')

    if domain_name in ['cartpole']:
      FLAGS.set_default('action_repeat', 8)
    elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
      FLAGS.set_default('action_repeat', 4)
    elif domain_name in ['finger', 'walker']:
      FLAGS.set_default('action_repeat', 2)

    env, state_env = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                    FLAGS.action_repeat, FLAGS.frame_stack,
                                    FLAGS.obs_type)

    if FLAGS.obs_type == 'pixels':
      data_spec = trajectory.from_transition(
          env.time_step_spec(), policy_step.PolicyStep(env.action_spec()),
          env.time_step_spec())
      ckpt_steps = FLAGS.ckpt_timesteps[0]
    else:
      data_spec = trajectory.from_transition(
          state_env.time_step_spec(),
          policy_step.PolicyStep(state_env.action_spec()),
          state_env.time_step_spec())
      n_state = state_env.observation_spec().shape[0]
      ckpt_steps = FLAGS.ckpt_timesteps[0]

  if FLAGS.numpy_dataset:
    tf.io.gfile.makedirs(os.path.join(FLAGS.save_dir, 'datasets'))
    def shard_fn(shard):
      return os.path.join(
          FLAGS.save_dir, 'datasets', FLAGS.env_name + '__%d__%d__%d.npy' %
          (int(ckpt_steps[-1]), FLAGS.max_timesteps, shard))

    observer = tf_utils.NumpyObserver(shard_fn, env)
    observer.allocate_arrays(FLAGS.max_timesteps)
  else:
    shard_fn = os.path.join(
        FLAGS.save_dir, 'datasets',
        FLAGS.env_name + '__%d__%d.tfrecord.shard-%d-of-%d' %
        (int(ckpt_steps[-1]), FLAGS.max_timesteps, FLAGS.worker_id,
         FLAGS.total_workers))
    observer = DummyObserver(
        shard_fn, data_spec, py_mode=True, compress_image=True)

  def load_model(checkpoint):
    checkpoint = int(checkpoint)
    print(checkpoint)
    if FLAGS.env_name.startswith('procgen'):
      env_id = [i for i, name in enumerate(
          PROCGEN_ENVS) if name == env_name][0]+1
      if checkpoint == 10_000_000:
        ckpt_iter = '0000020480'
      elif checkpoint == 15_000_000:
        ckpt_iter = '0000030720'
      elif checkpoint == 20_000_000:
        ckpt_iter = '0000040960'
      elif checkpoint == 25_000_000:
        ckpt_iter = '0000051200'
      policy_weights_dir = ('ppo_darts/'
                            '2021-06-22-16-36-54/%d/policies/checkpoints/'
                            'policy_checkpoint_%s/' % (env_id, ckpt_iter))
      policy_def_dir = ('ppo_darts/'
                        '2021-06-22-16-36-54/%d/policies/policy/' % (env_id))
      model = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
          policy_def_dir,
          time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
          action_spec=env._action_spec,  # pylint: disable=protected-access
          policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
          info_spec=tf.TensorSpec(shape=(None,)),
          load_specs_from_pbtxt=False)
      model.update_from_checkpoint(policy_weights_dir)
      model.actor = model.action
    else:
      if 'ddpg' in FLAGS.algo_name:
        model = ddpg.DDPG(
            env.observation_spec(),
            env.action_spec(),
            cross_norm='crossnorm' in FLAGS.algo_name)
      elif 'crr' in FLAGS.algo_name:
        model = awr.AWR(
            env.observation_spec(),
            env.action_spec(), f='bin_max')
      elif 'awr' in FLAGS.algo_name:
        model = awr.AWR(
            env.observation_spec(),
            env.action_spec(), f='exp_mean')
      elif 'sac_v1' in FLAGS.algo_name:
        model = sac_v1.SAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      elif 'asac' in FLAGS.algo_name:
        model = asac.ASAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      elif 'sac' in FLAGS.algo_name:
        model = sac.SAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0],
            cross_norm='crossnorm' in FLAGS.algo_name,
            pcl_actor_update='pc' in FLAGS.algo_name)
      elif 'pcl' in FLAGS.algo_name:
        model = pcl.PCL(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      if 'distractor' in FLAGS.env_name:
        ckpt_path = os.path.join(
            ('experiments/'
             '20210622_2023.policy_weights_sac_1M_dmc_distractor_hard_pixel/'),
            'results', FLAGS.env_name+'__'+str(checkpoint))
      else:
        ckpt_path = os.path.join(
            ('experiments/'
             '20210607_2023.policy_weights_dmc_1M_SAC_pixel'), 'results',
            FLAGS.env_name + '__' + str(checkpoint))

      model.load_weights(ckpt_path)
    print('Loaded model weights')
    return model

  # previous_time = time.time()
  timestep = env.reset()
  episode_return = 0
  episode_timesteps = 0

  actions = []
  time_steps = []

  def get_state_or_pixels(obs, obs_type):
    # obs of shape 1 x 84 x 84 x (n_state*frame_stack + 3*frame_stack)
    if len(obs.shape) == 4:
      obs = obs[0]
    if obs_type == 'state':
      obs = obs[0, 0, :n_state]
    else:
      obs_tmp = []
      for i in range(FLAGS.frame_stack):
        obs_tmp.append(obs[:, :, (i + 1) * (n_state) +
                           i * 3:((i + 1) * (n_state) + (i + 1) * 3)])
      obs = np.concatenate(obs_tmp, axis=-1)
    return obs

  k_model = 0
  model = load_model(ckpt_steps[k_model])
  reload_model = False

  def linear_scheduling(t):  # pylint: disable=unused-variable
    return 0.1 - 3.96e-9* t

  mixture_freq = FLAGS.max_timesteps // len(ckpt_steps)
  for i in tqdm.tqdm(range(FLAGS.max_timesteps)):
    if (i % mixture_freq) == 0 and i > 0:
      reload_model = True
    if np.all(timestep.is_last()):
      if FLAGS.env_name.startswith('procgen'):
        timestep = trajectories.TimeStep(
            timestep.step_type[0], timestep.reward[0], timestep.discount[0],
            (timestep.observation[0] * 255).astype(np.uint8))

      time_steps.append(
          ts.termination(
              get_state_or_pixels(timestep.observation()[0], 'state')
              if FLAGS.obs_type == 'state' else timestep.observation,
              timestep.reward if timestep.reward is not None else 1.0))
      # Write the episode into the TF Record

      for l in range(len(time_steps) - 1):
        t_ = min(l + FLAGS.n_step_returns, len(time_steps) - 1)
        n_step_return = 0.
        for j in range(l, t_):
          if len(time_steps[j].reward.shape) == 1:
            r_t = time_steps[j].reward[0]
          else:
            r_t = time_steps[j].reward

          n_step_return += FLAGS.discount**j * r_t

        t_ = min(l + 1 + FLAGS.n_step_returns, len(time_steps) - 1)
        n_step_return_tp1 = 0.
        for j in range(l + 1, t_):
          if len(time_steps[j].reward.shape) == 1:
            r_t = time_steps[j].reward[0]
          else:
            r_t = time_steps[j].reward

          n_step_return_tp1 += FLAGS.discount**j * r_t

        if len(time_steps[l].observation.shape) == 4:
          if len(time_steps[l].reward.shape) == 1:
            time_steps[l] = trajectories.TimeStep(time_steps[l].step_type[0],
                                                  n_step_return,
                                                  time_steps[l].discount[0],
                                                  time_steps[l].observation[0])
          else:
            time_steps[l] = trajectories.TimeStep(time_steps[l].step_type,
                                                  n_step_return,
                                                  time_steps[l].discount,
                                                  time_steps[l].observation[0])
        if len(time_steps[l + 1].observation.shape) == 4:
          if len(time_steps[l + 1].reward.shape) == 1:
            time_steps[l + 1] = trajectories.TimeStep(
                time_steps[l + 1].step_type[0], n_step_return_tp1,
                time_steps[l + 1].discount[0], time_steps[l + 1].observation[0])
          else:
            time_steps[l + 1] = trajectories.TimeStep(
                time_steps[l + 1].step_type, n_step_return_tp1,
                time_steps[l + 1].discount, time_steps[l + 1].observation[0])
        traj = trajectory.from_transition(time_steps[l], actions[l],
                                          time_steps[l + 1])
        if FLAGS.numpy_dataset:
          traj = Traj(traj, next_obs=time_steps[l+1].observation)
          observer(traj)
        else:
          observer(traj)

      timestep = env.reset()
      print(episode_return)
      episode_return = 0
      episode_timesteps = 0
      # previous_time = time.time()

      actions = []
      time_steps = []

      if reload_model:
        k_model += 1
        model = load_model(ckpt_steps[k_model])
        reload_model = False
    if FLAGS.env_name.startswith('procgen'):
      timestep = trajectories.TimeStep(
          timestep.step_type[0], timestep.reward[0], timestep.discount[0],
          (timestep.observation[0] * 255).astype(np.uint8))

    if episode_timesteps == 0:
      time_steps.append(
          ts.restart(
              get_state_or_pixels(timestep.observation, 'state') if FLAGS
              .obs_type == 'state' else (timestep.observation)))
    elif not timestep.is_last():
      time_steps.append(
          ts.transition(
              get_state_or_pixels(timestep.observation[0], 'state')
              if FLAGS.obs_type == 'state' else (timestep.observation),
              timestep.reward if timestep.reward is not None else 0.0,
              timestep.discount))

    if FLAGS.env_name.startswith('procgen'):
      # eps_t = linear_scheduling(i)
      eps_t = 0
      u = np.random.uniform(0, 1, size=1)
      if u > eps_t:
        timestep_act = trajectories.TimeStep(
            timestep.step_type, timestep.reward, timestep.discount,
            timestep.observation.astype(np.float32) / 255.)
        action = model.actor(timestep_act)
        action = action.action
      else:
        action = np.random.choice(
            env.action_spec().maximum.item() + 1, size=1)[0]
      next_timestep = env.step(action)
      info_arr = np.array(env._infos[0]['level_seed'], dtype=np.int32)  # pylint: disable=protected-access
      actions.append(policy_step.PolicyStep(action=action, state=(),
                                            info=info_arr))
    else:
      action = model.actor(
          tf.expand_dims(
              get_state_or_pixels(timestep.observation[0], 'pixel')
              if FLAGS.obs_type == 'state' else (timestep.observation[0]), 0),
          sample=True)
      next_timestep = env.step(action)
      actions.append(
          policy_step.PolicyStep(action=action.numpy()[0], state=(), info=()))

    episode_return += next_timestep.reward[0]
    episode_timesteps += 1

    timestep = next_timestep

  if FLAGS.numpy_dataset:
    observer.save(n_shards=10)
Beispiel #2
0
def main(_):
    if FLAGS.eager:
        tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    tf.random.set_seed(FLAGS.seed)
    # np.random.seed(FLAGS.seed)
    # random.seed(FLAGS.seed)

    if 'procgen' in FLAGS.env_name:
        _, env_name, train_levels, _ = FLAGS.env_name.split('-')
        env = procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,
            env_name=env_name,
            num_levels=int(train_levels),
            start_level=0)

    elif FLAGS.env_name.startswith('pixels-dm'):
        if 'distractor' in FLAGS.env_name:
            _, _, domain_name, _, _ = FLAGS.env_name.split('-')
        else:
            _, _, domain_name, _ = FLAGS.env_name.split('-')

        if domain_name in ['cartpole']:
            FLAGS.set_default('action_repeat', 8)
        elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
            FLAGS.set_default('action_repeat', 4)
        elif domain_name in ['finger', 'walker']:
            FLAGS.set_default('action_repeat', 2)

        env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                FLAGS.action_repeat, FLAGS.frame_stack,
                                FLAGS.obs_type)

    hparam_str = utils.make_hparam_string(FLAGS.xm_parameters,
                                          algo_name=FLAGS.algo_name,
                                          seed=FLAGS.seed,
                                          task_name=FLAGS.env_name,
                                          ckpt_timesteps=FLAGS.ckpt_timesteps)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    if FLAGS.env_name.startswith('procgen'):
        # map env string to digit [1,16]
        env_id = [
            i for i, name in enumerate(PROCGEN_ENVS) if name == env_name
        ][0] + 1
        if FLAGS.ckpt_timesteps == 10_000_000:
            ckpt_iter = '0000020480'
        elif FLAGS.ckpt_timesteps == 25_000_000:
            ckpt_iter = '0000051200'
        policy_weights_dir = ('ppo_darts/'
                              '2021-06-22-16-36-54/%d/policies/checkpoints/'
                              'policy_checkpoint_%s/' % (env_id, ckpt_iter))
        policy_def_dir = ('ppo_darts/'
                          '2021-06-22-16-36-54/%d/policies/policy/' % (env_id))
        model = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            policy_def_dir,
            time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
            action_spec=env._action_spec,  # pylint: disable=protected-access
            policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
            info_spec=tf.TensorSpec(shape=(None, )),
            load_specs_from_pbtxt=False)
        model.update_from_checkpoint(policy_weights_dir)
        model = TfAgentsPolicy(model)
    else:
        if 'ddpg' in FLAGS.algo_name:
            model = ddpg.DDPG(env.observation_spec(),
                              env.action_spec(),
                              cross_norm='crossnorm' in FLAGS.algo_name)
        elif 'crr' in FLAGS.algo_name:
            model = awr.AWR(env.observation_spec(),
                            env.action_spec(),
                            f='bin_max')
        elif 'awr' in FLAGS.algo_name:
            model = awr.AWR(env.observation_spec(),
                            env.action_spec(),
                            f='exp_mean')
        elif 'sac_v1' in FLAGS.algo_name:
            model = sac_v1.SAC(env.observation_spec(),
                               env.action_spec(),
                               target_entropy=-env.action_spec().shape[0])
        elif 'asac' in FLAGS.algo_name:
            model = asac.ASAC(env.observation_spec(),
                              env.action_spec(),
                              target_entropy=-env.action_spec().shape[0])
        elif 'sac' in FLAGS.algo_name:
            model = sac.SAC(env.observation_spec(),
                            env.action_spec(),
                            target_entropy=-env.action_spec().shape[0],
                            cross_norm='crossnorm' in FLAGS.algo_name,
                            pcl_actor_update='pc' in FLAGS.algo_name)
        elif 'pcl' in FLAGS.algo_name:
            model = pcl.PCL(env.observation_spec(),
                            env.action_spec(),
                            target_entropy=-env.action_spec().shape[0])
        if 'distractor' in FLAGS.env_name:
            ckpt_path = os.path.join(
                ('experiments/20210622_2023.policy_weights_sac'
                 '_1M_dmc_distractor_hard_pixel/'), 'results',
                FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps))
        else:
            ckpt_path = os.path.join(
                ('experiments/20210607_2023.'
                 'policy_weights_dmc_1M_SAC_pixel'), 'results',
                FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps))

        model.load_weights(ckpt_path)
    print('Loaded model weights')

    with summary_writer.as_default():
        env = procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,
            env_name=env_name,
            num_levels=0,
            start_level=0)
        (avg_returns,
         avg_len) = evaluation.evaluate(env,
                                        model,
                                        num_episodes=100,
                                        return_distributions=False)
        tf.summary.scalar('evaluation/returns-all', avg_returns, step=0)
        tf.summary.scalar('evaluation/length-all', avg_len, step=0)
Beispiel #3
0
def main(_):
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    print('Num GPUs Available: ', len(tf.config.list_physical_devices('GPU')))
    if FLAGS.env_name.startswith('procgen'):
        print('Test env: %s' % FLAGS.env_name)
        _, env_name, train_levels, _ = FLAGS.env_name.split('-')
        print('Train env: %s' % FLAGS.env_name)
        env = tf_py_environment.TFPyEnvironment(
            procgen_wrappers.TFAgentsParallelProcGenEnv(
                1,
                normalize_rewards=False,  # no normalization for evaluation
                env_name=env_name,
                num_levels=int(train_levels),
                start_level=0))
        env_all = tf_py_environment.TFPyEnvironment(
            procgen_wrappers.TFAgentsParallelProcGenEnv(
                1,
                normalize_rewards=False,  # no normalization for evaluation
                env_name=env_name,
                num_levels=0,
                start_level=0))

        if int(train_levels) == 0:
            train_levels = '200'

    elif FLAGS.env_name.startswith('pixels-dm'):
        if 'distractor' in FLAGS.env_name:
            _, _, domain_name, _, _ = FLAGS.env_name.split('-')
        else:
            _, _, domain_name, _ = FLAGS.env_name.split('-')

        if domain_name in ['cartpole']:
            FLAGS.set_default('action_repeat', 8)
        elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
            FLAGS.set_default('action_repeat', 4)
        elif domain_name in ['finger', 'walker']:
            FLAGS.set_default('action_repeat', 2)

        env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                FLAGS.action_repeat, FLAGS.frame_stack,
                                FLAGS.obs_type)

        if FLAGS.obs_type == 'pixels':
            env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                    FLAGS.action_repeat, FLAGS.frame_stack,
                                    FLAGS.obs_type)
        else:
            _, env = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                    FLAGS.action_repeat, FLAGS.frame_stack,
                                    FLAGS.obs_type)

    if FLAGS.obs_type != 'state':
        if FLAGS.env_name.startswith('procgen'):
            bcq = bcq_pixel
            cql = cql_pixel
            fisher_brac = fisher_brac_pixel
            deepmdp = deepmdp_pixel
            vpn = vpn_pixel
            cssc = cssc_pixel
            pse = pse_pixel
    else:
        bcq = bcq_state
        cql = cql_state
    print('Loading dataset')

    # Use load_tfrecord_dataset_sequence to load transitions of size k>=2.
    if FLAGS.numpy_dataset:
        n_shards = 10

        def shard_fn(shard):
            return ('experiments/'
                    '20210617_0105.dataset_dmc_50k,100k,'
                    '200k_SAC_pixel_numpy/datasets/'
                    '%s__%d__%d__%d.npy' %
                    (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps,
                     shard))

        np_observer = tf_utils.NumpyObserver(shard_fn, env)
        dataset = np_observer.load(n_shards)
    else:
        if FLAGS.env_name.startswith('procgen'):
            if FLAGS.n_step_returns > 0:
                if FLAGS.max_timesteps == 100_000:
                    dataset_path = ('experiments/'
                                    '20210624_2033.dataset_procgen__ppo_pixel/'
                                    'datasets/%s__%d__%d.tfrecord' %
                                    (FLAGS.env_name, FLAGS.ckpt_timesteps,
                                     FLAGS.max_timesteps))
                elif FLAGS.max_timesteps == 3_000_000:
                    if int(train_levels) == 1:
                        print('Using dataset with 1 level')
                        dataset_path = (
                            'experiments/'
                            '20210713_1557.dataset_procgen__ppo_pixel_1_level/'
                            'datasets/%s__%d__%d.tfrecord' %
                            (FLAGS.env_name, FLAGS.ckpt_timesteps,
                             FLAGS.max_timesteps))
                    elif int(train_levels) == 200:
                        print('Using dataset with 200 levels')
                        # Mixture dataset between 10M,15M,20M and 25M in equal amounts
                        # dataset_path = 'experiments/
                        # 20210718_1522.dataset_procgen__ppo_pixel_mixture10,15,20,25M/
                        # datasets/%s__%d__%d.tfrecord'%(FLAGS.env_name,
                        # FLAGS.ckpt_timesteps,FLAGS.max_timesteps)
                        # PPO after 25M steps
                        dataset_path = (
                            'experiments/'
                            '20210702_2234.dataset_procgen__ppo_pixel/'
                            'datasets/%s__%d__%d.tfrecord' %
                            (FLAGS.env_name, FLAGS.ckpt_timesteps,
                             FLAGS.max_timesteps))
                elif FLAGS.max_timesteps == 5_000_000:
                    # epsilon-greedy, eps: 0.1->0.001
                    dataset_path = (
                        'experiments/'
                        '20210805_1958.dataset_procgen__ppo_pixel_'
                        'egreedy_levelIDs/datasets/'
                        '%s__%d__%d.tfrecord*' %
                        (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000))
                    # Pure greedy (epsilon=0)
                    # dataset_path = ('experiments/'
                    #                 '20210820_1348.dataset_procgen__ppo_pixel_'
                    #                 'egreedy_levelIDs/datasets/'
                    #                 '%s__%d__%d.tfrecord*' %
                    #                 (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000))

        elif FLAGS.env_name.startswith('pixels-dm'):
            if 'distractor' in FLAGS.env_name:
                dataset_path = (
                    'experiments/'
                    '20210623_1749.dataset_dmc__sac_pixel/datasets/'
                    '%s__%d__%d.tfrecord' %
                    (FLAGS.env_name, FLAGS.ckpt_timesteps,
                     FLAGS.max_timesteps))
            else:
                if FLAGS.obs_type == 'pixels':
                    dataset_path = (
                        'experiments/'
                        '20210612_1644.dataset_dmc_50k,100k,200k_SAC_pixel/'
                        'datasets/%s__%d__%d.tfrecord' %
                        (FLAGS.env_name, FLAGS.ckpt_timesteps,
                         FLAGS.max_timesteps))
                else:
                    dataset_path = (
                        'experiments/'
                        '20210621_1436.dataset_dmc__SAC_pixel/datasets/'
                        '%s__%d__%d.tfrecord' %
                        (FLAGS.env_name, FLAGS.ckpt_timesteps,
                         FLAGS.max_timesteps))
        shards = tf.io.gfile.glob(dataset_path)
        shards = [s for s in shards if not s.endswith('.spec')]
        print('Found %d shards under path %s' % (len(shards), dataset_path))
        if FLAGS.n_step_returns > 1:
            # Load sequences of length N
            dataset = load_tfrecord_dataset_sequence(
                shards,
                buffer_size_per_shard=FLAGS.dataset_size // len(shards),
                deterministic=False,
                compress_image=True,
                seq_len=FLAGS.n_step_returns)  # spec=data_spec,
            dataset = dataset.take(FLAGS.dataset_size).shuffle(
                buffer_size=FLAGS.batch_size,
                reshuffle_each_iteration=False).batch(
                    FLAGS.batch_size,
                    drop_remainder=True).prefetch(1).repeat()

            dataset_iter = iter(dataset)
        else:
            dataset_iter = tf_utils.create_data_iterator(
                ('experiments/20210805'
                 '_1958.dataset_procgen__ppo_pixel_egreedy_'
                 'levelIDs/datasets/%s__%d__%d.tfrecord.shard-*-of-*' %
                 (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000)),
                FLAGS.batch_size,
                shuffle_buffer_size=FLAGS.batch_size,
                obs_to_float=False)

    tf.random.set_seed(FLAGS.seed)

    hparam_str = utils.make_hparam_string(
        FLAGS.xm_parameters,
        algo_name=FLAGS.algo_name,
        seed=FLAGS.seed,
        task_name=FLAGS.env_name,
        ckpt_timesteps=FLAGS.ckpt_timesteps,
        rep_learn_keywords=FLAGS.rep_learn_keywords)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))
    result_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'results', hparam_str))

    pretrain = (FLAGS.pretrain > 0)

    if FLAGS.env_name.startswith('procgen'):
        # disable entropy reg for discrete spaces
        action_dim = env.action_spec().maximum.item() + 1
    else:
        action_dim = env.action_spec().shape[0]
    if 'cql' in FLAGS.algo_name:
        model = cql.CQL(env.observation_spec(),
                        env.action_spec(),
                        reg=FLAGS.f_reg,
                        target_entropy=-action_dim,
                        num_augmentations=FLAGS.num_data_augs,
                        rep_learn_keywords=FLAGS.rep_learn_keywords,
                        batch_size=FLAGS.batch_size)
    elif 'bcq' in FLAGS.algo_name:
        model = bcq.BCQ(env.observation_spec(),
                        env.action_spec(),
                        num_augmentations=FLAGS.num_data_augs)
    elif 'fbrac' in FLAGS.algo_name:
        model = fisher_brac.FBRAC(env.observation_spec(),
                                  env.action_spec(),
                                  target_entropy=-action_dim,
                                  f_reg=FLAGS.f_reg,
                                  reward_bonus=FLAGS.reward_bonus,
                                  num_augmentations=FLAGS.num_data_augs,
                                  env_name=FLAGS.env_name,
                                  batch_size=FLAGS.batch_size)
    elif 'ours' in FLAGS.algo_name:
        model = ours.OURS(env.observation_spec(),
                          env.action_spec(),
                          target_entropy=-action_dim,
                          f_reg=FLAGS.f_reg,
                          reward_bonus=FLAGS.reward_bonus,
                          num_augmentations=FLAGS.num_data_augs,
                          env_name=FLAGS.env_name,
                          rep_learn_keywords=FLAGS.rep_learn_keywords,
                          batch_size=FLAGS.batch_size,
                          n_quantiles=FLAGS.n_quantiles,
                          temp=FLAGS.temp,
                          num_training_levels=train_levels)
        bc_pretraining_steps = FLAGS.pretrain
        if pretrain:
            model_save_path = os.path.join(FLAGS.save_dir, 'weights',
                                           hparam_str)
            checkpoint = tf.train.Checkpoint(**model.model_dict)
            tf_step_counter = tf.Variable(0, dtype=tf.int32)
            manager = tf.train.CheckpointManager(
                checkpoint,
                directory=model_save_path,
                max_to_keep=1,
                checkpoint_interval=FLAGS.save_interval,
                step_counter=tf_step_counter)

            # Load the checkpoint in case it exists
            state = manager.restore_or_initialize()
            if state is not None:
                # loaded variables from checkpoint folder
                timesteps_already_done = int(
                    re.findall('ckpt-([0-9]*)',
                               state)[0])  #* FLAGS.save_interval
                print('Loaded model from timestep %d' % timesteps_already_done)
            else:
                print('Training from scratch')
                timesteps_already_done = 0

            tf_step_counter.assign(timesteps_already_done)

            print('Pretraining')
            for i in tqdm.tqdm(range(bc_pretraining_steps)):
                info_dict = model.update_step(dataset_iter,
                                              train_target='encoder')
                # (quantile_states, quantile_bins)
                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            v = tf.reduce_mean(v)
                            tf.summary.scalar(f'pretrain/{k}', v, step=i)

                tf_step_counter.assign(i)
                manager.save(checkpoint_number=i)
    elif 'bc' in FLAGS.algo_name:
        model = bc_pixel.BehavioralCloning(
            env.observation_spec(),
            env.action_spec(),
            mixture=False,
            encoder=None,
            num_augmentations=FLAGS.num_data_augs,
            rep_learn_keywords=FLAGS.rep_learn_keywords,
            env_name=FLAGS.env_name,
            batch_size=FLAGS.batch_size)
    elif 'deepmdp' in FLAGS.algo_name:
        model = deepmdp.DeepMdpLearner(
            env.observation_spec(),
            env.action_spec(),
            embedding_dim=512,
            num_distributions=1,
            sequence_length=2,
            learning_rate=3e-4,
            num_augmentations=FLAGS.num_data_augs,
            rep_learn_keywords=FLAGS.rep_learn_keywords,
            batch_size=FLAGS.batch_size)
    elif 'vpn' in FLAGS.algo_name:
        model = vpn.ValuePredictionNetworkLearner(
            env.observation_spec(),
            env.action_spec(),
            embedding_dim=512,
            learning_rate=3e-4,
            num_augmentations=FLAGS.num_data_augs,
            rep_learn_keywords=FLAGS.rep_learn_keywords,
            batch_size=FLAGS.batch_size)
    elif 'cssc' in FLAGS.algo_name:
        model = cssc.CSSC(env.observation_spec(),
                          env.action_spec(),
                          embedding_dim=512,
                          actor_lr=3e-4,
                          critic_lr=3e-4,
                          num_augmentations=FLAGS.num_data_augs,
                          rep_learn_keywords=FLAGS.rep_learn_keywords,
                          batch_size=FLAGS.batch_size)
    elif 'pse' in FLAGS.algo_name:
        model = pse.PSE(env.observation_spec(),
                        env.action_spec(),
                        embedding_dim=512,
                        actor_lr=3e-4,
                        critic_lr=3e-4,
                        num_augmentations=FLAGS.num_data_augs,
                        rep_learn_keywords=FLAGS.rep_learn_keywords,
                        batch_size=FLAGS.batch_size,
                        temperature=FLAGS.temp)
        bc_pretraining_steps = FLAGS.pretrain
        if pretrain:
            print('Pretraining')
            for i in tqdm.tqdm(range(bc_pretraining_steps)):
                info_dict = model.update_step(dataset_iter,
                                              train_target='encoder')
                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            v = tf.reduce_mean(v)
                            tf.summary.scalar(f'pretrain/{k}', v, step=i)

    if 'fbrac' in FLAGS.algo_name or FLAGS.algo_name == 'bc':
        # Either load the online policy:
        if FLAGS.load_bc and FLAGS.env_name.startswith('procgen'):
            env_id = [
                i for i, name in enumerate(PROCGEN_ENVS) if name == env_name
            ][0] + 1  # map env string to digit [1,16]
            if FLAGS.ckpt_timesteps == 10_000_000:
                ckpt_iter = '0000020480'
            elif FLAGS.ckpt_timesteps == 25_000_000:
                ckpt_iter = '0000051200'
            policy_weights_dir = (
                'ppo_darts/'
                '2021-06-22-16-36-54/%d/policies/checkpoints/'
                'policy_checkpoint_%s/' % (env_id, ckpt_iter))
            policy_def_dir = ('ppo_darts/'
                              '2021-06-22-16-36-54/%d/policies/policy/' %
                              (env_id))
            bc = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
                policy_def_dir,
                time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
                action_spec=env._action_spec,  # pylint: disable=protected-access
                policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
                info_spec=tf.TensorSpec(shape=(None, )),
                load_specs_from_pbtxt=False)
            bc.update_from_checkpoint(policy_weights_dir)
            model.bc.policy = tf_utils.TfAgentsPolicy(bc)
        else:
            if FLAGS.algo_name == 'fbrac':
                bc_pretraining_steps = 100_000
            elif FLAGS.algo_name == 'bc':
                bc_pretraining_steps = 1_000_000

            if 'fbrac' in FLAGS.algo_name:
                bc = model.bc
            else:
                bc = model
            for i in tqdm.tqdm(range(bc_pretraining_steps)):

                info_dict = bc.update_step(dataset_iter)
                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            v = tf.reduce_mean(v)
                            tf.summary.scalar(f'bc/{k}', v, step=i)

                if FLAGS.algo_name == 'bc':
                    if (i + 1) % FLAGS.eval_interval == 0:
                        average_returns, average_length = evaluation.evaluate(
                            env, bc)  # (FLAGS.env_name.startswith('procgen'))
                        average_returns_all, average_length_all = evaluation.evaluate(
                            env_all, bc)

                        with result_writer.as_default():
                            tf.summary.scalar('evaluation/returns',
                                              average_returns,
                                              step=i + 1)
                            tf.summary.scalar('evaluation/length',
                                              average_length,
                                              step=i + 1)
                            tf.summary.scalar('evaluation/returns-all',
                                              average_returns_all,
                                              step=i + 1)
                            tf.summary.scalar('evaluation/length-all',
                                              average_length_all,
                                              step=i + 1)

    if FLAGS.algo_name == 'bc':
        exit()

    if not (FLAGS.algo_name == 'ours' and pretrain):
        model_save_path = os.path.join(FLAGS.save_dir, 'weights', hparam_str)
        checkpoint = tf.train.Checkpoint(**model.model_dict)
        tf_step_counter = tf.Variable(0, dtype=tf.int32)
        manager = tf.train.CheckpointManager(
            checkpoint,
            directory=model_save_path,
            max_to_keep=1,
            checkpoint_interval=FLAGS.save_interval,
            step_counter=tf_step_counter)

        # Load the checkpoint in case it exists
        weights_path = tf.io.gfile.glob(model_save_path + '/ckpt-*.index')
        key_fn = lambda x: int(re.findall(r'(\d+)', x)[-1])
        weights_path.sort(key=key_fn)
        if weights_path:
            weights_path = weights_path[-1]  # take most recent
        state = manager.restore_or_initialize()  # restore(weights_path)
        if state is not None:
            # loaded variables from checkpoint folder
            timesteps_already_done = int(
                re.findall('ckpt-([0-9]*)', state)[0])  #* FLAGS.save_interval
            print('Loaded model from timestep %d' % timesteps_already_done)
        else:
            print('Training from scratch')
            timesteps_already_done = 0

    tf_step_counter.assign(timesteps_already_done)

    for i in tqdm.tqdm(range(timesteps_already_done, FLAGS.num_updates)):
        with summary_writer.as_default():
            info_dict = model.update_step(
                dataset_iter, train_target='rl' if pretrain else 'both')
        if i % FLAGS.log_interval == 0:
            with summary_writer.as_default():
                for k, v in info_dict.items():
                    v = tf.reduce_mean(v)
                    tf.summary.scalar(f'training/{k}', v, step=i)

        if (i + 1) % FLAGS.eval_interval == 0:
            average_returns, average_length = evaluation.evaluate(env, model)
            average_returns_all, average_length_all = evaluation.evaluate(
                env_all, model)

            with result_writer.as_default():
                tf.summary.scalar('evaluation/returns-200',
                                  average_returns,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length-200',
                                  average_length,
                                  step=i + 1)
                tf.summary.scalar('evaluation/returns-all',
                                  average_returns_all,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length-all',
                                  average_length_all,
                                  step=i + 1)

        tf_step_counter.assign(i)
        manager.save(checkpoint_number=i)