Example #1
0
    def test_train(self, policy_loss_coeff_fn):
        seed = 0
        num_iterations = 5
        batch_size = 64
        grad_updates_per_batch = 1

        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  bounded=True,
                                                  action_dim=6)
        spec = specs.make_environment_spec(environment)

        # Construct the learner.
        networks = crr.make_networks(spec,
                                     policy_layer_sizes=(8, 8),
                                     critic_layer_sizes=(8, 8))
        key = jax.random.PRNGKey(seed)
        dataset = fakes.transition_iterator(environment)
        learner = crr.CRRLearner(networks,
                                 key,
                                 discount=0.95,
                                 target_update_period=2,
                                 policy_loss_coeff_fn=policy_loss_coeff_fn,
                                 iterator=dataset(batch_size *
                                                  grad_updates_per_batch),
                                 policy_optimizer=optax.adam(1e-4),
                                 critic_optimizer=optax.adam(1e-4),
                                 grad_updates_per_batch=grad_updates_per_batch)

        # Train the learner.
        for _ in range(num_iterations):
            learner.step()
Example #2
0
    def test_d4pg(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  action_dim=3,
                                                  observation_dim=5,
                                                  bounded=True)
        spec = specs.make_environment_spec(environment)

        # Create the networks.
        networks = make_networks(spec)

        config = d4pg.D4PGConfig(
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10,
            samples_per_insert_tolerance_rate=float('inf'))
        counter = counting.Counter()
        agent = d4pg.D4PG(spec,
                          networks,
                          config=config,
                          random_seed=0,
                          counter=counter)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=2)
Example #3
0
    def test_mompo(self, distributional_critic):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Create objectives.
        reward_objectives, qvalue_objectives = make_objectives()
        num_critic_heads = len(reward_objectives)

        # Create networks.
        agent_networks = make_networks(
            spec.actions,
            num_critic_heads=num_critic_heads,
            distributional_critic=distributional_critic)

        # Construct the agent.
        agent = mompo.MultiObjectiveMPO(
            reward_objectives,
            qvalue_objectives,
            spec,
            policy_network=agent_networks['policy'],
            critic_network=agent_networks['critic'],
            batch_size=10,
            samples_per_insert=2,
            min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
Example #4
0
    def test_make_dataset_with_sequence_length_and_batch_size(self):
        sequence_length = 6
        batch_size = 4
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client,
            environment_spec=environment_spec,
            batch_size=batch_size,
            sequence_length=sequence_length)

        def make_tensor_spec(spec):
            return tf.TensorSpec(shape=(
                batch_size,
                sequence_length,
            ) + spec.shape,
                                 dtype=spec.dtype)

        expected_spec = tree.map_structure(make_tensor_spec, environment_spec)

        expected_spec = adders.Step(observation=expected_spec.observations,
                                    action=expected_spec.actions,
                                    reward=expected_spec.rewards,
                                    discount=expected_spec.discounts,
                                    start_of_episode=specs.Array(
                                        shape=(batch_size, sequence_length),
                                        dtype=bool),
                                    extras=())

        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
Example #5
0
  def test_td3_fd(self):
    # Create a fake environment to test with.
    environment = fakes.ContinuousEnvironment(
        episode_length=10, action_dim=3, observation_dim=5, bounded=True)
    spec = specs.make_environment_spec(environment)

    # Create the networks.
    td3_network = td3.make_networks(spec)

    batch_size = 10
    td3_config = td3.TD3Config(
        batch_size=batch_size,
        min_replay_size=1)
    lfd_config = lfd.LfdConfig(initial_insert_count=0,
                               demonstration_ratio=0.2)
    td3_fd_config = lfd.TD3fDConfig(lfd_config=lfd_config,
                                    td3_config=td3_config)
    counter = counting.Counter()
    agent = lfd.TD3fD(
        spec=spec,
        td3_network=td3_network,
        td3_fd_config=td3_fd_config,
        lfd_iterator_fn=fake_demonstration_iterator,
        seed=0,
        counter=counter)

    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent, counter=counter)
    loop.run(num_episodes=20)
