def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    puddles = arenas.get_arena(FLAGS.arena_name)
    pw = puddle_world.PuddleWorld(puddles=puddles,
                                  goal_position=geometry.Point((1.0, 1.0)))

    dpw = pw_utils.DiscretizedPuddleWorld(pw, FLAGS.num_bins)
    num_states = dpw.num_states

    # We want to avoid rounding errors when calculating shards,
    # so we use fractions.
    percent_work_to_complete = fractions.Fraction(1, FLAGS.num_shards)
    start_idx = int(FLAGS.shard_idx * percent_work_to_complete * num_states)
    end_idx = int(
        (FLAGS.shard_idx + 1) * percent_work_to_complete * num_states)
    logging.info('start idx: %d, end idx: %d', start_idx, end_idx)

    result_matrices = list()

    # TODO(joshgreaves): utils has helpful functions for generating rollouts.
    for start_state in range(start_idx, end_idx):
        logging.info('Starting iteration %d', start_state)
        result_matrix = np.zeros(
            (FLAGS.num_rollouts_per_start_state, num_states), dtype=np.float32)

        for i in range(FLAGS.num_rollouts_per_start_state):
            current_gamma = 1.0

            s = dpw.sample_state_in_bin(start_state)

            for _ in range(FLAGS.rollout_length):
                action = random.randrange(pw_utils.NUM_ACTIONS)
                transition = dpw.transition(s, action)
                s = transition.next_state

                result_matrix[i, s.bin_idx] += current_gamma
                current_gamma *= FLAGS.gamma

        result_matrices.append(np.mean(result_matrix, axis=0))

    # Before saving, make sure the path exists.
    output_dir = epath.Path(FLAGS.output_dir)
    output_dir.mkdir(exist_ok=True)

    if FLAGS.shard_idx == 0:
        # Write some metadata to make analysis easier at the end.
        metadata = {
            'arena_name': FLAGS.arena_name,
            'num_bins': FLAGS.num_bins,
            'num_shards': FLAGS.num_shards,
        }
        json_file_path = output_dir / 'metadata.json'
        with json_file_path.open('w') as f:
            json.dump(metadata, f)

    file_path = output_dir / f'sr_{start_idx}-{end_idx}.np'
    with file_path.open('wb') as f:
        np.save(f, np.stack(result_matrices, axis=0))
Esempio n. 2
0
    def test_puddle_world_calculates_transition_with_multiple_slow_puddles(
            self):
        # This world has 2 circular ⚫️ slow puddles centered in the middle
        # of the arena. One had a radius of 0.25, and the other 0.1. We expect
        # that as we move through the larger one, we will be slowed to
        # 0.5 times base speed, and as we move through both we will be slowed to
        # 0.25 times base speed. 🐌
        pw = puddle_world.PuddleWorld(
            puddles=(puddle_world.SlowPuddle(
                shape=puddle_world.circle(geometry.Point((0.5, 0.5)), 0.25)),
                     puddle_world.SlowPuddle(
                         shape=puddle_world.circle(geometry.Point((
                             0.5, 0.5)), 0.1))),
            goal_position=geometry.Point((0.5, 0.5)),
            noise=0.0,
            thrust=0.5)  # A large thrust to step over multiple circles.

        start_position = geometry.Point((0.2, 0.5))
        transition = pw.transition(start_position, puddle_world.Action.RIGHT)

        # 0.5 total movement 🏃
        # 0.05 spent getting to first circle's edge.
        # 0.3 spent getting to second circle's edge.
        # 0.15 remaining at 25% efficiency => 0.0375 into inner circle.
        expected_end_position = geometry.Point((0.4375, 0.5))

        self.assertEqual(transition.state, start_position)
        self.assertPointsAlmostEqual(transition.next_state,
                                     expected_end_position)
Esempio n. 3
0
    def test_puddle_world_obeys_arena_boundaries(self, start_position, action):
        pw = puddle_world.PuddleWorld((),
                                      goal_position=geometry.Point((0.5, 0.5)),
                                      noise=0.0,
                                      thrust=0.1)

        transition = pw.transition(start_position, action)

        self.assertBetween(transition.next_state.x, 0.0, 1.0)
        self.assertBetween(transition.next_state.y, 0.0, 1.0)
