def main(_):
  tf.config.experimental_run_functions_eagerly(FLAGS.eager)

  tf.random.set_seed(FLAGS.seed)
  np.random.seed(FLAGS.seed)
  random.seed(FLAGS.seed)

  FLAGS.set_default('max_timesteps', FLAGS.max_timesteps // FLAGS.action_repeat)

  if 'pixels-dm' in FLAGS.env_name:
    if 'distractor' in FLAGS.env_name:
      _, _, domain_name, _, _ = FLAGS.env_name.split('-')
    else:
      _, _, domain_name, _ = FLAGS.env_name.split('-')

    if domain_name in ['cartpole']:
      FLAGS.set_default('action_repeat', 8)
    elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
      FLAGS.set_default('action_repeat', 4)
    elif domain_name in ['finger', 'walker']:
      FLAGS.set_default('action_repeat', 2)
    print('Loading env')
    env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed, FLAGS.action_repeat,
                            FLAGS.frame_stack)
    eval_env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                 FLAGS.action_repeat, FLAGS.frame_stack)
    print('Env loaded')
  else:
    raise Exception('Unsupported env')

  is_image_obs = (isinstance(env.observation_spec(), TensorSpec) and
                  len(env.observation_spec().shape) == 3)

  spec = (
      env.observation_spec(),
      env.action_spec(),
      env.reward_spec(),
      env.reward_spec(),  # discount spec
      env.observation_spec()  # next observation spec
  )
  print('Init replay')

  replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      spec, batch_size=1, max_length=FLAGS.max_length_replay_buffer)
  print('Replay created')
  @tf.function
  def add_to_replay(state, action, reward, discount, next_states):
    replay_buffer.add_batch((state, action, reward, discount, next_states))

  hparam_str = utils.make_hparam_string(
      FLAGS.xm_parameters, seed=FLAGS.seed, env_name=FLAGS.env_name,
      algo_name=FLAGS.algo_name)
  summary_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'tb', hparam_str))
  results_writer = tf.summary.create_file_writer(
      os.path.join(FLAGS.save_dir, 'results', hparam_str))
  print('Init actor')
  if 'ddpg' in FLAGS.algo_name:
    model = ddpg.DDPG(
        env.observation_spec(),
        env.action_spec(),
        cross_norm='crossnorm' in FLAGS.algo_name)
  elif 'crr' in FLAGS.algo_name:
    model = awr.AWR(
        env.observation_spec(),
        env.action_spec(), f='bin_max')
  elif 'awr' in FLAGS.algo_name:
    model = awr.AWR(
        env.observation_spec(),
        env.action_spec(), f='exp_mean')
  elif 'sac_v1' in FLAGS.algo_name:
    model = sac_v1.SAC(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0])
  elif 'asac' in FLAGS.algo_name:
    model = asac.ASAC(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0])
  elif 'sac' in FLAGS.algo_name:
    model = sac.SAC(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0],
        cross_norm='crossnorm' in FLAGS.algo_name,
        pcl_actor_update='pc' in FLAGS.algo_name)
  elif 'pcl' in FLAGS.algo_name:
    model = pcl.PCL(
        env.observation_spec(),
        env.action_spec(),
        target_entropy=-env.action_spec().shape[0])

  print('Init random policy for warmup')
  initial_collect_policy = random_tf_policy.RandomTFPolicy(
      env.time_step_spec(), env.action_spec())
  print('Init replay buffer')
  dataset = replay_buffer.as_dataset(
      num_parallel_calls=tf.data.AUTOTUNE,
      sample_batch_size=FLAGS.sample_batch_size)
  if is_image_obs:
    dataset = dataset.map(image_aug,
                          num_parallel_calls=tf.data.AUTOTUNE,
                          deterministic=False).prefetch(3)
  else:
    dataset = dataset.prefetch(3)

  def repack(*data):
    return data[0]
  dataset = dataset.map(repack)
  replay_buffer_iter = iter(dataset)

  previous_time = time.time()
  timestep = env.reset()
  episode_return = 0
  episode_timesteps = 0
  step_mult = 1 if FLAGS.action_repeat < 1 else FLAGS.action_repeat

  print('Starting training')
  for i in tqdm.tqdm(range(FLAGS.max_timesteps)):
    if i % FLAGS.deployment_batch_size == 0:
      for _ in range(FLAGS.deployment_batch_size):
        last_timestep = timestep.is_last()
        if last_timestep:

          if episode_timesteps > 0:
            current_time = time.time()
            with summary_writer.as_default():
              tf.summary.scalar(
                  'train/returns',
                  episode_return,
                  step=(i + 1) * step_mult)
              tf.summary.scalar(
                  'train/FPS',
                  episode_timesteps / (current_time - previous_time),
                  step=(i + 1) * step_mult)

          timestep = env.reset()
          episode_return = 0
          episode_timesteps = 0
          previous_time = time.time()

        if (replay_buffer.num_frames() < FLAGS.num_random_actions or
            replay_buffer.num_frames() < FLAGS.deployment_batch_size):
          # Use policy only after the first deployment.
          policy_step = initial_collect_policy.action(timestep)
          action = policy_step.action
        else:
          action = model.actor(timestep.observation, sample=True)

        next_timestep = env.step(action)

        add_to_replay(timestep.observation, action, next_timestep.reward,
                      next_timestep.discount, next_timestep.observation)

        episode_return += next_timestep.reward[0]
        episode_timesteps += 1

        timestep = next_timestep

    if i + 1 >= FLAGS.start_training_timesteps:
      with summary_writer.as_default():
        info_dict = model.update_step(replay_buffer_iter)
      if (i + 1) % FLAGS.log_interval == 0:
        with summary_writer.as_default():
          for k, v in info_dict.items():
            tf.summary.scalar(f'training/{k}', v, step=(i + 1) * step_mult)

    if (i + 1) % FLAGS.eval_interval == 0:
      logging.info('Performing policy eval.')
      average_returns, evaluation_timesteps = evaluation.evaluate(
          eval_env, model)

      with results_writer.as_default():
        tf.summary.scalar(
            'evaluation/returns', average_returns, step=(i + 1) * step_mult)
        tf.summary.scalar(
            'evaluation/length', evaluation_timesteps, step=(i+1) * step_mult)
      logging.info('Eval at %d: ave returns=%f, ave episode length=%f',
                   (i + 1) * step_mult, average_returns, evaluation_timesteps)

    if ((i + 1) * step_mult) % 50_000 == 0:
      model.save_weights(
          os.path.join(FLAGS.save_dir, 'results', FLAGS.env_name + '__' + str(
              (i + 1) * step_mult)))