Example #6
0
    def test_train(self):
        seed = 0
        num_iterations = 2
        batch_size = 64

        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  bounded=True,
                                                  action_dim=6)
        spec = specs.make_environment_spec(environment)

        # Construct the agent.
        networks = cql.make_networks(spec, hidden_layer_sizes=(8, 8))
        dataset = fakes.transition_iterator(environment)
        key = jax.random.PRNGKey(seed)
        learner = cql.CQLLearner(batch_size,
                                 networks,
                                 key,
                                 demonstrations=dataset(batch_size),
                                 policy_optimizer=optax.adam(3e-5),
                                 critic_optimizer=optax.adam(3e-4),
                                 cql_lagrange_threshold=1.,
                                 target_entropy=0.1,
                                 num_sgd_steps_per_step=1)

        # Train the agent
        for _ in range(num_iterations):
            learner.step()
Example #7
0
    def test_value_dice(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10,
                                                  action_dim=3,
                                                  observation_dim=5,
                                                  bounded=True)

        spec = specs.make_environment_spec(environment)

        # Create the networks.
        network = value_dice.make_networks(spec)

        config = value_dice.ValueDiceConfig(batch_size=10, min_replay_size=1)
        counter = counting.Counter()
        agent = value_dice.ValueDice(
            spec=spec,
            network=network,
            config=config,
            make_demonstrations=fakes.transition_iterator(environment),
            seed=0,
            counter=counter)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        loop.run(num_episodes=2)
Example #8
0
    def test_make_dataset_simple(self):
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client, environment_spec=environment_spec)

        self.assertTrue(
            _check_specs(tuple(environment_spec), dataset.element_spec.data))
Example #9
0
    def test_make_dataset_transition_adder(self):
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client,
            environment_spec=environment_spec,
            transition_adder=True)

        environment_spec = tuple(environment_spec) + (
            environment_spec.observations, )

        self.assertTrue(
            _check_specs(tuple(environment_spec), dataset.element_spec.data))
Example #10
0
    def test_make_dataset_simple(self):
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client, environment_spec=environment_spec)

        expected_spec = adders.Step(observation=environment_spec.observations,
                                    action=environment_spec.actions,
                                    reward=environment_spec.rewards,
                                    discount=environment_spec.discounts,
                                    start_of_episode=specs.Array(shape=(),
                                                                 dtype=bool),
                                    extras=())
        self.assertTrue(_check_specs(expected_spec, dataset.element_spec.data))
Example #11
0
    def test_make_dataset_with_batch_size(self):
        batch_size = 4
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            client=self.tf_client,
            environment_spec=environment_spec,
            batch_size=batch_size)

        def make_tensor_spec(spec):
            return tf.TensorSpec(shape=(None, ) + spec.shape, dtype=spec.dtype)

        expected_spec = tree.map_structure(make_tensor_spec, environment_spec)

        self.assertTrue(
            _check_specs(tuple(expected_spec), dataset.element_spec.data))
Example #12
0
    def test_make_dataset_transition_adder(self):
        environment = fakes.ContinuousEnvironment()
        environment_spec = specs.make_environment_spec(environment)
        dataset = reverb_dataset.make_dataset(
            server_address=self.server_address,
            environment_spec=environment_spec,
            transition_adder=True)

        environment_spec = types.Transition(
            observation=environment_spec.observations,
            action=environment_spec.actions,
            reward=environment_spec.rewards,
            discount=environment_spec.discounts,
            next_observation=environment_spec.observations,
            extras=())

        self.assertTrue(
            _check_specs(environment_spec, dataset.element_spec.data))