Esempio n. 4
0
    def test_puddle_world_calculates_correct_transition_with_no_puddles(
            self, start_position, action, expected_end_position):
        pw = puddle_world.PuddleWorld(puddles=(),
                                      goal_position=geometry.Point((0.5, 0.5)),
                                      noise=0.0)

        transition = pw.transition(start_position, action)

        self.assertEqual(transition.state, start_position)
        self.assertPointsAlmostEqual(transition.next_state,
                                     expected_end_position)
Esempio n. 5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    config = _CONFIG.value

    # Connect to reverb
    reverb_client = reverb.Client(FLAGS.reverb_address)

    puddles = arenas.get_arena(config.env.pw.arena)
    pw = puddle_world.PuddleWorld(puddles=puddles,
                                  goal_position=geometry.Point((1.0, 1.0)))
    dpw = pw_utils.DiscretizedPuddleWorld(pw, config.env.pw.num_bins)

    if FLAGS.eval:
        eval_worker(dpw, reverb_client, config)
    else:
        train_worker(dpw, reverb_client, config)
Esempio n. 6
0
    def test_puddle_world_correctly_applies_wall_puddles(self):
        pw = puddle_world.PuddleWorld(
            puddles=(puddle_world.SlowPuddle(
                shape=puddle_world.circle(geometry.Point((0.5, 0.5)), 0.25)),
                     puddle_world.WallPuddle(
                         shape=puddle_world.circle(geometry.Point((
                             0.5, 0.5)), 0.1))),
            goal_position=geometry.Point((0.5, 0.5)),
            noise=0.0,
            thrust=0.5)  # A large thrust to step over multiple circles.

        start_position = geometry.Point((0.2, 0.5))
        transition = pw.transition(start_position, puddle_world.Action.RIGHT)

        # We should stop at the inner wall puddle.
        expected_end_position = geometry.Point((0.4, 0.5))

        self.assertEqual(transition.state, start_position)
        self.assertPointsAlmostEqual(transition.next_state,
                                     expected_end_position)
Esempio n. 7
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    # Set up the arena.
    puddles = arenas.get_arena(_ARENA_NAME.value)
    pw = puddle_world.PuddleWorld(puddles=puddles,
                                  goal_position=geometry.Point((1.0, 1.0)))
    dpw = pw_utils.DiscretizedPuddleWorld(pw, _NUM_BINS.value)
    num_states = dpw.num_states

    # Work out which work this shard should do.
    start_idx, end_idx = _compute_start_and_end_indices(
        _SHARD_IDX.value, _NUM_SHARDS.value, num_states)
    logging.info('start idx: %d, end idx: %d', start_idx, end_idx)

    result_matrices = list()
    for start_state in range(start_idx, end_idx):
        logging.info('Starting iteration %d', start_state)
        result_matrix = np.zeros(
            (_NUM_ROLLOUTS_PER_START_STATE.value, num_states),
            dtype=np.float32)

        for i in range(_NUM_ROLLOUTS_PER_START_STATE.value):
            s = dpw.sample_state_in_bin(start_state)
            rollout = pw_utils.generate_rollout(dpw, _ROLLOUT_LENGTH.value, s)
            result_matrix[
                i] = pw_utils.calculate_empricial_successor_representation(
                    dpw, rollout, _GAMMA.value)

        result_matrices.append(np.mean(result_matrix, axis=0))

    output_dir = epath.Path(_OUTPUT_DIR.value)
    _save_results(output_dir, _SHARD_IDX.value, _NUM_SHARDS.value,
                  np.stack(result_matrices, axis=0))
    _maybe_combine_shards(output_dir, _SHARD_IDX.value, _NUM_SHARDS.value)
