def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_from_id(FLAGS.bsuite_id) environment = single_precision.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Build demonstration dataset. if hasattr(raw_environment, 'raw_env'): raw_environment = raw_environment.raw_env batch_dataset = bsuite_demonstrations.make_dataset(raw_environment) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=1, additional_discount=1.) dataset = batch_dataset.map(transition) # Batch and prefetch. dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # Create the networks to optimize. policy_network = make_policy_network(environment_spec.actions) # If the agent is non-autoregressive use epsilon=0 which will be a greedy # policy. evaluator_network = snt.Sequential([ policy_network, lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), ]) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(policy_network, [environment_spec.observations]) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # Create the actor which defines how we take actions. evaluation_network = actors_tf2.FeedForwardActor(evaluator_network) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluation_network, counter=counter, logger=loggers.TerminalLogger( 'evaluation', time_delta=1.)) # The learner updates the parameters (and initializes them). learner = learning.BCLearner(network=policy_network, learning_rate=FLAGS.learning_rate, dataset=dataset, counter=learner_counter) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() learner_counter.increment(learner_steps=FLAGS.evaluate_every) eval_loop.run(FLAGS.evaluation_episodes)
def load_offline_bsuite_dataset( bsuite_id: str, random_prob: float, path: str, batch_size: int, valid_batch_size: int, num_shards: int = 1, num_valid_shards: int = 1, num_threads: int = 1, single_precision_wrapper: bool = True, shuffle_buffer_size: int = 100000, shuffle: bool = True, repeat: bool = True ) -> Tuple[tf.data.Dataset, tf.data.Dataset, dm_env.Environment]: """Load bsuite offline dataset.""" # Data file path format: {path}-?????-of-{num_shards:05d} # The dataset is not deterministic and not repeated if shuffle = False. environment = bsuite.load_from_id(bsuite_id) if single_precision_wrapper: environment = single_precision.SinglePrecisionWrapper(environment) if random_prob > 0.: environment = RandomActionWrapper(environment, random_prob) params = bsuite_offline_dataset.dataset_params(environment) if os.path.basename(path): path += '_' train_path = path + 'train' train_dataset = bsuite_offline_dataset.dataset( path=train_path, num_threads=num_threads, batch_size=batch_size, num_shards=num_shards, shuffle_buffer_size=shuffle_buffer_size, shuffle=shuffle, repeat=repeat, **params) valid_path = path + 'valid' valid_dataset = bsuite_offline_dataset.dataset(path=valid_path, num_threads=num_threads, batch_size=valid_batch_size, num_shards=num_valid_shards, shuffle=False, repeat=False, **params) return train_dataset, valid_dataset, environment
def make_environment(task, end_on_success, max_episode_steps, distance_fn, goal_image, baseline_distance=None, eval_mode=False, logdir=None, counter=None, record_every=100, num_episodes_to_record=3): """Create the environment and its wrappers.""" env = gym.make(task) env = gym_wrapper.GymWrapper(env) if end_on_success: env = env_wrappers.EndOnSuccessWrapper(env) env = wrappers.StepLimitWrapper(env, max_episode_steps) env = env_wrappers.ReshapeImageWrapper(env) if distance_fn.history_length > 1: env = wrappers.FrameStackingWrapper(env, distance_fn.history_length) env = env_wrappers.GoalConditionedWrapper(env, goal_image) env = env_wrappers.DistanceModelWrapper( env, distance_fn, max_episode_steps, baseline_distance, distance_reward_weight=FLAGS.distance_reward_weight, environment_reward_weight=FLAGS.environment_reward_weight) if FLAGS.use_true_distance: env = env_wrappers.RewardWrapper(env) if logdir: env = env_wrappers.RecordEpisodesWrapper( env, counter, logdir, record_every=record_every, num_to_record=num_episodes_to_record, eval_mode=eval_mode) env = env_wrappers.VisibleStateWrapper(env, eval_mode) return single_precision.SinglePrecisionWrapper(env)
def make_environment(training: bool = True): del training env = bsuite.load(experiment_name='deep_sea', kwargs={'size': 10}) return single_precision.SinglePrecisionWrapper(env)
def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_and_record_to_csv( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment = single_precision.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Build demonstration dataset. if hasattr(raw_environment, 'raw_env'): raw_environment = raw_environment.raw_env batch_dataset = bsuite_demonstrations.make_dataset(raw_environment) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=1, additional_discount=1.) dataset = batch_dataset.map(transition) # Batch and prefetch. dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = tfds.as_numpy(dataset) # Create the networks to optimize. policy_network = make_policy_network(environment_spec.actions) policy_network = hk.without_apply_rng(hk.transform(policy_network)) # If the agent is non-autoregressive use epsilon=0 which will be a greedy # policy. def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: action_values = policy_network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # The learner updates the parameters (and initializes them). learner = learning.BCLearner(network=policy_network, optimizer=optax.adam(FLAGS.learning_rate), obs_spec=environment.observation_spec(), dataset=dataset, counter=learner_counter, rng=hk.PRNGSequence(FLAGS.seed)) # Create the actor which defines how we take actions. variable_client = variable_utils.VariableClient(learner, '') evaluator = actors.FeedForwardActor(evaluator_network, variable_client=variable_client, rng=hk.PRNGSequence(FLAGS.seed)) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, counter=counter, logger=loggers.TerminalLogger( 'evaluation', time_delta=1.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() learner_counter.increment(learner_steps=FLAGS.evaluate_every) eval_loop.run(FLAGS.evaluation_episodes)