Example #2
0
def main(_):
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    print('Num GPUs Available: ', len(tf.config.list_physical_devices('GPU')))
    if FLAGS.env_name.startswith('procgen'):
        print('Test env: %s' % FLAGS.env_name)
        _, env_name, train_levels, _ = FLAGS.env_name.split('-')
        print('Train env: %s' % FLAGS.env_name)
        env = tf_py_environment.TFPyEnvironment(
            procgen_wrappers.TFAgentsParallelProcGenEnv(
                1,
                normalize_rewards=False,  # no normalization for evaluation
                env_name=env_name,
                num_levels=int(train_levels),
                start_level=0))
        env_all = tf_py_environment.TFPyEnvironment(
            procgen_wrappers.TFAgentsParallelProcGenEnv(
                1,
                normalize_rewards=False,  # no normalization for evaluation
                env_name=env_name,
                num_levels=0,
                start_level=0))

        if int(train_levels) == 0:
            train_levels = '200'

    elif FLAGS.env_name.startswith('pixels-dm'):
        if 'distractor' in FLAGS.env_name:
            _, _, domain_name, _, _ = FLAGS.env_name.split('-')
        else:
            _, _, domain_name, _ = FLAGS.env_name.split('-')

        if domain_name in ['cartpole']:
            FLAGS.set_default('action_repeat', 8)
        elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
            FLAGS.set_default('action_repeat', 4)
        elif domain_name in ['finger', 'walker']:
            FLAGS.set_default('action_repeat', 2)

        env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                FLAGS.action_repeat, FLAGS.frame_stack,
                                FLAGS.obs_type)

        if FLAGS.obs_type == 'pixels':
            env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                    FLAGS.action_repeat, FLAGS.frame_stack,
                                    FLAGS.obs_type)
        else:
            _, env = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                    FLAGS.action_repeat, FLAGS.frame_stack,
                                    FLAGS.obs_type)

    if FLAGS.obs_type != 'state':
        if FLAGS.env_name.startswith('procgen'):
            bcq = bcq_pixel
            cql = cql_pixel
            fisher_brac = fisher_brac_pixel
            deepmdp = deepmdp_pixel
            vpn = vpn_pixel
            cssc = cssc_pixel
            pse = pse_pixel
    else:
        bcq = bcq_state
        cql = cql_state
    print('Loading dataset')

    # Use load_tfrecord_dataset_sequence to load transitions of size k>=2.
    if FLAGS.numpy_dataset:
        n_shards = 10

        def shard_fn(shard):
            return ('experiments/'
                    '20210617_0105.dataset_dmc_50k,100k,'
                    '200k_SAC_pixel_numpy/datasets/'
                    '%s__%d__%d__%d.npy' %
                    (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps,
                     shard))

        np_observer = tf_utils.NumpyObserver(shard_fn, env)
        dataset = np_observer.load(n_shards)
    else:
        if FLAGS.env_name.startswith('procgen'):
            if FLAGS.n_step_returns > 0:
                if FLAGS.max_timesteps == 100_000:
                    dataset_path = ('experiments/'
                                    '20210624_2033.dataset_procgen__ppo_pixel/'
                                    'datasets/%s__%d__%d.tfrecord' %
                                    (FLAGS.env_name, FLAGS.ckpt_timesteps,
                                     FLAGS.max_timesteps))
                elif FLAGS.max_timesteps == 3_000_000:
                    if int(train_levels) == 1:
                        print('Using dataset with 1 level')
                        dataset_path = (
                            'experiments/'
                            '20210713_1557.dataset_procgen__ppo_pixel_1_level/'
                            'datasets/%s__%d__%d.tfrecord' %
                            (FLAGS.env_name, FLAGS.ckpt_timesteps,
                             FLAGS.max_timesteps))
                    elif int(train_levels) == 200:
                        print('Using dataset with 200 levels')
                        # Mixture dataset between 10M,15M,20M and 25M in equal amounts
                        # dataset_path = 'experiments/
                        # 20210718_1522.dataset_procgen__ppo_pixel_mixture10,15,20,25M/
                        # datasets/%s__%d__%d.tfrecord'%(FLAGS.env_name,
                        # FLAGS.ckpt_timesteps,FLAGS.max_timesteps)
                        # PPO after 25M steps
                        dataset_path = (
                            'experiments/'
                            '20210702_2234.dataset_procgen__ppo_pixel/'
                            'datasets/%s__%d__%d.tfrecord' %
                            (FLAGS.env_name, FLAGS.ckpt_timesteps,
                             FLAGS.max_timesteps))
                elif FLAGS.max_timesteps == 5_000_000:
                    # epsilon-greedy, eps: 0.1->0.001
                    dataset_path = (
                        'experiments/'
                        '20210805_1958.dataset_procgen__ppo_pixel_'
                        'egreedy_levelIDs/datasets/'
                        '%s__%d__%d.tfrecord*' %
                        (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000))
                    # Pure greedy (epsilon=0)
                    # dataset_path = ('experiments/'
                    #                 '20210820_1348.dataset_procgen__ppo_pixel_'
                    #                 'egreedy_levelIDs/datasets/'
                    #                 '%s__%d__%d.tfrecord*' %
                    #                 (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000))

        elif FLAGS.env_name.startswith('pixels-dm'):
            if 'distractor' in FLAGS.env_name:
                dataset_path = (
                    'experiments/'
                    '20210623_1749.dataset_dmc__sac_pixel/datasets/'
                    '%s__%d__%d.tfrecord' %
                    (FLAGS.env_name, FLAGS.ckpt_timesteps,
                     FLAGS.max_timesteps))
            else:
                if FLAGS.obs_type == 'pixels':
                    dataset_path = (
                        'experiments/'
                        '20210612_1644.dataset_dmc_50k,100k,200k_SAC_pixel/'
                        'datasets/%s__%d__%d.tfrecord' %
                        (FLAGS.env_name, FLAGS.ckpt_timesteps,
                         FLAGS.max_timesteps))
                else:
                    dataset_path = (
                        'experiments/'
                        '20210621_1436.dataset_dmc__SAC_pixel/datasets/'
                        '%s__%d__%d.tfrecord' %
                        (FLAGS.env_name, FLAGS.ckpt_timesteps,
                         FLAGS.max_timesteps))
        shards = tf.io.gfile.glob(dataset_path)
        shards = [s for s in shards if not s.endswith('.spec')]
        print('Found %d shards under path %s' % (len(shards), dataset_path))
        if FLAGS.n_step_returns > 1:
            # Load sequences of length N
            dataset = load_tfrecord_dataset_sequence(
                shards,
                buffer_size_per_shard=FLAGS.dataset_size // len(shards),
                deterministic=False,
                compress_image=True,
                seq_len=FLAGS.n_step_returns)  # spec=data_spec,
            dataset = dataset.take(FLAGS.dataset_size).shuffle(
                buffer_size=FLAGS.batch_size,
                reshuffle_each_iteration=False).batch(
                    FLAGS.batch_size,
                    drop_remainder=True).prefetch(1).repeat()

            dataset_iter = iter(dataset)
        else:
            dataset_iter = tf_utils.create_data_iterator(
                ('experiments/20210805'
                 '_1958.dataset_procgen__ppo_pixel_egreedy_'
                 'levelIDs/datasets/%s__%d__%d.tfrecord.shard-*-of-*' %
                 (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000)),
                FLAGS.batch_size,
                shuffle_buffer_size=FLAGS.batch_size,
                obs_to_float=False)

    tf.random.set_seed(FLAGS.seed)

    hparam_str = utils.make_hparam_string(
        FLAGS.xm_parameters,
        algo_name=FLAGS.algo_name,
        seed=FLAGS.seed,
        task_name=FLAGS.env_name,
        ckpt_timesteps=FLAGS.ckpt_timesteps,
        rep_learn_keywords=FLAGS.rep_learn_keywords)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))
    result_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'results', hparam_str))

    pretrain = (FLAGS.pretrain > 0)

    if FLAGS.env_name.startswith('procgen'):
        # disable entropy reg for discrete spaces
        action_dim = env.action_spec().maximum.item() + 1
    else:
        action_dim = env.action_spec().shape[0]
    if 'cql' in FLAGS.algo_name:
        model = cql.CQL(env.observation_spec(),
                        env.action_spec(),
                        reg=FLAGS.f_reg,
                        target_entropy=-action_dim,
                        num_augmentations=FLAGS.num_data_augs,
                        rep_learn_keywords=FLAGS.rep_learn_keywords,
                        batch_size=FLAGS.batch_size)
    elif 'bcq' in FLAGS.algo_name:
        model = bcq.BCQ(env.observation_spec(),
                        env.action_spec(),
                        num_augmentations=FLAGS.num_data_augs)
    elif 'fbrac' in FLAGS.algo_name:
        model = fisher_brac.FBRAC(env.observation_spec(),
                                  env.action_spec(),
                                  target_entropy=-action_dim,
                                  f_reg=FLAGS.f_reg,
                                  reward_bonus=FLAGS.reward_bonus,
                                  num_augmentations=FLAGS.num_data_augs,
                                  env_name=FLAGS.env_name,
                                  batch_size=FLAGS.batch_size)
    elif 'ours' in FLAGS.algo_name:
        model = ours.OURS(env.observation_spec(),
                          env.action_spec(),
                          target_entropy=-action_dim,
                          f_reg=FLAGS.f_reg,
                          reward_bonus=FLAGS.reward_bonus,
                          num_augmentations=FLAGS.num_data_augs,
                          env_name=FLAGS.env_name,
                          rep_learn_keywords=FLAGS.rep_learn_keywords,
                          batch_size=FLAGS.batch_size,
                          n_quantiles=FLAGS.n_quantiles,
                          temp=FLAGS.temp,
                          num_training_levels=train_levels)
        bc_pretraining_steps = FLAGS.pretrain
        if pretrain:
            model_save_path = os.path.join(FLAGS.save_dir, 'weights',
                                           hparam_str)
            checkpoint = tf.train.Checkpoint(**model.model_dict)
            tf_step_counter = tf.Variable(0, dtype=tf.int32)
            manager = tf.train.CheckpointManager(
                checkpoint,
                directory=model_save_path,
                max_to_keep=1,
                checkpoint_interval=FLAGS.save_interval,
                step_counter=tf_step_counter)

            # Load the checkpoint in case it exists
            state = manager.restore_or_initialize()
            if state is not None:
                # loaded variables from checkpoint folder
                timesteps_already_done = int(
                    re.findall('ckpt-([0-9]*)',
                               state)[0])  #* FLAGS.save_interval
                print('Loaded model from timestep %d' % timesteps_already_done)
            else:
                print('Training from scratch')
                timesteps_already_done = 0

            tf_step_counter.assign(timesteps_already_done)

            print('Pretraining')
            for i in tqdm.tqdm(range(bc_pretraining_steps)):
                info_dict = model.update_step(dataset_iter,
                                              train_target='encoder')
                # (quantile_states, quantile_bins)
                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            v = tf.reduce_mean(v)
                            tf.summary.scalar(f'pretrain/{k}', v, step=i)

                tf_step_counter.assign(i)
                manager.save(checkpoint_number=i)
    elif 'bc' in FLAGS.algo_name:
        model = bc_pixel.BehavioralCloning(
            env.observation_spec(),
            env.action_spec(),
            mixture=False,
            encoder=None,
            num_augmentations=FLAGS.num_data_augs,
            rep_learn_keywords=FLAGS.rep_learn_keywords,
            env_name=FLAGS.env_name,
            batch_size=FLAGS.batch_size)
    elif 'deepmdp' in FLAGS.algo_name:
        model = deepmdp.DeepMdpLearner(
            env.observation_spec(),
            env.action_spec(),
            embedding_dim=512,
            num_distributions=1,
            sequence_length=2,
            learning_rate=3e-4,
            num_augmentations=FLAGS.num_data_augs,
            rep_learn_keywords=FLAGS.rep_learn_keywords,
            batch_size=FLAGS.batch_size)
    elif 'vpn' in FLAGS.algo_name:
        model = vpn.ValuePredictionNetworkLearner(
            env.observation_spec(),
            env.action_spec(),
            embedding_dim=512,
            learning_rate=3e-4,
            num_augmentations=FLAGS.num_data_augs,
            rep_learn_keywords=FLAGS.rep_learn_keywords,
            batch_size=FLAGS.batch_size)
    elif 'cssc' in FLAGS.algo_name:
        model = cssc.CSSC(env.observation_spec(),
                          env.action_spec(),
                          embedding_dim=512,
                          actor_lr=3e-4,
                          critic_lr=3e-4,
                          num_augmentations=FLAGS.num_data_augs,
                          rep_learn_keywords=FLAGS.rep_learn_keywords,
                          batch_size=FLAGS.batch_size)
    elif 'pse' in FLAGS.algo_name:
        model = pse.PSE(env.observation_spec(),
                        env.action_spec(),
                        embedding_dim=512,
                        actor_lr=3e-4,
                        critic_lr=3e-4,
                        num_augmentations=FLAGS.num_data_augs,
                        rep_learn_keywords=FLAGS.rep_learn_keywords,
                        batch_size=FLAGS.batch_size,
                        temperature=FLAGS.temp)
        bc_pretraining_steps = FLAGS.pretrain
        if pretrain:
            print('Pretraining')
            for i in tqdm.tqdm(range(bc_pretraining_steps)):
                info_dict = model.update_step(dataset_iter,
                                              train_target='encoder')
                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            v = tf.reduce_mean(v)
                            tf.summary.scalar(f'pretrain/{k}', v, step=i)

    if 'fbrac' in FLAGS.algo_name or FLAGS.algo_name == 'bc':
        # Either load the online policy:
        if FLAGS.load_bc and FLAGS.env_name.startswith('procgen'):
            env_id = [
                i for i, name in enumerate(PROCGEN_ENVS) if name == env_name
            ][0] + 1  # map env string to digit [1,16]
            if FLAGS.ckpt_timesteps == 10_000_000:
                ckpt_iter = '0000020480'
            elif FLAGS.ckpt_timesteps == 25_000_000:
                ckpt_iter = '0000051200'
            policy_weights_dir = (
                'ppo_darts/'
                '2021-06-22-16-36-54/%d/policies/checkpoints/'
                'policy_checkpoint_%s/' % (env_id, ckpt_iter))
            policy_def_dir = ('ppo_darts/'
                              '2021-06-22-16-36-54/%d/policies/policy/' %
                              (env_id))
            bc = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
                policy_def_dir,
                time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
                action_spec=env._action_spec,  # pylint: disable=protected-access
                policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
                info_spec=tf.TensorSpec(shape=(None, )),
                load_specs_from_pbtxt=False)
            bc.update_from_checkpoint(policy_weights_dir)
            model.bc.policy = tf_utils.TfAgentsPolicy(bc)
        else:
            if FLAGS.algo_name == 'fbrac':
                bc_pretraining_steps = 100_000
            elif FLAGS.algo_name == 'bc':
                bc_pretraining_steps = 1_000_000

            if 'fbrac' in FLAGS.algo_name:
                bc = model.bc
            else:
                bc = model
            for i in tqdm.tqdm(range(bc_pretraining_steps)):

                info_dict = bc.update_step(dataset_iter)
                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            v = tf.reduce_mean(v)
                            tf.summary.scalar(f'bc/{k}', v, step=i)

                if FLAGS.algo_name == 'bc':
                    if (i + 1) % FLAGS.eval_interval == 0:
                        average_returns, average_length = evaluation.evaluate(
                            env, bc)  # (FLAGS.env_name.startswith('procgen'))
                        average_returns_all, average_length_all = evaluation.evaluate(
                            env_all, bc)

                        with result_writer.as_default():
                            tf.summary.scalar('evaluation/returns',
                                              average_returns,
                                              step=i + 1)
                            tf.summary.scalar('evaluation/length',
                                              average_length,
                                              step=i + 1)
                            tf.summary.scalar('evaluation/returns-all',
                                              average_returns_all,
                                              step=i + 1)
                            tf.summary.scalar('evaluation/length-all',
                                              average_length_all,
                                              step=i + 1)

    if FLAGS.algo_name == 'bc':
        exit()

    if not (FLAGS.algo_name == 'ours' and pretrain):
        model_save_path = os.path.join(FLAGS.save_dir, 'weights', hparam_str)
        checkpoint = tf.train.Checkpoint(**model.model_dict)
        tf_step_counter = tf.Variable(0, dtype=tf.int32)
        manager = tf.train.CheckpointManager(
            checkpoint,
            directory=model_save_path,
            max_to_keep=1,
            checkpoint_interval=FLAGS.save_interval,
            step_counter=tf_step_counter)

        # Load the checkpoint in case it exists
        weights_path = tf.io.gfile.glob(model_save_path + '/ckpt-*.index')
        key_fn = lambda x: int(re.findall(r'(\d+)', x)[-1])
        weights_path.sort(key=key_fn)
        if weights_path:
            weights_path = weights_path[-1]  # take most recent
        state = manager.restore_or_initialize()  # restore(weights_path)
        if state is not None:
            # loaded variables from checkpoint folder
            timesteps_already_done = int(
                re.findall('ckpt-([0-9]*)', state)[0])  #* FLAGS.save_interval
            print('Loaded model from timestep %d' % timesteps_already_done)
        else:
            print('Training from scratch')
            timesteps_already_done = 0

    tf_step_counter.assign(timesteps_already_done)

    for i in tqdm.tqdm(range(timesteps_already_done, FLAGS.num_updates)):
        with summary_writer.as_default():
            info_dict = model.update_step(
                dataset_iter, train_target='rl' if pretrain else 'both')
        if i % FLAGS.log_interval == 0:
            with summary_writer.as_default():
                for k, v in info_dict.items():
                    v = tf.reduce_mean(v)
                    tf.summary.scalar(f'training/{k}', v, step=i)

        if (i + 1) % FLAGS.eval_interval == 0:
            average_returns, average_length = evaluation.evaluate(env, model)
            average_returns_all, average_length_all = evaluation.evaluate(
                env_all, model)

            with result_writer.as_default():
                tf.summary.scalar('evaluation/returns-200',
                                  average_returns,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length-200',
                                  average_length,
                                  step=i + 1)
                tf.summary.scalar('evaluation/returns-all',
                                  average_returns_all,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length-all',
                                  average_length_all,
                                  step=i + 1)

        tf_step_counter.assign(i)
        manager.save(checkpoint_number=i)