Esempio n. 8
0
def create_puddle_world_experiment(
    config):
  """Creates a Puddle World experiment."""
  key = jax.random.PRNGKey(config.seed)
  network_key, key = jax.random.split(key)

  all_arenas = {'sutton_10', 'sutton_20', 'sutton_100'}
  if config.puddle_world_arena not in all_arenas:
    raise ValueError(f'Unknown arena {config.puddle_world_arena}.')

  if not config.puddle_world_path:
    raise ValueError('puddle_world_path wasn\'t supplied. Please pass a path.')

  path = epath.Path(config.puddle_world_path)
  path = path / config.puddle_world_arena

  with (path / 'metadata.json').open('r') as f:
    metadata = json.load(f)

  puddles = arenas.get_arena(metadata['arena_name'])
  pw = puddle_world.PuddleWorld(
      puddles, goal_position=geometry.Point((1.0, 1.0)))
  dpw = pw_utils.DiscretizedPuddleWorld(pw, metadata['num_bins'])

  with (path / 'sr.np').open('rb') as f:
    Psi = np.load(f)
  Psi = jnp.asarray(Psi, dtype=jnp.float32)

  with (path / 'svd.np').open('rb') as f:
    optimal_subspace = np.load(f)
  optimal_subspace = jnp.asarray(optimal_subspace[:, :config.d])

  if config.T != Psi.shape[1]:
    logging.warning(
        'Num tasks T (%d) does not match columns of Psi (%d). Overwriting.',
        config.T,
        Psi.shape[1])
    config.T = Psi.shape[1]

  # Ensure we don't use the tabular gradient, since we will always use
  # neural networks with puddle world experiments.
  config.use_tabular_gradient = False

  eval_states = []
  for i in range(dpw.num_states):
    bottom_left, top_right = dpw.get_bin_corners_by_bin_idx(i)
    mid_x = (bottom_left.x + top_right.x) / 2
    mid_y = (bottom_left.y + top_right.y) / 2
    eval_states.append([mid_x, mid_y])

  eval_states = jnp.asarray(eval_states, dtype=jnp.float32)

  def sample_states_continuous(
      key, num_samples):
    """Samples a random (x, y) coordinate."""
    sample_key, key = jax.random.split(key)
    samples = jax.random.uniform(
        sample_key, (num_samples, 2), dtype=jnp.float32)
    return samples, key

  def sample_states_discrete(
      key, num_samples):
    """Samples from the (x, y) coordinates at the center of a bin."""
    sample_key, key = jax.random.split(key)
    samples = jax.random.choice(
        sample_key, eval_states, (num_samples,))
    return samples, key

  if config.use_center_states_only:
    sample_states = sample_states_discrete
  else:
    sample_states = sample_states_continuous

  def compute_psi(
      states, tasks = None):
    # First, we get which column and row the x and y falls into, and then
    # clip to make sure the edge cases when x=1.0 or y=1.0 falls into a
    # valid bin.
    cols_and_rows = jnp.clip(
        jnp.floor(states * metadata['num_bins']),
        a_min=0,
        a_max=metadata['num_bins'] - 1)

    # Bin indices are assigned starting in the bottom left moving right, and
    # then advancing upwards after finishing each row.
    # e.g. in a 10x10 grid, row 2 col 3 corresponds to bin 13.
    bin_indices = (
        cols_and_rows[:, 0] + cols_and_rows[:, 1] * metadata['num_bins'])
    bin_indices = bin_indices.astype(jnp.int32)

    if tasks is None:
      return Psi[bin_indices, :]
    return Psi[bin_indices, tasks]

  class Module(nn.Module):

    @nn.compact
    def __call__(self, x):
      for _ in range(config.phi_hidden_layers):
        x = jax.nn.relu(nn.Dense(config.phi_hidden_layer_width)(x))
      return nn.Dense(config.d)(x)

  module = Module()
  params = module.init(
      network_key, jnp.zeros((10, 2), dtype=jnp.float32))
  compute_phi = module.apply

  return SyntheticExperiment(
      compute_phi=compute_phi,
      compute_psi=compute_psi,
      sample_states=sample_states,
      eval_states=eval_states,
      optimal_subspace=optimal_subspace,
      params=params,
      key=key,
  )