Esempio n. 1
0
 def __init__(self, spec: specs.EnvironmentSpec,
              network: ail_networks.AILNetworks, config: GAILConfig, *args,
              **kwargs):
     ppo_agent = ppo.PPOBuilder(config.ppo_config)
     kwargs['discriminator_loss'] = losses.gail_loss()
     super().__init__(spec, ppo_agent, network, config.ail_config, *args,
                      **kwargs)
Esempio n. 2
0
    def __init__(self, spec: specs.EnvironmentSpec,
                 network: ail_networks.AILNetworks, config: DACConfig, *args,
                 **kwargs):
        td3_agent = td3.TD3Builder(config.td3_config)

        dac_loss = losses.add_gradient_penalty(
            losses.gail_loss(entropy_coefficient=config.entropy_coefficient),
            gradient_penalty_coefficient=config.gradient_penalty_coefficient,
            gradient_penalty_target=1.)
        kwargs['discriminator_loss'] = dac_loss
        super().__init__(spec, td3_agent, network, config.ail_config, *args,
                         **kwargs)
Esempio n. 3
0
File: dac.py Progetto: deepmind/acme
    def __init__(self, config: DACConfig,
                 make_demonstrations: Callable[[int],
                                               Iterator[types.Transition]]):

        td3_builder = td3.TD3Builder(config.td3_config)
        dac_loss = losses.add_gradient_penalty(
            losses.gail_loss(entropy_coefficient=config.entropy_coefficient),
            gradient_penalty_coefficient=config.gradient_penalty_coefficient,
            gradient_penalty_target=1.)
        super().__init__(td3_builder,
                         config=config.ail_config,
                         discriminator_loss=dac_loss,
                         make_demonstrations=make_demonstrations)
Esempio n. 4
0
 def __init__(self, environment_factory: jax_types.EnvironmentFactory,
              config: GAILConfig, *args, **kwargs):
     logger_fn = functools.partial(loggers.make_default_logger,
                                   'direct_learner',
                                   kwargs['log_to_bigtable'],
                                   time_delta=kwargs['log_every'],
                                   asynchronous=True,
                                   serialize_fn=utils.fetch_devicearray,
                                   steps_key='learner_steps')
     ppo_agent = ppo.PPOBuilder(config.ppo_config, logger_fn=logger_fn)
     kwargs['discriminator_loss'] = losses.gail_loss()
     super().__init__(environment_factory, ppo_agent, config.ail_config,
                      *args, **kwargs)
Esempio n. 5
0
    def test_step(self):
        simple_spec = specs.Array(shape=(), dtype=float)

        spec = specs.EnvironmentSpec(simple_spec, simple_spec, simple_spec,
                                     simple_spec)

        discriminator = _make_discriminator(spec)
        ail_network = ail_networks.AILNetworks(discriminator,
                                               imitation_reward_fn=lambda x: x,
                                               direct_rl_networks=None)

        loss = losses.gail_loss()

        optimizer = optax.adam(.01)

        step = jax.jit(
            functools.partial(ail_learning.ail_update_step,
                              optimizer=optimizer,
                              ail_network=ail_network,
                              loss_fn=loss))

        zero_transition = types.Transition(np.array([0.]), np.array([0.]), 0.,
                                           0., np.array([0.]))
        zero_transition = utils.add_batch_dim(zero_transition)

        one_transition = types.Transition(np.array([1.]), np.array([0.]), 0.,
                                          0., np.array([0.]))
        one_transition = utils.add_batch_dim(one_transition)

        key = jax.random.PRNGKey(0)
        discriminator_params, discriminator_state = discriminator.init(key)

        state = ail_learning.DiscriminatorTrainingState(
            optimizer_state=optimizer.init(discriminator_params),
            discriminator_params=discriminator_params,
            discriminator_state=discriminator_state,
            policy_params=None,
            key=key,
            steps=0,
        )

        expected_loss = [1.062, 1.057, 1.052]

        for i in range(3):
            state, loss = step(state, (one_transition, zero_transition))
            self.assertAlmostEqual(loss['total_loss'],
                                   expected_loss[i],
                                   places=3)
Esempio n. 6
0
    def __init__(self, environment_factory: jax_types.EnvironmentFactory,
                 config: DACConfig, *args, **kwargs):
        logger_fn = functools.partial(loggers.make_default_logger,
                                      'direct_learner',
                                      kwargs['log_to_bigtable'],
                                      time_delta=kwargs['log_every'],
                                      asynchronous=True,
                                      serialize_fn=utils.fetch_devicearray,
                                      steps_key='learner_steps')
        td3_agent = td3.TD3Builder(config.td3_config, logger_fn=logger_fn)

        dac_loss = losses.add_gradient_penalty(
            losses.gail_loss(entropy_coefficient=config.entropy_coefficient),
            gradient_penalty_coefficient=config.gradient_penalty_coefficient,
            gradient_penalty_target=1.)
        kwargs['discriminator_loss'] = dac_loss
        super().__init__(environment_factory, td3_agent, config.ail_config,
                         *args, **kwargs)