Ejemplo n.º 1
0
 def setUpClass(cls):
   super(MonMiniGridEnvTest, cls).setUpClass()
   gin.parse_config_files_and_bindings([
       os.path.join(mon_minigrid.GIN_FILES_PREFIX,
                    '{}.gin'.format('classic_fourrooms'))
   ],
                                       bindings=[],
                                       skip_unknown=False)
   cls.env_id = mon_minigrid.register_environment()
Ejemplo n.º 2
0
def main(_):
    flags.mark_flags_as_required(['base_dir'])
    if FLAGS.custom_base_dir_from_hparams is not None:
        FLAGS.base_dir = os.path.join(FLAGS.base_dir,
                                      FLAGS.custom_base_dir_from_hparams)
    else:
        # Add Work unit to base directory path, if it exists.
        if 'xm_wid' in FLAGS and FLAGS.xm_wid > 0:
            FLAGS.base_dir = os.path.join(FLAGS.base_dir, str(FLAGS.xm_wid))
    xm_parameters = (None
                     if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters)
    if xm_parameters:
        xm_params = json.loads(xm_parameters)
        if 'env_name' in xm_params:
            FLAGS.env_name = xm_params['env_name']
    if FLAGS.env_name is None:
        base_dir = os.path.join(
            FLAGS.base_dir, '{}_{}'.format(FLAGS.num_states,
                                           FLAGS.num_actions))
    else:
        base_dir = os.path.join(FLAGS.base_dir, FLAGS.env_name)
    base_dir = os.path.join(base_dir, 'PVF', FLAGS.estimator)
    if not tf.io.gfile.exists(base_dir):
        tf.io.gfile.makedirs(base_dir)
    if FLAGS.env_name is not None:
        gin.add_config_file_search_path(_ENV_CONFIG_PATH)
        gin.parse_config_files_and_bindings(
            config_files=[f'{FLAGS.env_name}.gin'],
            bindings=FLAGS.gin_bindings,
            skip_unknown=False)
        env_id = mon_minigrid.register_environment()
        env = gym.make(env_id)
        env = RGBImgObsWrapper(env)  # Get pixel observations
        # Get tabular observation and drop the 'mission' field:
        env = mdp_wrapper.MDPWrapper(env, get_rgb=False)
        env = coloring_wrapper.ColoringWrapper(env)
    if FLAGS.env_name is None:
        env = random_mdp.RandomMDP(FLAGS.num_states, FLAGS.num_actions)
        # We add the discount factor to the environment.

    env.gamma = FLAGS.gamma

    logging.set_verbosity(logging.INFO)
    gin_files = []
    gin_bindings = FLAGS.gin_bindings

    runner = TrainRunner(base_dir, env, FLAGS.epochs, FLAGS.lr,
                         FLAGS.estimator, FLAGS.alpha, FLAGS.optimizer,
                         FLAGS.use_l2_reg, FLAGS.reg_coeff,
                         FLAGS.use_penalty, FLAGS.j, FLAGS.num_rows,
                         jax.random.PRNGKey(0), FLAGS.epochs - 1,
                         FLAGS.epochs - 1)
    runner.train()
