Exemplo n.º 1
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_from_id(FLAGS.bsuite_id)
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    evaluator_network = snt.Sequential([
        policy_network,
        lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
    ])

    # Ensure that we create the variables before proceeding (maybe not needed).
    tf2_utils.create_variables(policy_network, [environment_spec.observations])

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # Create the actor which defines how we take actions.
    evaluation_network = actors_tf2.FeedForwardActor(evaluator_network)

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluation_network,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 learning_rate=FLAGS.learning_rate,
                                 dataset=dataset,
                                 counter=learner_counter)

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)
Exemplo n.º 2
0
def load_offline_bsuite_dataset(
    bsuite_id: str,
    random_prob: float,
    path: str,
    batch_size: int,
    valid_batch_size: int,
    num_shards: int = 1,
    num_valid_shards: int = 1,
    num_threads: int = 1,
    single_precision_wrapper: bool = True,
    shuffle_buffer_size: int = 100000,
    shuffle: bool = True,
    repeat: bool = True
) -> Tuple[tf.data.Dataset, tf.data.Dataset, dm_env.Environment]:
    """Load bsuite offline dataset."""
    # Data file path format: {path}-?????-of-{num_shards:05d}
    # The dataset is not deterministic and not repeated if shuffle = False.
    environment = bsuite.load_from_id(bsuite_id)
    if single_precision_wrapper:
        environment = single_precision.SinglePrecisionWrapper(environment)
    if random_prob > 0.:
        environment = RandomActionWrapper(environment, random_prob)
    params = bsuite_offline_dataset.dataset_params(environment)
    if os.path.basename(path):
        path += '_'
    train_path = path + 'train'
    train_dataset = bsuite_offline_dataset.dataset(
        path=train_path,
        num_threads=num_threads,
        batch_size=batch_size,
        num_shards=num_shards,
        shuffle_buffer_size=shuffle_buffer_size,
        shuffle=shuffle,
        repeat=repeat,
        **params)
    valid_path = path + 'valid'
    valid_dataset = bsuite_offline_dataset.dataset(path=valid_path,
                                                   num_threads=num_threads,
                                                   batch_size=valid_batch_size,
                                                   num_shards=num_valid_shards,
                                                   shuffle=False,
                                                   repeat=False,
                                                   **params)
    return train_dataset, valid_dataset, environment
Exemplo n.º 3
0
def make_environment(task,
                     end_on_success,
                     max_episode_steps,
                     distance_fn,
                     goal_image,
                     baseline_distance=None,
                     eval_mode=False,
                     logdir=None,
                     counter=None,
                     record_every=100,
                     num_episodes_to_record=3):
    """Create the environment and its wrappers."""
    env = gym.make(task)
    env = gym_wrapper.GymWrapper(env)
    if end_on_success:
        env = env_wrappers.EndOnSuccessWrapper(env)
    env = wrappers.StepLimitWrapper(env, max_episode_steps)

    env = env_wrappers.ReshapeImageWrapper(env)
    if distance_fn.history_length > 1:
        env = wrappers.FrameStackingWrapper(env, distance_fn.history_length)
    env = env_wrappers.GoalConditionedWrapper(env, goal_image)
    env = env_wrappers.DistanceModelWrapper(
        env,
        distance_fn,
        max_episode_steps,
        baseline_distance,
        distance_reward_weight=FLAGS.distance_reward_weight,
        environment_reward_weight=FLAGS.environment_reward_weight)
    if FLAGS.use_true_distance:
        env = env_wrappers.RewardWrapper(env)
    if logdir:
        env = env_wrappers.RecordEpisodesWrapper(
            env,
            counter,
            logdir,
            record_every=record_every,
            num_to_record=num_episodes_to_record,
            eval_mode=eval_mode)
    env = env_wrappers.VisibleStateWrapper(env, eval_mode)

    return single_precision.SinglePrecisionWrapper(env)
Exemplo n.º 4
0
def make_environment(training: bool = True):
    del training
    env = bsuite.load(experiment_name='deep_sea', kwargs={'size': 10})
    return single_precision.SinglePrecisionWrapper(env)
Exemplo n.º 5
0
def main(_):
    # Create an environment and grab the spec.
    raw_environment = bsuite.load_and_record_to_csv(
        bsuite_id=FLAGS.bsuite_id,
        results_dir=FLAGS.results_dir,
        overwrite=FLAGS.overwrite,
    )
    environment = single_precision.SinglePrecisionWrapper(raw_environment)
    environment_spec = specs.make_environment_spec(environment)

    # Build demonstration dataset.
    if hasattr(raw_environment, 'raw_env'):
        raw_environment = raw_environment.raw_env

    batch_dataset = bsuite_demonstrations.make_dataset(raw_environment)
    # Combine with demonstration dataset.
    transition = functools.partial(_n_step_transition_from_episode,
                                   n_step=1,
                                   additional_discount=1.)

    dataset = batch_dataset.map(transition)

    # Batch and prefetch.
    dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = tfds.as_numpy(dataset)

    # Create the networks to optimize.
    policy_network = make_policy_network(environment_spec.actions)
    policy_network = hk.without_apply_rng(hk.transform(policy_network))

    # If the agent is non-autoregressive use epsilon=0 which will be a greedy
    # policy.
    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        action_values = policy_network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values)

    counter = counting.Counter()
    learner_counter = counting.Counter(counter, prefix='learner')

    # The learner updates the parameters (and initializes them).
    learner = learning.BCLearner(network=policy_network,
                                 optimizer=optax.adam(FLAGS.learning_rate),
                                 obs_spec=environment.observation_spec(),
                                 dataset=dataset,
                                 counter=learner_counter,
                                 rng=hk.PRNGSequence(FLAGS.seed))

    # Create the actor which defines how we take actions.
    variable_client = variable_utils.VariableClient(learner, '')
    evaluator = actors.FeedForwardActor(evaluator_network,
                                        variable_client=variable_client,
                                        rng=hk.PRNGSequence(FLAGS.seed))

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     counter=counter,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=1.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        learner_counter.increment(learner_steps=FLAGS.evaluate_every)
        eval_loop.run(FLAGS.evaluation_episodes)