Example #13
0
    def test_continuous_actions(self, loss_name):
        with chex.fake_pmap_and_jit():
            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.ContinuousEnvironment(episode_length=10,
                                                      bounded=True,
                                                      action_dim=6)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec)

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                  dist_params.log_prob(actions))
            elif loss_name == 'mse':
                loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params
                                 .sample(seed=key))
            elif loss_name == 'peerbc':
                base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                       dist_params.log_prob(actions))
                loss_fn = bc.peerbc(base_loss_fn, zeta=0.1)
            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()
  def test_continuous(self):
    env = wrappers.SinglePrecisionWrapper(
        fakes.ContinuousEnvironment(
            action_dim=0, dtype=np.float64, reward_dtype=np.float64))

    self.assertTrue(np.issubdtype(env.observation_spec().dtype, np.float32))
    self.assertTrue(np.issubdtype(env.action_spec().dtype, np.float32))
    self.assertTrue(np.issubdtype(env.reward_spec().dtype, np.float32))
    self.assertTrue(np.issubdtype(env.discount_spec().dtype, np.float32))

    timestep = env.reset()
    self.assertEqual(timestep.reward, None)
    self.assertEqual(timestep.discount, None)
    self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32))

    timestep = env.step(0.0)
    self.assertTrue(np.issubdtype(timestep.reward.dtype, np.float32))
    self.assertTrue(np.issubdtype(timestep.discount.dtype, np.float32))
    self.assertTrue(np.issubdtype(timestep.observation.dtype, np.float32))
Example #15
0
    def test_dmpo(self):
        # Create a fake environment to test with.
        environment = fakes.ContinuousEnvironment(episode_length=10)
        spec = specs.make_environment_spec(environment)

        # Create networks.
        agent_networks = make_networks(spec.actions)

        # Construct the agent.
        agent = dmpo.DistributionalMPO(spec,
                                       policy_network=agent_networks['policy'],
                                       critic_network=agent_networks['critic'],
                                       batch_size=10,
                                       samples_per_insert=2,
                                       min_replay_size=10)

        # Try running the environment loop. We have no assertions here because all
        # we care about is that the agent runs without raising any errors.
        loop = acme.EnvironmentLoop(environment, agent)
        loop.run(num_episodes=2)
Example #16
0
    def test_control_suite(self):
        """Tests that the agent can run on the control suite without crashing."""

        agent = svg0_prior.DistributedSVG0(
            environment_factory=lambda x: fakes.ContinuousEnvironment(),
            network_factory=make_networks,
            num_actors=2,
            batch_size=32,
            min_replay_size=32,
            max_replay_size=1000,
        )
        program = agent.build()

        (learner_node, ) = program.groups['learner']
        learner_node.disable_run()

        lp.launch(program, launch_type='test_mt')

        learner: acme.Learner = learner_node.create_handle().dereference()

        for _ in range(5):
            learner.step()
Example #17
0
    def test_agent(self):

        agent = mpo.DistributedMPO(
            environment_factory=lambda x: fakes.ContinuousEnvironment(bounded=
                                                                      True),
            network_factory=make_networks,
            num_actors=2,
            batch_size=32,
            min_replay_size=32,
            max_replay_size=1000,
        )
        program = agent.build()

        (learner_node, ) = program.groups['learner']
        learner_node.disable_run()

        lp.launch(program, launch_type='test_mt')

        learner: acme.Learner = learner_node.create_handle().dereference()

        for _ in range(5):
            learner.step()
Example #18
0
  def test_sac(self):
    # Create a fake environment to test with.
    environment = fakes.ContinuousEnvironment(
        episode_length=10, action_dim=3, observation_dim=5, bounded=True)
    spec = specs.make_environment_spec(environment)

    # Create the networks.
    network = networks.make_networks(spec)

    batch_size = 10
    config = sac_config.SACConfig(
        batch_size=batch_size,
        target_entropy=sac_config.target_entropy_from_env_spec(spec),
        min_replay_size=1)
    counter = counting.Counter()
    agent = agents.SAC(
        spec=spec, network=network, config=config, seed=0, normalize_input=True,
        counter=counter)

    # Try running the environment loop. We have no assertions here because all
    # we care about is that the agent runs without raising any errors.
    loop = acme.EnvironmentLoop(environment, agent, counter=counter)
    loop.run(num_episodes=2)