Example #3
0
def main(_):
    if FLAGS.eager:
        tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    tf.random.set_seed(FLAGS.seed)
    # np.random.seed(FLAGS.seed)
    # random.seed(FLAGS.seed)

    if 'procgen' in FLAGS.env_name:
        _, env_name, train_levels, _ = FLAGS.env_name.split('-')
        env = procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,
            env_name=env_name,
            num_levels=int(train_levels),
            start_level=0)

    elif FLAGS.env_name.startswith('pixels-dm'):
        if 'distractor' in FLAGS.env_name:
            _, _, domain_name, _, _ = FLAGS.env_name.split('-')
        else:
            _, _, domain_name, _ = FLAGS.env_name.split('-')

        if domain_name in ['cartpole']:
            FLAGS.set_default('action_repeat', 8)
        elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
            FLAGS.set_default('action_repeat', 4)
        elif domain_name in ['finger', 'walker']:
            FLAGS.set_default('action_repeat', 2)

        env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                FLAGS.action_repeat, FLAGS.frame_stack,
                                FLAGS.obs_type)

    hparam_str = utils.make_hparam_string(FLAGS.xm_parameters,
                                          algo_name=FLAGS.algo_name,
                                          seed=FLAGS.seed,
                                          task_name=FLAGS.env_name,
                                          ckpt_timesteps=FLAGS.ckpt_timesteps)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    if FLAGS.env_name.startswith('procgen'):
        # map env string to digit [1,16]
        env_id = [
            i for i, name in enumerate(PROCGEN_ENVS) if name == env_name
        ][0] + 1
        if FLAGS.ckpt_timesteps == 10_000_000:
            ckpt_iter = '0000020480'
        elif FLAGS.ckpt_timesteps == 25_000_000:
            ckpt_iter = '0000051200'
        policy_weights_dir = ('ppo_darts/'
                              '2021-06-22-16-36-54/%d/policies/checkpoints/'
                              'policy_checkpoint_%s/' % (env_id, ckpt_iter))
        policy_def_dir = ('ppo_darts/'
                          '2021-06-22-16-36-54/%d/policies/policy/' % (env_id))
        model = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            policy_def_dir,
            time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
            action_spec=env._action_spec,  # pylint: disable=protected-access
            policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
            info_spec=tf.TensorSpec(shape=(None, )),
            load_specs_from_pbtxt=False)
        model.update_from_checkpoint(policy_weights_dir)
        model = TfAgentsPolicy(model)
    else:
        if 'ddpg' in FLAGS.algo_name:
            model = ddpg.DDPG(env.observation_spec(),
                              env.action_spec(),
                              cross_norm='crossnorm' in FLAGS.algo_name)
        elif 'crr' in FLAGS.algo_name:
            model = awr.AWR(env.observation_spec(),
                            env.action_spec(),
                            f='bin_max')
        elif 'awr' in FLAGS.algo_name:
            model = awr.AWR(env.observation_spec(),
                            env.action_spec(),
                            f='exp_mean')
        elif 'sac_v1' in FLAGS.algo_name:
            model = sac_v1.SAC(env.observation_spec(),
                               env.action_spec(),
                               target_entropy=-env.action_spec().shape[0])
        elif 'asac' in FLAGS.algo_name:
            model = asac.ASAC(env.observation_spec(),
                              env.action_spec(),
                              target_entropy=-env.action_spec().shape[0])
        elif 'sac' in FLAGS.algo_name:
            model = sac.SAC(env.observation_spec(),
                            env.action_spec(),
                            target_entropy=-env.action_spec().shape[0],
                            cross_norm='crossnorm' in FLAGS.algo_name,
                            pcl_actor_update='pc' in FLAGS.algo_name)
        elif 'pcl' in FLAGS.algo_name:
            model = pcl.PCL(env.observation_spec(),
                            env.action_spec(),
                            target_entropy=-env.action_spec().shape[0])
        if 'distractor' in FLAGS.env_name:
            ckpt_path = os.path.join(
                ('experiments/20210622_2023.policy_weights_sac'
                 '_1M_dmc_distractor_hard_pixel/'), 'results',
                FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps))
        else:
            ckpt_path = os.path.join(
                ('experiments/20210607_2023.'
                 'policy_weights_dmc_1M_SAC_pixel'), 'results',
                FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps))

        model.load_weights(ckpt_path)
    print('Loaded model weights')

    with summary_writer.as_default():
        env = procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,
            env_name=env_name,
            num_levels=0,
            start_level=0)
        (avg_returns,
         avg_len) = evaluation.evaluate(env,
                                        model,
                                        num_episodes=100,
                                        return_distributions=False)
        tf.summary.scalar('evaluation/returns-all', avg_returns, step=0)
        tf.summary.scalar('evaluation/length-all', avg_len, step=0)