Ejemplo n.º 3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    gin.parse_config_files_and_bindings([
        os.path.join(mon_minigrid.GIN_FILES_PREFIX, '{}.gin'.format(
            FLAGS.env_name))
    ],
                                        bindings=FLAGS.gin_bindings,
                                        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    # Get tabular observation and drop the 'mission' field:
    env = tabular_wrapper.TabularWrapper(env, get_rgb=True)
    env.reset()

    num_frames = 0
    max_num_frames = 500

    if not tf.io.gfile.exists(FLAGS.file_path):
        tf.io.gfile.makedirs(FLAGS.file_path)

    print('Available actions:')
    for a in ACTION_MAPPINGS:
        print('\t{}: "{}"'.format(ACTION_MAPPINGS[a], a))
    print()
    undisc_return = 0
    while num_frames < max_num_frames:
        draw_ascii_view(env)
        a = input('action: ')
        if a not in ACTION_MAPPINGS:
            print('Unrecognized action.')
            continue
        action = env.DirectionalActions[ACTION_MAPPINGS[a]].value
        obs, reward, done, _ = env.step(action)
        undisc_return += reward
        num_frames += 1

        print('t:', num_frames, '   s:', obs['state'])
        # Draw environment frame just for simple visualization
        plt.imshow(obs['image'])
        path = os.path.join(FLAGS.file_path, 'obs_{}.png'.format(num_frames))
        plt.savefig(path)
        plt.clf()

        if done:
            break

    print('Undiscounted return: %.2f' % undisc_return)
    env.close()
Ejemplo n.º 4
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    gin.parse_config_files_and_bindings([
        os.path.join(mon_minigrid.GIN_FILES_PREFIX, '{}.gin'.format(FLAGS.env))
    ],
                                        bindings=FLAGS.gin_bindings,
                                        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    # Get tabular observation and drop the 'mission' field:
    env = mdp_wrapper.MDPWrapper(env)
    env = coloring_wrapper.ColoringWrapper(env)
    values = np.zeros(env.num_states)
    error = FLAGS.tolerance * 2
    i = 0
    while error > FLAGS.tolerance:
        new_values = np.copy(values)
        for s in range(env.num_states):
            max_value = 0.
            for a in range(env.num_actions):
                curr_value = (env.rewards[s, a] + FLAGS.gamma *
                              np.matmul(env.transition_probs[s, a, :], values))
                if curr_value > max_value:
                    max_value = curr_value
            new_values[s] = max_value
        error = np.max(abs(new_values - values))
        values = new_values
        i += 1
        if i % 1000 == 0:
            print('Error after {} iterations: {}'.format(i, error))
    print('Found V* in {} iterations'.format(i))
    print(values)
    if FLAGS.values_image_file is not None:
        cmap = cm.get_cmap('plasma', 256)
        norm = colors.Normalize(vmin=min(values), vmax=max(values))
        obs_image = env.render_custom_observation(env.reset(),
                                                  values,
                                                  cmap,
                                                  boundary_values=[1.0, 4.5])
        m = cm.ScalarMappable(cmap=cmap, norm=norm)
        m.set_array(obs_image)
        plt.imshow(obs_image)
        plt.colorbar(m)
        plt.savefig(FLAGS.values_image_file)
        plt.clf()
    env.close()
Ejemplo n.º 5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    gin.parse_config_files_and_bindings(
        [os.path.join(mon_minigrid.GIN_FILES_PREFIX, 'classic_fourrooms.gin')],
        bindings=FLAGS.gin_bindings,
        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    # Get tabular observation and drop the 'mission' field:
    env = tabular_wrapper.TabularWrapper(env, get_rgb=True)
    env.reset()

    num_frames = 0
    max_num_frames = 500

    if not tf.io.gfile.exists(FLAGS.file_path):
        tf.io.gfile.makedirs(FLAGS.file_path)

    undisc_return = 0
    while num_frames < max_num_frames:
        # Act randomly
        obs, reward, done, _ = env.step(env.action_space.sample())
        undisc_return += reward
        num_frames += 1

        print('t:', num_frames, '   s:', obs['state'])
        # Draw environment frame just for simple visualization
        plt.imshow(obs['image'])
        path = FLAGS.file_path + str(num_frames) + '.png'
        plt.savefig(path)
        plt.clf()

        if done:
            break

    print('Undiscounted return: %.2f' % undisc_return)
    env.close()
Ejemplo n.º 6
0
def main(_):
  flags.mark_flags_as_required(['base_dir'])
  if FLAGS.custom_base_dir_from_hparams is not None:
    FLAGS.base_dir = os.path.join(FLAGS.base_dir,
                                  FLAGS.custom_base_dir_from_hparams)
  else:
    # Add Work unit to base directory path, if it exists.
    if 'xm_wid' in FLAGS and FLAGS.xm_wid > 0:
      FLAGS.base_dir = os.path.join(FLAGS.base_dir, str(FLAGS.xm_wid))
  xm_parameters = (None
                   if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters)
  if xm_parameters:
    xm_params = json.loads(xm_parameters)
    if 'env_name' in xm_params:
      FLAGS.env_name = xm_params['env_name']
  if FLAGS.env_name is None:
    base_dir = os.path.join(
        FLAGS.base_dir, '{}_{}'.format(FLAGS.num_states, FLAGS.num_actions))
  else:
    base_dir = os.path.join(FLAGS.base_dir, FLAGS.env_name)
  base_dir = os.path.join(base_dir, 'PVF', FLAGS.estimator, f'lr_{FLAGS.lr}')
  if not tf.io.gfile.exists(base_dir):
    tf.io.gfile.makedirs(base_dir)
  if FLAGS.env_name is not None:
    gin.add_config_file_search_path(_ENV_CONFIG_PATH)
    gin.parse_config_files_and_bindings(
        config_files=[f'{FLAGS.env_name}.gin'],
        bindings=FLAGS.gin_bindings,
        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    # Get tabular observation and drop the 'mission' field:
    env = mdp_wrapper.MDPWrapper(env, get_rgb=False)
    env = coloring_wrapper.ColoringWrapper(env)
  if FLAGS.env_name is None:
    env = random_mdp.RandomMDP(FLAGS.num_states, FLAGS.num_actions)
    # We add the discount factor to the environment.
  env.gamma = FLAGS.gamma
  P = utils.transition_matrix(env, rl_basics.policy_random(env))  # pylint: disable=invalid-name
  S = P.shape[0]  # pylint: disable=invalid-name
  Psi = jnp.linalg.solve(jnp.eye(S) - env.gamma * P, jnp.eye(S))  # pylint: disable=invalid-name
  # Normalize tasks so that they have maximum value 1.
  max_task_value = np.max(Psi, axis=0)
  Psi /= max_task_value  # pylint: disable=invalid-name

  left_vectors, _, _ = jnp.linalg.svd(Psi)  # pylint: disable=invalid-names
  approx_error = utils.approx_error(left_vectors, FLAGS.d, Psi)

  #   Initialization of Phi
  representation_init = jax.random.normal(  # pylint: disable=invalid-names
      jax.random.PRNGKey(0),
      (S, FLAGS.d),  # pylint: disable=invalid-name
      dtype=jnp.float64)
  representations, grads = train(representation_init,
                                 Psi, FLAGS.epochs, FLAGS.lr,
                                 jax.random.PRNGKey(0), FLAGS.estimator,
                                 FLAGS.alpha, FLAGS.optimizer, FLAGS.use_l2_reg,
                                 FLAGS.reg_coeff, FLAGS.use_penalty, FLAGS.j,
                                 FLAGS.num_rows, FLAGS.skipsize_train)

  gm_distances = calc_gm_distances(representations, left_vectors[:, :FLAGS.d],
                                   FLAGS.skipsize)
  x_len = len(gm_distances)
  frob_norms = calc_frob_norms(representations, Psi, FLAGS.skipsize)
  if FLAGS.d == 1:
    dot_products = calc_dot_products(representations, left_vectors[:, :FLAGS.d],
                                     FLAGS.skipsize)
  else:
    dot_products = np.zeros((x_len,))
  grad_norms = calc_grad_norms(grads, FLAGS.skipsize)
  phi_norms = calc_Phi_norm(representations, FLAGS.skipsize)
  phi_ranks = calc_sranks(representations, FLAGS.skipsize)

  prefix = f'alpha{FLAGS.alpha}_j{FLAGS.j}_d{FLAGS.d}_regcoeff{FLAGS.reg_coeff}'

  with tf.io.gfile.GFile(osp.join(base_dir, f'{prefix}.npy'), 'wb') as f:
    np.save(
        f, {
            'gm_distances': gm_distances,
            'dot_products': dot_products,
            'frob_norms': frob_norms,
            'approx_error': approx_error,
            'grad_norms': grad_norms,
            'representations': representations,
            'phi_norms': phi_norms,
            'phi_ranks': phi_ranks
        },
        allow_pickle=True)