def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_and_record_to_csv( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment = wrappers.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Create the networks to optimize. network = make_network(environment_spec.actions) agent = impala.IMPALA( environment_spec=environment_spec, network=network, sequence_length=3, sequence_period=3, ) # Run the environment loop. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=environment.bsuite_num_episodes) # pytype: disable=attribute-error
def test_ddpg(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks to optimize (online) and target networks. agent_networks = make_networks(spec.actions) # Construct the agent. agent = ddpg.DDPG( environment_spec=spec, policy_network=agent_networks['policy'], critic_network=agent_networks['critic'], batch_size=10, samples_per_insert=2, min_replay_size=10, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=2)
def main(_): env = helpers.make_environment(level=FLAGS.level, oar_wrapper=True) env_spec = acme.make_environment_spec(env) config = impala.IMPALAConfig( batch_size=16, sequence_period=10, seed=FLAGS.seed, ) networks = impala.make_atari_networks(env_spec) agent = impala.IMPALAFromConfig( environment_spec=env_spec, forward_fn=networks.forward_fn, unroll_init_fn=networks.unroll_init_fn, unroll_fn=networks.unroll_fn, initial_state_init_fn=networks.initial_state_init_fn, initial_state_fn=networks.initial_state_fn, config=config, ) loop = acme.EnvironmentLoop(env, agent) loop.run(FLAGS.num_episodes)
def test_impala(self): # Create a fake environment to test with. environment = fakes.DiscreteEnvironment( num_actions=5, num_observations=10, obs_shape=(10, 5), obs_dtype=np.float32, episode_length=10) spec = specs.make_environment_spec(environment) def forward_fn(x, s): model = MyNetwork(spec.actions.num_values) return model(x, s) def initial_state_fn(batch_size: Optional[int] = None): model = MyNetwork(spec.actions.num_values) return model.initial_state(batch_size) def unroll_fn(inputs, state): model = MyNetwork(spec.actions.num_values) return hk.static_unroll(model, inputs, state) # Construct the agent. agent = impala.IMPALA( environment_spec=spec, forward_fn=forward_fn, initial_state_fn=initial_state_fn, unroll_fn=unroll_fn, sequence_length=3, sequence_period=3, batch_size=6, ) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent) loop.run(num_episodes=20)
def test_td3(self): # Create a fake environment to test with. environment = fakes.ContinuousEnvironment(episode_length=10, action_dim=3, observation_dim=5, bounded=True) spec = specs.make_environment_spec(environment) # Create the networks. network = td3.make_networks(spec) config = td3.TD3Config(batch_size=10, min_replay_size=1) counter = counting.Counter() agent = td3.TD3(spec=spec, network=network, config=config, seed=0, counter=counter) # Try running the environment loop. We have no assertions here because all # we care about is that the agent runs without raising any errors. loop = acme.EnvironmentLoop(environment, agent, counter=counter) loop.run(num_episodes=2)
def actor( self, replay: reverb.Client, variable_source: acme.VariableSource, counter: counting.Counter, ) -> acme.EnvironmentLoop: """The actor process.""" # Create the behavior policy. networks = self._network_factory(self._environment_spec.actions) networks.init(self._environment_spec) policy_network = networks.make_policy( environment_spec=self._environment_spec, sigma=self._sigma, ) # Create the agent. actor = self._builder.make_actor( policy_network=policy_network, adder=self._builder.make_adder(replay), variable_source=variable_source, ) # Create the environment. environment = self._environment_factory(False) # Create logger and counter; actors will not spam bigtable. counter = counting.Counter(counter, 'actor') logger = loggers.make_default_logger( 'actor', save_data=False, time_delta=self._log_every, steps_key='actor_steps') # Create the loop to connect environment and agent. return acme.EnvironmentLoop(environment, actor, counter, logger)
def main(_): # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=_ENV_NAME.value) spec = specs.make_environment_spec(environment) key = jax.random.PRNGKey(_SEED.value) key, dataset_key, evaluator_key = jax.random.split(key, 3) # Load the dataset. dataset = tensorflow_datasets.load(_DATASET_NAME.value)['train'] # Unwrap the environment to get the demonstrations. dataset = mbop.episodes_to_timestep_batched_transitions(dataset, return_horizon=10) dataset = tfds.JaxInMemoryRandomSampleIterator( dataset, key=dataset_key, batch_size=_BATCH_SIZE.value) # Apply normalization to the dataset. mean_std = mbop.get_normalization_stats(dataset, _NUM_NORMALIZATION_BATCHES.value) apply_normalization = jax.jit( functools.partial(running_statistics.normalize, mean_std=mean_std)) dataset = (apply_normalization(sample) for sample in dataset) # Create the networks. networks = mbop.make_networks(spec, hidden_layer_sizes=tuple( _HIDDEN_LAYER_SIZES.value)) # Use the default losses. losses = mbop.MBOPLosses() def logger_fn(label: str, steps_key: str): return loggers.make_default_logger(label, steps_key=steps_key) def make_learner(name, logger_fn, counter, rng_key, dataset, network, loss): return mbop.make_ensemble_regressor_learner( name, _NUM_NETWORKS.value, logger_fn, counter, rng_key, dataset, network, loss, optax.adam(_LEARNING_RATE.value), _NUM_SGD_STEPS_PER_STEP.value, ) learner = mbop.MBOPLearner( networks, losses, dataset, key, logger_fn, functools.partial(make_learner, 'world_model'), functools.partial(make_learner, 'policy_prior'), functools.partial(make_learner, 'n_step_return')) planning_config = mbop.MPPIConfig() assert planning_config.n_trajectories % _NUM_NETWORKS.value == 0, ( 'Number of trajectories must be a multiple of the number of networks.') actor_core = mbop.make_ensemble_actor_core(networks, planning_config, spec, mean_std, use_round_robin=False) evaluator = mbop.make_actor(actor_core, evaluator_key, learner) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Train the agent. while True: for _ in range(_EVALUATE_EVERY.value): learner.step() eval_loop.run(_EVALUATION_EPISODES.value)
def train_and_evaluate(distance_fn, rng): """Train a policy on the learned distance function and evaluate task success. Args: distance_fn: function mapping a (state, goal)-pair to a state embedding and a distance estimate used for policy learning. rng: random key used to initialize evaluation actor. """ goal_image = load_goal_image(FLAGS.robot_data_path) logdir = FLAGS.logdir video_dir = paths.process_path(logdir, 'videos') print('Writing videos to', video_dir) counter = counting.Counter() eval_counter = counting.Counter(counter, prefix='eval', time_delta=0.0) # Include training episodes and steps and walltime in the first eval logs. counter.increment(episodes=0, steps=0, walltime=0) environment = make_environment( task=FLAGS.task, end_on_success=FLAGS.end_on_success, max_episode_steps=FLAGS.max_episode_steps, distance_fn=distance_fn, goal_image=goal_image, baseline_distance=FLAGS.baseline_distance, logdir=video_dir, counter=counter, record_every=FLAGS.record_episodes_frequency, num_episodes_to_record=FLAGS.num_episodes_to_record) environment_spec = specs.make_environment_spec(environment) print('Environment spec') print(environment_spec) agent_networks = sac.make_networks(environment_spec) config = sac.SACConfig( target_entropy=sac.target_entropy_from_env_spec(environment_spec), num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step, min_replay_size=FLAGS.min_replay_size) agent = deprecated_sac.SAC(environment_spec, agent_networks, config=config, counter=counter, seed=FLAGS.seed) env_logger = loggers.CSVLogger(logdir, 'env_loop', flush_every=5) eval_env_logger = loggers.CSVLogger(logdir, 'eval_env_loop', flush_every=1) train_loop = acme.EnvironmentLoop(environment, agent, label='train_loop', logger=env_logger, counter=counter) eval_actor = agent.builder.make_actor(random_key=rng, policy=sac.apply_policy_and_sample( agent_networks, eval_mode=True), environment_spec=environment_spec, variable_source=agent) eval_video_dir = paths.process_path(logdir, 'eval_videos') print('Writing eval videos to', eval_video_dir) if FLAGS.baseline_distance_from_goal_to_goal: state = goal_image if distance_fn.history_length > 1: state = np.stack([goal_image] * distance_fn.history_length, axis=-1) unused_embeddings, baseline_distance = distance_fn(state, goal_image) print('Baseline prediction', baseline_distance) else: baseline_distance = FLAGS.baseline_distance eval_env = make_environment(task=FLAGS.task, end_on_success=False, max_episode_steps=FLAGS.max_episode_steps, distance_fn=distance_fn, goal_image=goal_image, eval_mode=True, logdir=eval_video_dir, counter=eval_counter, record_every=FLAGS.num_eval_episodes, num_episodes_to_record=FLAGS.num_eval_episodes, baseline_distance=baseline_distance) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, label='eval_loop', logger=eval_env_logger, counter=eval_counter) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=FLAGS.num_eval_episodes) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=FLAGS.num_eval_episodes)
def run_experiment(experiment: config.ExperimentConfig, eval_every: int = 100, num_eval_episodes: int = 1): """Runs a simple, single-threaded training loop using the default evaluators. It targets simplicity of the code and so only the basic features of the ExperimentConfig are supported. Arguments: experiment: Definition and configuration of the agent to run. eval_every: After how many actor steps to perform evaluation. num_eval_episodes: How many evaluation episodes to execute at each evaluation step. """ key = jax.random.PRNGKey(experiment.seed) # Create the environment and get its spec. environment = experiment.environment_factory(experiment.seed) environment_spec = experiment.environment_spec or specs.make_environment_spec( environment) # Create the networks and policy. networks = experiment.network_factory(environment_spec) policy = config.make_policy( experiment=experiment, networks=networks, environment_spec=environment_spec, evaluation=False) # Create the replay server and grab its address. replay_tables = experiment.builder.make_replay_tables(environment_spec, policy) # Disable blocking of inserts by tables' rate limiters, as this function # executes learning (sampling from the table) and data generation # (inserting into the table) sequentially from the same thread # which could result in blocked insert making the algorithm hang. replay_tables, rate_limiters_max_diff = _disable_insert_blocking( replay_tables) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') # Parent counter allows to share step counts between train and eval loops and # the learner, so that it is possible to plot for example evaluator's return # value as a function of the number of training episodes. parent_counter = counting.Counter(time_delta=0.) # Create actor, and learner for generating, storing, and consuming # data respectively. dataset = experiment.builder.make_dataset_iterator(replay_client) # We always use prefetch, as it provides an iterator with additional # 'ready' method. dataset = utils.prefetch(dataset, buffer_size=1) learner_key, key = jax.random.split(key) learner = experiment.builder.make_learner( random_key=learner_key, networks=networks, dataset=dataset, logger_fn=experiment.logger_factory, environment_spec=environment_spec, replay_client=replay_client, counter=counting.Counter(parent_counter, prefix='learner', time_delta=0.)) adder = experiment.builder.make_adder(replay_client, environment_spec, policy) actor_key, key = jax.random.split(key) actor = experiment.builder.make_actor( actor_key, policy, environment_spec, variable_source=learner, adder=adder) # Create the environment loop used for training. train_counter = counting.Counter( parent_counter, prefix='train', time_delta=0.) train_logger = experiment.logger_factory('train', train_counter.get_steps_key(), 0) # Replace the actor with a LearningActor. This makes sure that every time # that `update` is called on the actor it checks to see whether there is # any new data to learn from and if so it runs a learner step. The rate # at which new data is released is controlled by the replay table's # rate_limiter which is created by the builder.make_replay_tables call above. actor = _LearningActor(actor, learner, dataset, replay_tables, rate_limiters_max_diff) train_loop = acme.EnvironmentLoop( environment, actor, counter=train_counter, logger=train_logger, observers=experiment.observers) if num_eval_episodes == 0: # No evaluation. Just run the training loop. train_loop.run(num_steps=experiment.max_num_actor_steps) return # Create the evaluation actor and loop. eval_counter = counting.Counter(parent_counter, prefix='eval', time_delta=0.) eval_logger = experiment.logger_factory('eval', eval_counter.get_steps_key(), 0) eval_policy = config.make_policy( experiment=experiment, networks=networks, environment_spec=environment_spec, evaluation=True) eval_actor = experiment.builder.make_actor( random_key=jax.random.PRNGKey(experiment.seed), policy=eval_policy, environment_spec=environment_spec, variable_source=learner) eval_loop = acme.EnvironmentLoop( environment, eval_actor, counter=eval_counter, logger=eval_logger, observers=experiment.observers) steps = 0 while steps < experiment.max_num_actor_steps: eval_loop.run(num_episodes=num_eval_episodes) steps += train_loop.run(num_steps=eval_every) eval_loop.run(num_episodes=num_eval_episodes)
def main(_): # TODO(yutian): Create environment. # # Create an environment and grab the spec. # raw_environment = bsuite.load_and_record_to_csv( # bsuite_id=FLAGS.bsuite_id, # results_dir=FLAGS.results_dir, # overwrite=FLAGS.overwrite, # ) # environment = single_precision.SinglePrecisionWrapper(raw_environment) # environment_spec = specs.make_environment_spec(environment) # TODO(yutian): Create dataset. # Build the dataset. # if hasattr(raw_environment, 'raw_env'): # raw_environment = raw_environment.raw_env # # batch_dataset = bsuite_demonstrations.make_dataset(raw_environment) # # Combine with demonstration dataset. # transition = functools.partial( # _n_step_transition_from_episode, n_step=1, additional_discount=1.) # # dataset = batch_dataset.map(transition) # # # Batch and prefetch. # dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) # dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # Create the networks to optimize. networks = make_networks(environment_spec.actions) treatment_net = networks['treatment_net'] instrumental_net = networks['instrumental_net'] policy_net = networks['policy_net'] # If the agent is non-autoregressive use epsilon=0 which will be a greedy # policy. evaluator_net = snt.Sequential([ policy_net, # Sample actions. acme_nets.StochasticSamplingHead() ]) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(policy_net, [environment_spec.observations]) # TODO(liyuan): set the proper input spec using environment_spec.observations # and environment_spec.actions. tf2_utils.create_variables(treatment_net, [environment_spec.observations]) tf2_utils.create_variables( instrumental_net, [environment_spec.observations, environment_spec.actions]) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # Create the actor which defines how we take actions. evaluator_net = actors.FeedForwardActor(evaluator_net) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator_net, counter=counter, logger=loggers.TerminalLogger( 'evaluation', time_delta=1.)) # The learner updates the parameters (and initializes them). learner = learning.DFIVLearner( treatment_net=treatment_net, instrumental_net=instrumental_net, policy_net=policy_net, treatment_learning_rate=FLAGS.treatment_learning_rate, instrumental_learning_rate=FLAGS.instrumental_learning_rate, policy_learning_rate=FLAGS.policy_learning_rate, dataset=dataset, counter=learner_counter) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() learner_counter.increment(learner_steps=FLAGS.evaluate_every) eval_loop.run(FLAGS.evaluation_episodes)
def main(_): # Create an environment, grab the spec. environment = utils.make_environment(task=FLAGS.env_name) aqua_config = config.AquademConfig() spec = specs.make_environment_spec(environment) discretized_spec = aquadem_builder.discretize_spec(spec, aqua_config.num_actions) # Create AQuaDem builder. loss_fn = dqn.losses.MunchausenQLearning(max_abs_reward=100.) dqn_config = dqn.DQNConfig(min_replay_size=1000, n_step=3, num_sgd_steps_per_step=8, learning_rate=1e-4, samples_per_insert=256) rl_agent = dqn.DQNBuilder(config=dqn_config, loss_fn=loss_fn) make_demonstrations = utils.get_make_demonstrations_fn( FLAGS.env_name, FLAGS.num_demonstrations, FLAGS.seed) builder = aquadem_builder.AquademBuilder( rl_agent=rl_agent, config=aqua_config, make_demonstrations=make_demonstrations) # Create networks. q_network = aquadem_networks.make_q_network(spec=discretized_spec, ) dqn_networks = dqn.DQNNetworks( policy_network=networks_lib.non_stochastic_network_to_typed(q_network)) networks = aquadem_networks.make_action_candidates_network( spec=spec, num_actions=aqua_config.num_actions, discrete_rl_networks=dqn_networks) exploration_epsilon = 0.01 discrete_policy = dqn.default_behavior_policy(dqn_networks, exploration_epsilon) behavior_policy = aquadem_builder.get_aquadem_policy( discrete_policy, networks) # Create the environment loop used for training. agent = local_layout.LocalLayout(seed=FLAGS.seed, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, batch_size=dqn_config.batch_size * dqn_config.num_sgd_steps_per_step) train_logger = loggers.CSVLogger(FLAGS.workdir, label='train') train_loop = acme.EnvironmentLoop(environment, agent, logger=train_logger) # Create the evaluation actor and loop. eval_policy = dqn.default_behavior_policy(dqn_networks, 0.) eval_policy = aquadem_builder.get_aquadem_policy(eval_policy, networks) eval_actor = builder.make_actor(random_key=jax.random.PRNGKey(FLAGS.seed), policy=eval_policy, environment_spec=spec, variable_source=agent) eval_env = utils.make_environment(task=FLAGS.env_name, evaluation=True) eval_logger = loggers.CSVLogger(FLAGS.workdir, label='eval') eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, logger=eval_logger) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=10) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=10)
def test_ail_flax(self): shutil.rmtree(flags.FLAGS.test_tmpdir) batch_size = 8 # Mujoco environment and associated demonstration dataset. environment = fakes.ContinuousEnvironment( episode_length=EPISODE_LENGTH, action_dim=CONTINUOUS_ACTION_DIM, observation_dim=CONTINUOUS_OBS_DIM, bounded=True) spec = specs.make_environment_spec(environment) networks = sac.make_networks(spec=spec) config = sac.SACConfig(batch_size=batch_size, samples_per_insert_tolerance_rate=float('inf'), min_replay_size=1) base_builder = sac.SACBuilder(config=config) direct_rl_batch_size = batch_size behavior_policy = sac.apply_policy_and_sample(networks) discriminator_module = DiscriminatorModule(spec, linen.Dense(1)) def apply_fn(params: networks_lib.Params, policy_params: networks_lib.Params, state: networks_lib.Params, transitions: types.Transition, is_training: bool, rng: networks_lib.PRNGKey) -> networks_lib.Logits: del policy_params variables = dict(params=params, **state) return discriminator_module.apply(variables, transitions.observation, transitions.action, transitions.next_observation, is_training=is_training, rng=rng, mutable=state.keys()) def init_fn(rng): variables = discriminator_module.init(rng, dummy_obs, dummy_actions, dummy_obs, is_training=False, rng=rng) init_state, discriminator_params = variables.pop('params') return discriminator_params, init_state dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) dummy_actions = utils.zeros_like(spec.actions) dummy_actions = utils.add_batch_dim(dummy_actions) discriminator_network = networks_lib.FeedForwardNetwork(init=init_fn, apply=apply_fn) networks = ail.AILNetworks(discriminator_network, lambda x: x, networks) builder = ail.AILBuilder( base_builder, config=ail.AILConfig(is_sequence_based=False, share_iterator=True, direct_rl_batch_size=direct_rl_batch_size, discriminator_batch_size=2, policy_variable_name=None, min_replay_size=1), discriminator_loss=ail.losses.gail_loss(), make_demonstrations=fakes.transition_iterator(environment)) counter = counting.Counter() # Construct the agent. agent = local_layout.LocalLayout( seed=0, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, min_replay_size=1, batch_size=batch_size, counter=counter, ) # Train the agent. train_loop = acme.EnvironmentLoop(environment, agent, counter=counter) train_loop.run(num_episodes=1)
def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_and_record_to_csv( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment = single_precision.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Build demonstration dataset. if hasattr(raw_environment, 'raw_env'): raw_environment = raw_environment.raw_env batch_dataset = bsuite_demonstrations.make_dataset(raw_environment, stochastic=False) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=1, additional_discount=1.) dataset = batch_dataset.map(transition) # Batch and prefetch. dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) # Create the networks to optimize. policy_network = make_policy_network(environment_spec.actions) # If the agent is non-autoregressive use epsilon=0 which will be a greedy # policy. evaluator_network = snt.Sequential([ policy_network, lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), ]) # Ensure that we create the variables before proceeding (maybe not needed). tf2_utils.create_variables(policy_network, [environment_spec.observations]) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # Create the actor which defines how we take actions. evaluation_network = actors.FeedForwardActor(evaluator_network) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluation_network, counter=counter, logger=loggers.TerminalLogger( 'evaluation', time_delta=1.)) # The learner updates the parameters (and initializes them). learner = learning.BCLearner(network=policy_network, learning_rate=FLAGS.learning_rate, dataset=dataset, counter=learner_counter) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() learner_counter.increment(learner_steps=FLAGS.evaluate_every) eval_loop.run(FLAGS.evaluation_episodes)
def pong_experiment(): seeds = [0, 42, 69, 360, 420] seed = seeds[0] start_time = time.strftime("%Y-%m-%d_%H-%M-%S") # setting torch random seed here to make sure random initialization is the same torch.random.manual_seed(seed) # creating the environment env_name = "PongNoFrameskip-v4" env_train = make_environment_atari(env_name, seed) env_test = make_environment_atari(env_name, seed) env_train_spec = acme.make_environment_spec(env_train) # creating the neural network network = PositionNetworkSingleHead( env_train_spec.observations[0].shape, env_train_spec.observations[1].shape[0] * env_train_spec.actions[1].num_values, env_train_spec.actions[0].num_values * env_train_spec.actions[1].num_values) # creating the logger training_logger = TensorBoardLogger("runs/DQN-train-" + env_name + f"-rnd{seed}-" + start_time) testing_logger = TensorBoardLogger("runs/DQN-test-" + env_name + f"-rnd{seed}-" + start_time) # creating the agent agent = VanillaPartialDQN(network, [ env_train_spec.actions[0].num_values, env_train_spec.actions[1].num_values ], training_logger, gradient_clipping=True, device='gpu', seed=seed, replay_start_size=100) training_loop = acme.EnvironmentLoop(env_train, agent, logger=training_logger) testing_loop = acme.EnvironmentLoop(env_test, agent, logger=testing_logger, should_update=False) for epoch in range(200): agent.training() training_loop.run(num_steps=250000) torch.save( network.state_dict(), "runs/DQN-train-" + env_name + f"-rnd{seed}-" + start_time + f"/ep{epoch}.model") agent.testing() testing_loop.run(num_episodes=30) training_logger.close() testing_logger.close()
def main(_): key = jax.random.PRNGKey(FLAGS.seed) key_demonstrations, key_learner = jax.random.split(key, 2) # Create an environment and grab the spec. environment = gym_helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Get a demonstrations dataset with next_actions extra. transitions = tfds.get_tfds_dataset(FLAGS.dataset_name, FLAGS.num_demonstrations) double_transitions = rlds.transformations.batch(transitions, size=2, shift=1, drop_remainder=True) transitions = double_transitions.map(_add_next_action_extras) demonstrations = tfds.JaxInMemoryRandomSampleIterator( transitions, key=key_demonstrations, batch_size=FLAGS.batch_size) # Create the networks to optimize. networks = crr.make_networks(environment_spec) # CRR policy loss function. policy_loss_coeff_fn = crr.policy_loss_coeff_advantage_exp # Create the learner. learner = crr.CRRLearner( networks=networks, random_key=key_learner, discount=FLAGS.discount, target_update_period=FLAGS.target_update_period, policy_loss_coeff_fn=policy_loss_coeff_fn, iterator=demonstrations, policy_optimizer=optax.adam(FLAGS.policy_learning_rate), critic_optimizer=optax.adam(FLAGS.critic_learning_rate), grad_updates_per_batch=FLAGS.grad_updates_per_batch, use_sarsa_target=FLAGS.use_sarsa_target) def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: dist_params = networks.policy_network.apply(params, observation) return networks.sample_eval(dist_params, key) actor_core = actor_core_lib.batched_feed_forward_to_actor_core( evaluator_network) variable_client = variable_utils.VariableClient(learner, 'policy', device='cpu') evaluator = actors.GenericActor(actor_core, key, variable_client, backend='cpu') eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, logger=loggers.TerminalLogger( 'evaluation', time_delta=0.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() eval_loop.run(FLAGS.evaluation_episodes)
def actor( self, replay: reverb.Client, variable_source: acme.VariableSource, counter: counting.Counter, actor_id: int, ) -> acme.EnvironmentLoop: """The actor process.""" action_spec = self._environment_spec.actions observation_spec = self._environment_spec.observations # Create environment and target networks to act with. environment = self._environment_factory(False) agent_networks = self._network_factory(action_spec) # Make sure observation network is defined. observation_network = agent_networks.get('observation', tf.identity) # Create a stochastic behavior policy. behavior_network = snt.Sequential([ observation_network, agent_networks['policy'], networks.StochasticSamplingHead(), ]) # Ensure network variables are created. tf2_utils.create_variables(behavior_network, [observation_spec]) policy_variables = {'policy': behavior_network.variables} # Create the variable client responsible for keeping the actor up-to-date. variable_client = tf2_variable_utils.VariableClient( variable_source, policy_variables, update_period=self._variable_update_period) # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. variable_client.update_and_wait() # Component to add things into replay. adder = adders.NStepTransitionAdder( client=replay, n_step=self._n_step, discount=self._additional_discount) # Create the agent. actor = actors.FeedForwardActor( policy_network=behavior_network, adder=adder, variable_client=variable_client) # Create logger and counter; only the first actor stores logs to bigtable. save_data = actor_id == 0 counter = counting.Counter(counter, 'actor') logger = loggers.make_default_logger( 'actor', save_data=save_data, time_delta=self._log_every, steps_key='actor_steps') observers = self._make_observers() if self._make_observers else () # Create the run loop and return it. return acme.EnvironmentLoop( environment, actor, counter, logger, observers=observers)
def main(_): # Create an environment and grab the spec. raw_environment = bsuite.load_and_record_to_csv( bsuite_id=FLAGS.bsuite_id, results_dir=FLAGS.results_dir, overwrite=FLAGS.overwrite, ) environment = single_precision.SinglePrecisionWrapper(raw_environment) environment_spec = specs.make_environment_spec(environment) # Build demonstration dataset. if hasattr(raw_environment, 'raw_env'): raw_environment = raw_environment.raw_env batch_dataset = bsuite_demonstrations.make_dataset(raw_environment) # Combine with demonstration dataset. transition = functools.partial(_n_step_transition_from_episode, n_step=1, additional_discount=1.) dataset = batch_dataset.map(transition) # Batch and prefetch. dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = tfds.as_numpy(dataset) # Create the networks to optimize. policy_network = make_policy_network(environment_spec.actions) policy_network = hk.without_apply_rng(hk.transform(policy_network)) # If the agent is non-autoregressive use epsilon=0 which will be a greedy # policy. def evaluator_network(params: hk.Params, key: jnp.DeviceArray, observation: jnp.DeviceArray) -> jnp.DeviceArray: action_values = policy_network.apply(params, observation) return rlax.epsilon_greedy(FLAGS.epsilon).sample(key, action_values) counter = counting.Counter() learner_counter = counting.Counter(counter, prefix='learner') # The learner updates the parameters (and initializes them). learner = learning.BCLearner(network=policy_network, optimizer=optax.adam(FLAGS.learning_rate), obs_spec=environment.observation_spec(), dataset=dataset, counter=learner_counter, rng=hk.PRNGSequence(FLAGS.seed)) # Create the actor which defines how we take actions. variable_client = variable_utils.VariableClient(learner, '') evaluator = actors.FeedForwardActor(evaluator_network, variable_client=variable_client, rng=hk.PRNGSequence(FLAGS.seed)) eval_loop = acme.EnvironmentLoop(environment=environment, actor=evaluator, counter=counter, logger=loggers.TerminalLogger( 'evaluation', time_delta=1.)) # Run the environment loop. while True: for _ in range(FLAGS.evaluate_every): learner.step() learner_counter.increment(learner_steps=FLAGS.evaluate_every) eval_loop.run(FLAGS.evaluation_episodes)
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) agent_networks = ppo.make_continuous_networks(environment_spec) # Construct the agent. ppo_config = ppo.PPOConfig(unroll_length=FLAGS.unroll_length, num_minibatches=FLAGS.ppo_num_minibatches, num_epochs=FLAGS.ppo_num_epochs, batch_size=FLAGS.transition_batch_size // FLAGS.unroll_length, learning_rate=0.0003, entropy_cost=0, gae_lambda=0.8, value_cost=0.25) ppo_networks = ppo.make_continuous_networks(environment_spec) if FLAGS.pretrain: ppo_networks = add_bc_pretraining(ppo_networks) discriminator_batch_size = FLAGS.transition_batch_size ail_config = ail.AILConfig( direct_rl_batch_size=ppo_config.batch_size * ppo_config.unroll_length, discriminator_batch_size=discriminator_batch_size, is_sequence_based=True, num_sgd_steps_per_step=FLAGS.num_discriminator_steps_per_step, share_iterator=FLAGS.share_iterator, ) def discriminator(*args, **kwargs) -> networks_lib.Logits: # Note: observation embedding is not needed for e.g. Mujoco. return ail.DiscriminatorModule( environment_spec=environment_spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP([4, 4], ), )(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) ail_network = ail.AILNetworks( ail.make_discriminator(environment_spec, discriminator_transformed), imitation_reward_fn=ail.rewards.gail_reward(), direct_rl_networks=ppo_networks) agent = ail.GAIL(spec=environment_spec, network=ail_network, config=ail.GAILConfig(ail_config, ppo_config), seed=FLAGS.seed, batch_size=ppo_config.batch_size, make_demonstrations=functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name), policy_network=ppo.make_inference_fn(ppo_networks)) # Create the environment loop used for training. train_logger = experiment_utils.make_experiment_logger( label='train', steps_key='train_steps') train_loop = acme.EnvironmentLoop(environment, agent, counter=counting.Counter(prefix='train'), logger=train_logger) # Create the evaluation actor and loop. eval_logger = experiment_utils.make_experiment_logger( label='eval', steps_key='eval_steps') eval_actor = agent.builder.make_actor( random_key=jax.random.PRNGKey(FLAGS.seed), policy_network=ppo.make_inference_fn(agent_networks, evaluation=True), variable_source=agent) eval_env = helpers.make_environment(task=FLAGS.env_name) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, counter=counting.Counter(prefix='eval'), logger=eval_logger) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=5) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=5)
def main(_): # Create an environment, grab the spec, and use it to create networks. environment = helpers.make_environment(task=FLAGS.env_name) environment_spec = specs.make_environment_spec(environment) # Construct the agent. # Local layout makes sure that we populate the buffer with min_replay_size # initial transitions and that there's no need for tolerance_rate. In order # for deadlocks not to happen we need to disable rate limiting that heppens # inside the TD3Builder. This is achieved by the min_replay_size and # samples_per_insert_tolerance_rate arguments. td3_config = td3.TD3Config( num_sgd_steps_per_step=FLAGS.num_sgd_steps_per_step, min_replay_size=1, samples_per_insert_tolerance_rate=float('inf')) td3_networks = td3.make_networks(environment_spec) if FLAGS.pretrain: td3_networks = add_bc_pretraining(td3_networks) ail_config = ail.AILConfig(direct_rl_batch_size=td3_config.batch_size * td3_config.num_sgd_steps_per_step) dac_config = ail.DACConfig(ail_config, td3_config) def discriminator(*args, **kwargs) -> networks_lib.Logits: return ail.DiscriminatorModule(environment_spec=environment_spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP( [4, 4], ))(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) ail_network = ail.AILNetworks( ail.make_discriminator(environment_spec, discriminator_transformed), imitation_reward_fn=ail.rewards.gail_reward(), direct_rl_networks=td3_networks) agent = ail.DAC(spec=environment_spec, network=ail_network, config=dac_config, seed=FLAGS.seed, batch_size=td3_config.batch_size * td3_config.num_sgd_steps_per_step, make_demonstrations=functools.partial( helpers.make_demonstration_iterator, dataset_name=FLAGS.dataset_name), policy_network=td3.get_default_behavior_policy( td3_networks, action_specs=environment_spec.actions, sigma=td3_config.sigma)) # Create the environment loop used for training. train_logger = experiment_utils.make_experiment_logger( label='train', steps_key='train_steps') train_loop = acme.EnvironmentLoop(environment, agent, counter=counting.Counter(prefix='train'), logger=train_logger) # Create the evaluation actor and loop. # TODO(lukstafi): sigma=0 for eval? eval_logger = experiment_utils.make_experiment_logger( label='eval', steps_key='eval_steps') eval_actor = agent.builder.make_actor( random_key=jax.random.PRNGKey(FLAGS.seed), policy_network=td3.get_default_behavior_policy( td3_networks, action_specs=environment_spec.actions, sigma=0.), variable_source=agent) eval_env = helpers.make_environment(task=FLAGS.env_name) eval_loop = acme.EnvironmentLoop(eval_env, eval_actor, counter=counting.Counter(prefix='eval'), logger=eval_logger) assert FLAGS.num_steps % FLAGS.eval_every == 0 for _ in range(FLAGS.num_steps // FLAGS.eval_every): eval_loop.run(num_episodes=5) train_loop.run(num_steps=FLAGS.eval_every) eval_loop.run(num_episodes=5)
def run_offline_experiment(experiment: config.OfflineExperimentConfig, eval_every: int = 100, num_eval_episodes: int = 1): """Runs a simple, single-threaded training loop using the default evaluators. It targets simplicity of the code and so only the basic features of the OfflineExperimentConfig are supported. Arguments: experiment: Definition and configuration of the agent to run. eval_every: After how many learner steps to perform evaluation. num_eval_episodes: How many evaluation episodes to execute at each evaluation step. """ key = jax.random.PRNGKey(experiment.seed) # Create the environment and get its spec. environment = experiment.environment_factory(experiment.seed) environment_spec = experiment.environment_spec or specs.make_environment_spec( environment) # Create the networks and policy. networks = experiment.network_factory(environment_spec) # Parent counter allows to share step counts between train and eval loops and # the learner, so that it is possible to plot for example evaluator's return # value as a function of the number of training episodes. parent_counter = counting.Counter(time_delta=0.) # Create the demonstrations dataset. dataset_key, key = jax.random.split(key) dataset = experiment.demonstration_dataset_factory(dataset_key) # Create the learner. learner_key, key = jax.random.split(key) learner = experiment.builder.make_learner( random_key=learner_key, networks=networks, dataset=dataset, logger_fn=experiment.logger_factory, environment_spec=environment_spec, counter=counting.Counter(parent_counter, prefix='learner', time_delta=0.)) # Define the evaluation loop. eval_loop = None if num_eval_episodes > 0: # Create the evaluation actor and loop. eval_logger = experiment.logger_factory('eval', 'eval_steps', 0) eval_key, key = jax.random.split(key) eval_actor = experiment.builder.make_actor( random_key=eval_key, policy=experiment.builder.make_policy(networks, environment_spec, True), environment_spec=environment_spec, variable_source=learner) eval_loop = acme.EnvironmentLoop(environment, eval_actor, counter=counting.Counter( parent_counter, prefix='eval', time_delta=0.), logger=eval_logger, observers=experiment.observers) # Run the training loop. if eval_loop: eval_loop.run(num_eval_episodes) steps = 0 while steps < experiment.max_num_learner_steps: learner_steps = min(eval_every, experiment.max_num_learner_steps - steps) for _ in range(learner_steps): learner.step() if eval_loop: eval_loop.run(num_eval_episodes) steps += learner_steps
from acme.utils import loggers from acme.wrappers import gym_wrapper from agents.dqn_agent import DQNAgent from networks.models import Models from tensorflow.python.client import device_lib print(device_lib.list_local_devices()) def render(env): return env.environment.render(mode='rgb_array') environment = gym_wrapper.GymWrapper(gym.make('LunarLander-v2')) environment = wrappers.SinglePrecisionWrapper(environment) environment_spec = specs.make_environment_spec(environment) model = Models.sequential_model( input_shape=environment_spec.observations.shape, num_outputs=environment_spec.actions.num_values, hidden_layers=3, layer_size=300) agent = DQNAgent(environment_spec=environment_spec, network=model) logger = loggers.TerminalLogger(time_delta=10.) loop = acme.EnvironmentLoop(environment=environment, actor=agent) loop.run()
def test_ail(self, algo, airl_discriminator=False, subtract_logpi=False, dropout=0., lipschitz_coeff=None): shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True) batch_size = 8 # Mujoco environment and associated demonstration dataset. if algo == 'ppo': environment = fakes.DiscreteEnvironment( num_actions=NUM_DISCRETE_ACTIONS, num_observations=NUM_OBSERVATIONS, obs_shape=OBS_SHAPE, obs_dtype=OBS_DTYPE, episode_length=EPISODE_LENGTH) else: environment = fakes.ContinuousEnvironment( episode_length=EPISODE_LENGTH, action_dim=CONTINUOUS_ACTION_DIM, observation_dim=CONTINUOUS_OBS_DIM, bounded=True) spec = specs.make_environment_spec(environment) if algo == 'sac': networks = sac.make_networks(spec=spec) config = sac.SACConfig( batch_size=batch_size, samples_per_insert_tolerance_rate=float('inf'), min_replay_size=1) base_builder = sac.SACBuilder(config=config) direct_rl_batch_size = batch_size behavior_policy = sac.apply_policy_and_sample(networks) elif algo == 'ppo': unroll_length = 5 distribution_value_networks = make_ppo_networks(spec) networks = ppo.make_ppo_networks(distribution_value_networks) config = ppo.PPOConfig(unroll_length=unroll_length, num_minibatches=2, num_epochs=4, batch_size=batch_size) base_builder = ppo.PPOBuilder(config=config) direct_rl_batch_size = batch_size * unroll_length behavior_policy = jax.jit(ppo.make_inference_fn(networks), backend='cpu') else: raise ValueError(f'Unexpected algorithm {algo}') if subtract_logpi: assert algo == 'sac' logpi_fn = make_sac_logpi(networks) else: logpi_fn = None if algo == 'ppo': embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1]) else: embedding = lambda x: x def discriminator(*args, **kwargs) -> networks_lib.Logits: if airl_discriminator: return ail.AIRLModule( environment_spec=spec, use_action=True, use_next_obs=True, discount=.99, g_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), h_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) else: return ail.DiscriminatorModule( environment_spec=spec, use_action=True, use_next_obs=True, network_core=ail.DiscriminatorMLP( [4, 4], hidden_dropout_rate=dropout, spectral_normalization_lipschitz_coeff=lipschitz_coeff ), observation_embedding=embedding)(*args, **kwargs) discriminator_transformed = hk.without_apply_rng( hk.transform_with_state(discriminator)) discriminator_network = ail.make_discriminator( environment_spec=spec, discriminator_transformed=discriminator_transformed, logpi_fn=logpi_fn) networks = ail.AILNetworks(discriminator_network, lambda x: x, networks) builder = ail.AILBuilder( base_builder, config=ail.AILConfig( is_sequence_based=(algo == 'ppo'), share_iterator=True, direct_rl_batch_size=direct_rl_batch_size, discriminator_batch_size=2, policy_variable_name='policy' if subtract_logpi else None, min_replay_size=1), discriminator_loss=ail.losses.gail_loss(), make_demonstrations=fakes.transition_iterator(environment)) # Construct the agent. agent = local_layout.LocalLayout(seed=0, environment_spec=spec, builder=builder, networks=networks, policy_network=behavior_policy, min_replay_size=1, batch_size=batch_size) # Train the agent. train_loop = acme.EnvironmentLoop(environment, agent) train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))