def main(_):
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    gym_env, dataset = d4rl_utils.create_d4rl_env_and_dataset(
        task_name=FLAGS.task_name, batch_size=FLAGS.batch_size)

    env = gym_wrapper.GymWrapper(gym_env)
    env = tf_py_environment.TFPyEnvironment(env)

    dataset_iter = iter(dataset)

    tf.random.set_seed(FLAGS.seed)

    hparam_str = utils.make_hparam_string(FLAGS.xm_parameters,
                                          algo_name=FLAGS.algo_name,
                                          seed=FLAGS.seed,
                                          task_name=FLAGS.task_name,
                                          data_name=FLAGS.data_name)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))
    result_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'results', hparam_str))

    if FLAGS.algo_name == 'bc':
        model = behavioral_cloning.BehavioralCloning(env.observation_spec(),
                                                     env.action_spec())
    elif FLAGS.algo_name == 'bc_mix':
        model = behavioral_cloning.BehavioralCloning(env.observation_spec(),
                                                     env.action_spec(),
                                                     mixture=True)
    elif 'ddpg' in FLAGS.algo_name:
        model = ddpg.DDPG(env.observation_spec(), env.action_spec())
    elif 'crr' in FLAGS.algo_name:
        model = awr.AWR(env.observation_spec(), env.action_spec(), f='bin_max')
    elif 'awr' in FLAGS.algo_name:
        model = awr.AWR(env.observation_spec(),
                        env.action_spec(),
                        f='exp_mean')
    elif 'bcq' in FLAGS.algo_name:
        model = bcq.BCQ(env.observation_spec(), env.action_spec())
    elif 'asac' in FLAGS.algo_name:
        model = asac.ASAC(env.observation_spec(),
                          env.action_spec(),
                          target_entropy=-env.action_spec().shape[0])
    elif 'sac' in FLAGS.algo_name:
        model = sac.SAC(env.observation_spec(),
                        env.action_spec(),
                        target_entropy=-env.action_spec().shape[0])
    elif 'cql' in FLAGS.algo_name:
        model = cql.CQL(env.observation_spec(),
                        env.action_spec(),
                        target_entropy=-env.action_spec().shape[0])
    elif 'brac' in FLAGS.algo_name:
        if 'fbrac' in FLAGS.algo_name:
            model = fisher_brac.FBRAC(
                env.observation_spec(),
                env.action_spec(),
                target_entropy=-env.action_spec().shape[0],
                f_reg=FLAGS.f_reg,
                reward_bonus=FLAGS.reward_bonus)
        else:
            model = brac.BRAC(env.observation_spec(),
                              env.action_spec(),
                              target_entropy=-env.action_spec().shape[0])

        model_folder = os.path.join(
            FLAGS.save_dir, 'models',
            f'{FLAGS.task_name}_{FLAGS.data_name}_{FLAGS.seed}')
        if not tf.gfile.io.isdir(model_folder):
            bc_pretraining_steps = 1_000_000
            for i in tqdm.tqdm(range(bc_pretraining_steps)):
                info_dict = model.bc.update_step(dataset_iter)

                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            tf.summary.scalar(f'training/{k}',
                                              v,
                                              step=i - bc_pretraining_steps)
            # model.bc.policy.save_weights(os.path.join(model_folder, 'model'))
        else:
            model.bc.policy.load_weights(os.path.join(model_folder, 'model'))

    for i in tqdm.tqdm(range(FLAGS.num_updates)):
        with summary_writer.as_default():
            info_dict = model.update_step(dataset_iter)

        if i % FLAGS.log_interval == 0:
            with summary_writer.as_default():
                for k, v in info_dict.items():
                    tf.summary.scalar(f'training/{k}', v, step=i)

        if (i + 1) % FLAGS.eval_interval == 0:
            average_returns, average_length = evaluation.evaluate(env, model)
            if FLAGS.data_name is None:
                average_returns = gym_env.get_normalized_score(
                    average_returns) * 100.0

            with result_writer.as_default():
                tf.summary.scalar('evaluation/returns',
                                  average_returns,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length',
                                  average_length,
                                  step=i + 1)