def __init__( self, environment_factory: jax_types.EnvironmentFactory, network_factory: NetworkFactory, td3_fd_config: TD3fDConfig, lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]], seed: int, num_actors: int, environment_spec: specs.EnvironmentSpec, max_number_of_steps: Optional[int] = None, log_to_bigtable: bool = False, log_every: float = 10.0, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, ): logger_fn = functools.partial(loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') td3_config = td3_fd_config.td3_config lfd_config = td3_fd_config.lfd_config td3_builder = td3.TD3Builder(td3_config, logger_fn=logger_fn) lfd_builder = builder.LfdBuilder(td3_builder, lfd_iterator_fn, lfd_config) action_specs = environment_spec.actions policy_network_fn = functools.partial(td3.get_default_behavior_policy, action_specs=action_specs, sigma=td3_config.sigma) if evaluator_factories is None: eval_network_fn = functools.partial( td3.get_default_behavior_policy, action_specs=action_specs, sigma=0.) evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=eval_network_fn, log_to_bigtable=log_to_bigtable) ] super().__init__( seed=seed, environment_factory=environment_factory, network_factory=network_factory, builder=lfd_builder, policy_network=policy_network_fn, evaluator_factories=evaluator_factories, num_actors=num_actors, max_number_of_steps=max_number_of_steps, prefetch_size=td3_config.prefetch_size, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every), )
def __init__(self, spec: specs.EnvironmentSpec, network: ail_networks.AILNetworks, config: DACConfig, *args, **kwargs): td3_agent = td3.TD3Builder(config.td3_config) dac_loss = losses.add_gradient_penalty( losses.gail_loss(entropy_coefficient=config.entropy_coefficient), gradient_penalty_coefficient=config.gradient_penalty_coefficient, gradient_penalty_target=1.) kwargs['discriminator_loss'] = dac_loss super().__init__(spec, td3_agent, network, config.ail_config, *args, **kwargs)
def __init__(self, config: DACConfig, make_demonstrations: Callable[[int], Iterator[types.Transition]]): td3_builder = td3.TD3Builder(config.td3_config) dac_loss = losses.add_gradient_penalty( losses.gail_loss(entropy_coefficient=config.entropy_coefficient), gradient_penalty_coefficient=config.gradient_penalty_coefficient, gradient_penalty_target=1.) super().__init__(td3_builder, config=config.ail_config, discriminator_loss=dac_loss, make_demonstrations=make_demonstrations)
def __init__(self, environment_factory: jax_types.EnvironmentFactory, config: DACConfig, *args, **kwargs): logger_fn = functools.partial(loggers.make_default_logger, 'direct_learner', kwargs['log_to_bigtable'], time_delta=kwargs['log_every'], asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') td3_agent = td3.TD3Builder(config.td3_config, logger_fn=logger_fn) dac_loss = losses.add_gradient_penalty( losses.gail_loss(entropy_coefficient=config.entropy_coefficient), gradient_penalty_coefficient=config.gradient_penalty_coefficient, gradient_penalty_target=1.) kwargs['discriminator_loss'] = dac_loss super().__init__(environment_factory, td3_agent, config.ail_config, *args, **kwargs)
def __init__(self, spec: specs.EnvironmentSpec, td3_network: td3.TD3Networks, td3_fd_config: TD3fDConfig, lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]], seed: int, counter: Optional[counting.Counter] = None): """New instance of a TD3fD agent.""" td3_config = td3_fd_config.td3_config lfd_config = td3_fd_config.lfd_config td3_builder = td3.TD3Builder(td3_config) lfd_builder = builder.LfdBuilder(td3_builder, lfd_iterator_fn, lfd_config) min_replay_size = td3_config.min_replay_size # Local layout (actually agent.Agent) makes sure that we populate the # buffer with min_replay_size initial transitions and that there's no need # for tolerance_rate. In order for deadlocks not to happen we need to # disable rate limiting that heppens inside the SACBuilder. This is achieved # by the following two lines. td3_config.samples_per_insert_tolerance_rate = float('inf') td3_config.min_replay_size = 1 behavior_policy = td3.get_default_behavior_policy( networks=td3_network, action_specs=spec.actions, sigma=td3_config.sigma) self.builder = lfd_builder super().__init__( seed=seed, environment_spec=spec, builder=lfd_builder, networks=td3_network, policy_network=behavior_policy, batch_size=td3_config.batch_size, prefetch_size=td3_config.prefetch_size, samples_per_insert=td3_config.samples_per_insert, min_replay_size=min_replay_size, num_sgd_steps_per_step=td3_config.num_sgd_steps_per_step, counter=counter, )
def build_experiment_config(): """Builds TD3 experiment config which can be executed in different ways.""" # Create an environment, grab the spec, and use it to create networks. suite, task = FLAGS.env_name.split(':', 1) network_factory = (lambda spec: td3.make_networks( spec, hidden_layer_sizes=(256, 256, 256))) # Construct the agent. config = td3.TD3Config( policy_learning_rate=3e-4, critic_learning_rate=3e-4, ) td3_builder = td3.TD3Builder(config) # pylint:disable=g-long-lambda return experiments.ExperimentConfig( builder=td3_builder, environment_factory=lambda seed: helpers.make_environment(suite, task), network_factory=network_factory, seed=FLAGS.seed, max_num_actor_steps=FLAGS.num_steps)
def __init__(self, td3_fd_config: TD3fDConfig, lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]): td3_builder = td3.TD3Builder(td3_fd_config.td3_config) super().__init__(td3_builder, lfd_iterator_fn, td3_fd_config.lfd_config)