Example #19
0
def main(_):
  environment = fakes.ContinuousEnvironment(action_dim=8,
                                            observation_dim=87,
                                            episode_length=10000000)
  spec = specs.make_environment_spec(environment)
  replay_tables = make_replay_tables(spec)
  replay_server = reverb.Server(replay_tables, port=None)
  replay_client = reverb.Client(f'localhost:{replay_server.port}')
  adder = make_adder(replay_client)

  timestep = environment.reset()
  adder.add_first(timestep)
  # TODO(raveman): Consider also filling the table to say 1M (too slow).
  for steps in range(10000):
    if steps % 1000 == 0:
      logging.info('Processed %s steps', steps)
    action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32)
    next_timestep = environment.step(action)
    adder.add(action, next_timestep, extras=())

  for batch_size in [256, 256 * 8, 256 * 64]:
    for prefetch_size in [0, 1, 4]:
      print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}')
      ds = datasets.make_reverb_dataset(
          table='default',
          server_address=replay_client.server_address,
          batch_size=batch_size,
          prefetch_size=prefetch_size,
      )
      it = ds.as_numpy_iterator()

      for iteration in range(3):
        t = time.time()
        for _ in range(1000):
          _ = next(it)
        print(f'Iteration {iteration} finished in {time.time() - t}s')
Example #20
0
 def make_env(seed):
   del seed
   return fakes.ContinuousEnvironment(
       episode_length=10, action_dim=3, observation_dim=5, bounded=True)
