def make_demonstrations(env: dm_env.Environment, batch_size: int) -> tf.data.Dataset: """Prepare the dataset of demonstrations.""" batch_dataset = bsuite_demonstrations.make_dataset(env, stochastic=False) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=1, additional_discount=1.) dataset = batch_dataset.map(transition) # Batch and prefetch. dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_and_record_to_csv( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment = wrappers.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Construct the agent. agent = dqfd.DQfD( environment_spec=environment_spec, network=make_network(environment_spec.actions), demonstration_dataset=bsuite_demonstrations.make_dataset(raw_environment), demonstration_ratio=FLAGS.demonstration_ratio, samples_per_insert=FLAGS.samples_per_insert, learning_rate=FLAGS.learning_rate) # Run the environment loop. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error
def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_and_record_to_csv( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment = single_precision.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Build demonstration dataset. if hasattr(raw_environment, 'raw_env'): raw_environment = raw_environment.raw_env batch_dataset = bsuite_demonstrations.make_dataset(raw_environment, stochastic=False) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=1, additional_discount=1.) dataset = batch_dataset.map(transition) # Batch and prefetch. dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # Create the networks to optimize. policy_network = make_policy_network(environment_spec.actions) # If the agent is non-autoregressive use epsilon=0 which will be a greedy # policy. evaluator_network = snt.Sequential([ policy_network, lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), ]) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(policy_network, [environment_spec.observations]) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # Create the actor which defines how we take actions. evaluation_network = actors.FeedForwardActor(evaluator_network) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluation_network, counter=counter, logger=loggers.TerminalLogger( 'evaluation', time_delta=1.)) # The learner updates the parameters (and initializes them). learner = learning.BCLearner(network=policy_network, learning_rate=FLAGS.learning_rate, dataset=dataset, counter=learner_counter) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() learner_counter.increment(learner_steps=FLAGS.evaluate_every) eval_loop.run(FLAGS.evaluation_episodes)
def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_and_record_to_csv( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment = single_precision.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Build demonstration dataset. if hasattr(raw_environment, 'raw_env'): raw_environment = raw_environment.raw_env batch_dataset = bsuite_demonstrations.make_dataset(raw_environment) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=1, additional_discount=1.) dataset = batch_dataset.map(transition) # Batch and prefetch. dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = tfds.as_numpy(dataset) # Create the networks to optimize. policy_network = make_policy_network(environment_spec.actions) policy_network = hk.without_apply_rng(hk.transform(policy_network)) # If the agent is non-autoregressive use epsilon=0 which will be a greedy # policy. def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: action_values = policy_network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # The learner updates the parameters (and initializes them). learner = learning.BCLearner(network=policy_network, optimizer=optax.adam(FLAGS.learning_rate), obs_spec=environment.observation_spec(), dataset=dataset, counter=learner_counter, rng=hk.PRNGSequence(FLAGS.seed)) # Create the actor which defines how we take actions. variable_client = variable_utils.VariableClient(learner, '') evaluator = actors.FeedForwardActor(evaluator_network, variable_client=variable_client, rng=hk.PRNGSequence(FLAGS.seed)) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, counter=counter, logger=loggers.TerminalLogger( 'evaluation', time_delta=1.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() learner_counter.increment(learner_steps=FLAGS.evaluate_every) eval_loop.run(FLAGS.evaluation_episodes)