Beispiel #1
0
    def test_discrete_actions(self, loss_name):
        with chex.fake_pmap_and_jit():

            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.DiscreteEnvironment(num_actions=10,
                                                    num_observations=100,
                                                    obs_shape=(10, ),
                                                    obs_dtype=np.float32)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec, discrete_actions=True)

            def logp_fn(logits, actions):
                max_logits = jnp.max(logits, axis=-1, keepdims=True)
                logits = logits - max_logits
                logits_actions = jnp.sum(
                    jax.nn.one_hot(actions, spec.actions.num_values) * logits,
                    axis=-1)

                log_prob = logits_actions - special.logsumexp(logits, axis=-1)
                return log_prob

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=logp_fn)

            elif loss_name == 'rcal':
                base_loss_fn = bc.logp(logp_fn=logp_fn)
                loss_fn = bc.rcal(base_loss_fn, discount=0.99, alpha=0.1)

            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()
Beispiel #2
0
    def test_continuous_actions(self, loss_name):
        with chex.fake_pmap_and_jit():
            num_sgd_steps_per_step = 1
            num_steps = 5

            # Create a fake environment to test with.
            environment = fakes.ContinuousEnvironment(episode_length=10,
                                                      bounded=True,
                                                      action_dim=6)

            spec = specs.make_environment_spec(environment)
            dataset_demonstration = fakes.transition_dataset(environment)
            dataset_demonstration = dataset_demonstration.map(
                lambda sample: types.Transition(*sample.data))
            dataset_demonstration = dataset_demonstration.batch(
                8).as_numpy_iterator()

            # Construct the agent.
            network = make_networks(spec)

            if loss_name == 'logp':
                loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                  dist_params.log_prob(actions))
            elif loss_name == 'mse':
                loss_fn = bc.mse(sample_fn=lambda dist_params, key: dist_params
                                 .sample(seed=key))
            elif loss_name == 'peerbc':
                base_loss_fn = bc.logp(logp_fn=lambda dist_params, actions:
                                       dist_params.log_prob(actions))
                loss_fn = bc.peerbc(base_loss_fn, zeta=0.1)
            else:
                raise ValueError

            learner = bc.BCLearner(
                network=network,
                random_key=jax.random.PRNGKey(0),
                loss_fn=loss_fn,
                optimizer=optax.adam(0.01),
                demonstrations=dataset_demonstration,
                num_sgd_steps_per_step=num_sgd_steps_per_step)

            # Train the agent
            for _ in range(num_steps):
                learner.step()
Beispiel #3
0
  def test_full_learner(self):
    # Create dataset.
    environment = fakes.DiscreteEnvironment(
        num_actions=5,
        num_observations=10,
        obs_dtype=np.float32,
        episode_length=10)
    spec = specs.make_environment_spec(environment)
    dataset = fakes.transition_dataset(environment).batch(2)

    # Build network.
    g_network = _make_network(spec.actions)
    q_network = _make_network(spec.actions)
    network = discrete_networks.DiscreteFilteredQNetwork(g_network=g_network,
                                                         q_network=q_network,
                                                         threshold=0.5)
    tf2_utils.create_variables(network, [spec.observations])

    # Build learner.
    counter = counting.Counter()
    learner = bcq.DiscreteBCQLearner(
        network=network,
        dataset=dataset,
        learning_rate=1e-4,
        discount=0.99,
        importance_sampling_exponent=0.2,
        target_update_period=100,
        counter=counter)

    # Run a learner step.
    learner.step()

    # Check counts from BC and BCQ learners.
    counts = counter.get_counts()
    self.assertEqual(1, counts['bc_steps'])
    self.assertEqual(1, counts['bcq_steps'])

    # Check learner state.
    self.assertEqual(1, learner.state['bc_num_steps'].numpy())
    self.assertEqual(1, learner.state['bcq_num_steps'].numpy())
Beispiel #4
0
  def test_full_learner(self):
    # Create dataset.
    environment = fakes.DiscreteEnvironment(
        num_actions=5,
        num_observations=10,
        obs_dtype=np.float32,
        episode_length=10)
    spec = specs.make_environment_spec(environment)
    dataset = fakes.transition_dataset(environment).batch(
        2, drop_remainder=True)

    # Build network.
    network = networks.IQNNetwork(
        torso=_make_torso_network(num_outputs=2),
        head=_make_head_network(num_outputs=spec.actions.num_values),
        latent_dim=2,
        num_quantile_samples=1)
    tf2_utils.create_variables(network, [spec.observations])

    # Build learner.
    counter = counting.Counter()
    learner = iqn.IQNLearner(
        network=network,
        target_network=copy.deepcopy(network),
        dataset=dataset,
        learning_rate=1e-4,
        discount=0.99,
        importance_sampling_exponent=0.2,
        target_update_period=1,
        counter=counter)

    # Run a learner step.
    learner.step()

    # Check counts from IQN learner.
    counts = counter.get_counts()
    self.assertEqual(1, counts['steps'])

    # Check learner state.
    self.assertEqual(1, learner.state['num_steps'].numpy())