Example #21
0
    def test_ail_flax(self):
        shutil.rmtree(flags.FLAGS.test_tmpdir)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        environment = fakes.ContinuousEnvironment(
            episode_length=EPISODE_LENGTH,
            action_dim=CONTINUOUS_ACTION_DIM,
            observation_dim=CONTINUOUS_OBS_DIM,
            bounded=True)
        spec = specs.make_environment_spec(environment)

        networks = sac.make_networks(spec=spec)
        config = sac.SACConfig(batch_size=batch_size,
                               samples_per_insert_tolerance_rate=float('inf'),
                               min_replay_size=1)
        base_builder = sac.SACBuilder(config=config)
        direct_rl_batch_size = batch_size
        behavior_policy = sac.apply_policy_and_sample(networks)

        discriminator_module = DiscriminatorModule(spec, linen.Dense(1))

        def apply_fn(params: networks_lib.Params,
                     policy_params: networks_lib.Params,
                     state: networks_lib.Params, transitions: types.Transition,
                     is_training: bool,
                     rng: networks_lib.PRNGKey) -> networks_lib.Logits:
            del policy_params
            variables = dict(params=params, **state)
            return discriminator_module.apply(variables,
                                              transitions.observation,
                                              transitions.action,
                                              transitions.next_observation,
                                              is_training=is_training,
                                              rng=rng,
                                              mutable=state.keys())

        def init_fn(rng):
            variables = discriminator_module.init(rng,
                                                  dummy_obs,
                                                  dummy_actions,
                                                  dummy_obs,
                                                  is_training=False,
                                                  rng=rng)
            init_state, discriminator_params = variables.pop('params')
            return discriminator_params, init_state

        dummy_obs = utils.zeros_like(spec.observations)
        dummy_obs = utils.add_batch_dim(dummy_obs)
        dummy_actions = utils.zeros_like(spec.actions)
        dummy_actions = utils.add_batch_dim(dummy_actions)
        discriminator_network = networks_lib.FeedForwardNetwork(init=init_fn,
                                                                apply=apply_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(is_sequence_based=False,
                                 share_iterator=True,
                                 direct_rl_batch_size=direct_rl_batch_size,
                                 discriminator_batch_size=2,
                                 policy_variable_name=None,
                                 min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        counter = counting.Counter()
        # Construct the agent.
        agent = local_layout.LocalLayout(
            seed=0,
            environment_spec=spec,
            builder=builder,
            networks=networks,
            policy_network=behavior_policy,
            min_replay_size=1,
            batch_size=batch_size,
            counter=counter,
        )

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent, counter=counter)
        train_loop.run(num_episodes=1)
Example #22
0
    def test_ail(self,
                 algo,
                 airl_discriminator=False,
                 subtract_logpi=False,
                 dropout=0.,
                 lipschitz_coeff=None):
        shutil.rmtree(flags.FLAGS.test_tmpdir, ignore_errors=True)
        batch_size = 8
        # Mujoco environment and associated demonstration dataset.
        if algo == 'ppo':
            environment = fakes.DiscreteEnvironment(
                num_actions=NUM_DISCRETE_ACTIONS,
                num_observations=NUM_OBSERVATIONS,
                obs_shape=OBS_SHAPE,
                obs_dtype=OBS_DTYPE,
                episode_length=EPISODE_LENGTH)
        else:
            environment = fakes.ContinuousEnvironment(
                episode_length=EPISODE_LENGTH,
                action_dim=CONTINUOUS_ACTION_DIM,
                observation_dim=CONTINUOUS_OBS_DIM,
                bounded=True)
        spec = specs.make_environment_spec(environment)

        if algo == 'sac':
            networks = sac.make_networks(spec=spec)
            config = sac.SACConfig(
                batch_size=batch_size,
                samples_per_insert_tolerance_rate=float('inf'),
                min_replay_size=1)
            base_builder = sac.SACBuilder(config=config)
            direct_rl_batch_size = batch_size
            behavior_policy = sac.apply_policy_and_sample(networks)
        elif algo == 'ppo':
            unroll_length = 5
            distribution_value_networks = make_ppo_networks(spec)
            networks = ppo.make_ppo_networks(distribution_value_networks)
            config = ppo.PPOConfig(unroll_length=unroll_length,
                                   num_minibatches=2,
                                   num_epochs=4,
                                   batch_size=batch_size)
            base_builder = ppo.PPOBuilder(config=config)
            direct_rl_batch_size = batch_size * unroll_length
            behavior_policy = jax.jit(ppo.make_inference_fn(networks),
                                      backend='cpu')
        else:
            raise ValueError(f'Unexpected algorithm {algo}')

        if subtract_logpi:
            assert algo == 'sac'
            logpi_fn = make_sac_logpi(networks)
        else:
            logpi_fn = None

        if algo == 'ppo':
            embedding = lambda x: jnp.reshape(x, list(x.shape[:-2]) + [-1])
        else:
            embedding = lambda x: x

        def discriminator(*args, **kwargs) -> networks_lib.Logits:
            if airl_discriminator:
                return ail.AIRLModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    discount=.99,
                    g_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    h_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)
            else:
                return ail.DiscriminatorModule(
                    environment_spec=spec,
                    use_action=True,
                    use_next_obs=True,
                    network_core=ail.DiscriminatorMLP(
                        [4, 4],
                        hidden_dropout_rate=dropout,
                        spectral_normalization_lipschitz_coeff=lipschitz_coeff
                    ),
                    observation_embedding=embedding)(*args, **kwargs)

        discriminator_transformed = hk.without_apply_rng(
            hk.transform_with_state(discriminator))

        discriminator_network = ail.make_discriminator(
            environment_spec=spec,
            discriminator_transformed=discriminator_transformed,
            logpi_fn=logpi_fn)

        networks = ail.AILNetworks(discriminator_network, lambda x: x,
                                   networks)

        builder = ail.AILBuilder(
            base_builder,
            config=ail.AILConfig(
                is_sequence_based=(algo == 'ppo'),
                share_iterator=True,
                direct_rl_batch_size=direct_rl_batch_size,
                discriminator_batch_size=2,
                policy_variable_name='policy' if subtract_logpi else None,
                min_replay_size=1),
            discriminator_loss=ail.losses.gail_loss(),
            make_demonstrations=fakes.transition_iterator(environment))

        # Construct the agent.
        agent = local_layout.LocalLayout(seed=0,
                                         environment_spec=spec,
                                         builder=builder,
                                         networks=networks,
                                         policy_network=behavior_policy,
                                         min_replay_size=1,
                                         batch_size=batch_size)

        # Train the agent.
        train_loop = acme.EnvironmentLoop(environment, agent)
        train_loop.run(num_episodes=(10 if algo == 'ppo' else 1))