def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') puddles = arenas.get_arena(FLAGS.arena_name) pw = puddle_world.PuddleWorld(puddles=puddles, goal_position=geometry.Point((1.0, 1.0))) dpw = pw_utils.DiscretizedPuddleWorld(pw, FLAGS.num_bins) num_states = dpw.num_states # We want to avoid rounding errors when calculating shards, # so we use fractions. percent_work_to_complete = fractions.Fraction(1, FLAGS.num_shards) start_idx = int(FLAGS.shard_idx * percent_work_to_complete * num_states) end_idx = int( (FLAGS.shard_idx + 1) * percent_work_to_complete * num_states) logging.info('start idx: %d, end idx: %d', start_idx, end_idx) result_matrices = list() # TODO(joshgreaves): utils has helpful functions for generating rollouts. for start_state in range(start_idx, end_idx): logging.info('Starting iteration %d', start_state) result_matrix = np.zeros( (FLAGS.num_rollouts_per_start_state, num_states), dtype=np.float32) for i in range(FLAGS.num_rollouts_per_start_state): current_gamma = 1.0 s = dpw.sample_state_in_bin(start_state) for _ in range(FLAGS.rollout_length): action = random.randrange(pw_utils.NUM_ACTIONS) transition = dpw.transition(s, action) s = transition.next_state result_matrix[i, s.bin_idx] += current_gamma current_gamma *= FLAGS.gamma result_matrices.append(np.mean(result_matrix, axis=0)) # Before saving, make sure the path exists. output_dir = epath.Path(FLAGS.output_dir) output_dir.mkdir(exist_ok=True) if FLAGS.shard_idx == 0: # Write some metadata to make analysis easier at the end. metadata = { 'arena_name': FLAGS.arena_name, 'num_bins': FLAGS.num_bins, 'num_shards': FLAGS.num_shards, } json_file_path = output_dir / 'metadata.json' with json_file_path.open('w') as f: json.dump(metadata, f) file_path = output_dir / f'sr_{start_idx}-{end_idx}.np' with file_path.open('wb') as f: np.save(f, np.stack(result_matrices, axis=0))
def test_puddle_world_calculates_transition_with_multiple_slow_puddles( self): # This world has 2 circular ⚫️ slow puddles centered in the middle # of the arena. One had a radius of 0.25, and the other 0.1. We expect # that as we move through the larger one, we will be slowed to # 0.5 times base speed, and as we move through both we will be slowed to # 0.25 times base speed. 🐌 pw = puddle_world.PuddleWorld( puddles=(puddle_world.SlowPuddle( shape=puddle_world.circle(geometry.Point((0.5, 0.5)), 0.25)), puddle_world.SlowPuddle( shape=puddle_world.circle(geometry.Point(( 0.5, 0.5)), 0.1))), goal_position=geometry.Point((0.5, 0.5)), noise=0.0, thrust=0.5) # A large thrust to step over multiple circles. start_position = geometry.Point((0.2, 0.5)) transition = pw.transition(start_position, puddle_world.Action.RIGHT) # 0.5 total movement 🏃 # 0.05 spent getting to first circle's edge. # 0.3 spent getting to second circle's edge. # 0.15 remaining at 25% efficiency => 0.0375 into inner circle. expected_end_position = geometry.Point((0.4375, 0.5)) self.assertEqual(transition.state, start_position) self.assertPointsAlmostEqual(transition.next_state, expected_end_position)
def test_puddle_world_obeys_arena_boundaries(self, start_position, action): pw = puddle_world.PuddleWorld((), goal_position=geometry.Point((0.5, 0.5)), noise=0.0, thrust=0.1) transition = pw.transition(start_position, action) self.assertBetween(transition.next_state.x, 0.0, 1.0) self.assertBetween(transition.next_state.y, 0.0, 1.0)
def test_puddle_world_calculates_correct_transition_with_no_puddles( self, start_position, action, expected_end_position): pw = puddle_world.PuddleWorld(puddles=(), goal_position=geometry.Point((0.5, 0.5)), noise=0.0) transition = pw.transition(start_position, action) self.assertEqual(transition.state, start_position) self.assertPointsAlmostEqual(transition.next_state, expected_end_position)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') config = _CONFIG.value # Connect to reverb reverb_client = reverb.Client(FLAGS.reverb_address) puddles = arenas.get_arena(config.env.pw.arena) pw = puddle_world.PuddleWorld(puddles=puddles, goal_position=geometry.Point((1.0, 1.0))) dpw = pw_utils.DiscretizedPuddleWorld(pw, config.env.pw.num_bins) if FLAGS.eval: eval_worker(dpw, reverb_client, config) else: train_worker(dpw, reverb_client, config)
def test_puddle_world_correctly_applies_wall_puddles(self): pw = puddle_world.PuddleWorld( puddles=(puddle_world.SlowPuddle( shape=puddle_world.circle(geometry.Point((0.5, 0.5)), 0.25)), puddle_world.WallPuddle( shape=puddle_world.circle(geometry.Point(( 0.5, 0.5)), 0.1))), goal_position=geometry.Point((0.5, 0.5)), noise=0.0, thrust=0.5) # A large thrust to step over multiple circles. start_position = geometry.Point((0.2, 0.5)) transition = pw.transition(start_position, puddle_world.Action.RIGHT) # We should stop at the inner wall puddle. expected_end_position = geometry.Point((0.4, 0.5)) self.assertEqual(transition.state, start_position) self.assertPointsAlmostEqual(transition.next_state, expected_end_position)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Set up the arena. puddles = arenas.get_arena(_ARENA_NAME.value) pw = puddle_world.PuddleWorld(puddles=puddles, goal_position=geometry.Point((1.0, 1.0))) dpw = pw_utils.DiscretizedPuddleWorld(pw, _NUM_BINS.value) num_states = dpw.num_states # Work out which work this shard should do. start_idx, end_idx = _compute_start_and_end_indices( _SHARD_IDX.value, _NUM_SHARDS.value, num_states) logging.info('start idx: %d, end idx: %d', start_idx, end_idx) result_matrices = list() for start_state in range(start_idx, end_idx): logging.info('Starting iteration %d', start_state) result_matrix = np.zeros( (_NUM_ROLLOUTS_PER_START_STATE.value, num_states), dtype=np.float32) for i in range(_NUM_ROLLOUTS_PER_START_STATE.value): s = dpw.sample_state_in_bin(start_state) rollout = pw_utils.generate_rollout(dpw, _ROLLOUT_LENGTH.value, s) result_matrix[ i] = pw_utils.calculate_empricial_successor_representation( dpw, rollout, _GAMMA.value) result_matrices.append(np.mean(result_matrix, axis=0)) output_dir = epath.Path(_OUTPUT_DIR.value) _save_results(output_dir, _SHARD_IDX.value, _NUM_SHARDS.value, np.stack(result_matrices, axis=0)) _maybe_combine_shards(output_dir, _SHARD_IDX.value, _NUM_SHARDS.value)
def create_puddle_world_experiment( config): """Creates a Puddle World experiment.""" key = jax.random.PRNGKey(config.seed) network_key, key = jax.random.split(key) all_arenas = {'sutton_10', 'sutton_20', 'sutton_100'} if config.puddle_world_arena not in all_arenas: raise ValueError(f'Unknown arena {config.puddle_world_arena}.') if not config.puddle_world_path: raise ValueError('puddle_world_path wasn\'t supplied. Please pass a path.') path = epath.Path(config.puddle_world_path) path = path / config.puddle_world_arena with (path / 'metadata.json').open('r') as f: metadata = json.load(f) puddles = arenas.get_arena(metadata['arena_name']) pw = puddle_world.PuddleWorld( puddles, goal_position=geometry.Point((1.0, 1.0))) dpw = pw_utils.DiscretizedPuddleWorld(pw, metadata['num_bins']) with (path / 'sr.np').open('rb') as f: Psi = np.load(f) Psi = jnp.asarray(Psi, dtype=jnp.float32) with (path / 'svd.np').open('rb') as f: optimal_subspace = np.load(f) optimal_subspace = jnp.asarray(optimal_subspace[:, :config.d]) if config.T != Psi.shape[1]: logging.warning( 'Num tasks T (%d) does not match columns of Psi (%d). Overwriting.', config.T, Psi.shape[1]) config.T = Psi.shape[1] # Ensure we don't use the tabular gradient, since we will always use # neural networks with puddle world experiments. config.use_tabular_gradient = False eval_states = [] for i in range(dpw.num_states): bottom_left, top_right = dpw.get_bin_corners_by_bin_idx(i) mid_x = (bottom_left.x + top_right.x) / 2 mid_y = (bottom_left.y + top_right.y) / 2 eval_states.append([mid_x, mid_y]) eval_states = jnp.asarray(eval_states, dtype=jnp.float32) def sample_states_continuous( key, num_samples): """Samples a random (x, y) coordinate.""" sample_key, key = jax.random.split(key) samples = jax.random.uniform( sample_key, (num_samples, 2), dtype=jnp.float32) return samples, key def sample_states_discrete( key, num_samples): """Samples from the (x, y) coordinates at the center of a bin.""" sample_key, key = jax.random.split(key) samples = jax.random.choice( sample_key, eval_states, (num_samples,)) return samples, key if config.use_center_states_only: sample_states = sample_states_discrete else: sample_states = sample_states_continuous def compute_psi( states, tasks = None): # First, we get which column and row the x and y falls into, and then # clip to make sure the edge cases when x=1.0 or y=1.0 falls into a # valid bin. cols_and_rows = jnp.clip( jnp.floor(states * metadata['num_bins']), a_min=0, a_max=metadata['num_bins'] - 1) # Bin indices are assigned starting in the bottom left moving right, and # then advancing upwards after finishing each row. # e.g. in a 10x10 grid, row 2 col 3 corresponds to bin 13. bin_indices = ( cols_and_rows[:, 0] + cols_and_rows[:, 1] * metadata['num_bins']) bin_indices = bin_indices.astype(jnp.int32) if tasks is None: return Psi[bin_indices, :] return Psi[bin_indices, tasks] class Module(nn.Module): @nn.compact def __call__(self, x): for _ in range(config.phi_hidden_layers): x = jax.nn.relu(nn.Dense(config.phi_hidden_layer_width)(x)) return nn.Dense(config.d)(x) module = Module() params = module.init( network_key, jnp.zeros((10, 2), dtype=jnp.float32)) compute_phi = module.apply return SyntheticExperiment( compute_phi=compute_phi, compute_psi=compute_psi, sample_states=sample_states, eval_states=eval_states, optimal_subspace=optimal_subspace, params=params, key=key, )