def __init__(self, spec: specs.EnvironmentSpec, network: ail_networks.AILNetworks, config: GAILConfig, *args, **kwargs): ppo_agent = ppo.PPOBuilder(config.ppo_config) kwargs['discriminator_loss'] = losses.gail_loss() super().__init__(spec, ppo_agent, network, config.ail_config, *args, **kwargs)
def __init__(self, spec: specs.EnvironmentSpec, network: ail_networks.AILNetworks, config: DACConfig, *args, **kwargs): td3_agent = td3.TD3Builder(config.td3_config) dac_loss = losses.add_gradient_penalty( losses.gail_loss(entropy_coefficient=config.entropy_coefficient), gradient_penalty_coefficient=config.gradient_penalty_coefficient, gradient_penalty_target=1.) kwargs['discriminator_loss'] = dac_loss super().__init__(spec, td3_agent, network, config.ail_config, *args, **kwargs)
def __init__(self, config: DACConfig, make_demonstrations: Callable[[int], Iterator[types.Transition]]): td3_builder = td3.TD3Builder(config.td3_config) dac_loss = losses.add_gradient_penalty( losses.gail_loss(entropy_coefficient=config.entropy_coefficient), gradient_penalty_coefficient=config.gradient_penalty_coefficient, gradient_penalty_target=1.) super().__init__(td3_builder, config=config.ail_config, discriminator_loss=dac_loss, make_demonstrations=make_demonstrations)
def __init__(self, environment_factory: jax_types.EnvironmentFactory, config: GAILConfig, *args, **kwargs): logger_fn = functools.partial(loggers.make_default_logger, 'direct_learner', kwargs['log_to_bigtable'], time_delta=kwargs['log_every'], asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') ppo_agent = ppo.PPOBuilder(config.ppo_config, logger_fn=logger_fn) kwargs['discriminator_loss'] = losses.gail_loss() super().__init__(environment_factory, ppo_agent, config.ail_config, *args, **kwargs)
def test_step(self): simple_spec = specs.Array(shape=(), dtype=float) spec = specs.EnvironmentSpec(simple_spec, simple_spec, simple_spec, simple_spec) discriminator = _make_discriminator(spec) ail_network = ail_networks.AILNetworks(discriminator, imitation_reward_fn=lambda x: x, direct_rl_networks=None) loss = losses.gail_loss() optimizer = optax.adam(.01) step = jax.jit( functools.partial(ail_learning.ail_update_step, optimizer=optimizer, ail_network=ail_network, loss_fn=loss)) zero_transition = types.Transition(np.array([0.]), np.array([0.]), 0., 0., np.array([0.])) zero_transition = utils.add_batch_dim(zero_transition) one_transition = types.Transition(np.array([1.]), np.array([0.]), 0., 0., np.array([0.])) one_transition = utils.add_batch_dim(one_transition) key = jax.random.PRNGKey(0) discriminator_params, discriminator_state = discriminator.init(key) state = ail_learning.DiscriminatorTrainingState( optimizer_state=optimizer.init(discriminator_params), discriminator_params=discriminator_params, discriminator_state=discriminator_state, policy_params=None, key=key, steps=0, ) expected_loss = [1.062, 1.057, 1.052] for i in range(3): state, loss = step(state, (one_transition, zero_transition)) self.assertAlmostEqual(loss['total_loss'], expected_loss[i], places=3)
def __init__(self, environment_factory: jax_types.EnvironmentFactory, config: DACConfig, *args, **kwargs): logger_fn = functools.partial(loggers.make_default_logger, 'direct_learner', kwargs['log_to_bigtable'], time_delta=kwargs['log_every'], asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') td3_agent = td3.TD3Builder(config.td3_config, logger_fn=logger_fn) dac_loss = losses.add_gradient_penalty( losses.gail_loss(entropy_coefficient=config.entropy_coefficient), gradient_penalty_coefficient=config.gradient_penalty_coefficient, gradient_penalty_target=1.) kwargs['discriminator_loss'] = dac_loss super().__init__(environment_factory, td3_agent, config.ail_config, *args, **kwargs)