Ejemplo n.º 1
0
    def __init__(
        self,
        environment_factory: jax_types.EnvironmentFactory,
        network_factory: NetworkFactory,
        td3_fd_config: TD3fDConfig,
        lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]],
        seed: int,
        num_actors: int,
        environment_spec: specs.EnvironmentSpec,
        max_number_of_steps: Optional[int] = None,
        log_to_bigtable: bool = False,
        log_every: float = 10.0,
        evaluator_factories: Optional[Sequence[
            distributed_layout.EvaluatorFactory]] = None,
    ):
        logger_fn = functools.partial(loggers.make_default_logger,
                                      'learner',
                                      log_to_bigtable,
                                      time_delta=log_every,
                                      asynchronous=True,
                                      serialize_fn=utils.fetch_devicearray,
                                      steps_key='learner_steps')

        td3_config = td3_fd_config.td3_config
        lfd_config = td3_fd_config.lfd_config
        td3_builder = td3.TD3Builder(td3_config, logger_fn=logger_fn)
        lfd_builder = builder.LfdBuilder(td3_builder, lfd_iterator_fn,
                                         lfd_config)

        action_specs = environment_spec.actions
        policy_network_fn = functools.partial(td3.get_default_behavior_policy,
                                              action_specs=action_specs,
                                              sigma=td3_config.sigma)

        if evaluator_factories is None:
            eval_network_fn = functools.partial(
                td3.get_default_behavior_policy,
                action_specs=action_specs,
                sigma=0.)
            evaluator_factories = [
                distributed_layout.default_evaluator_factory(
                    environment_factory=environment_factory,
                    network_factory=network_factory,
                    policy_factory=eval_network_fn,
                    log_to_bigtable=log_to_bigtable)
            ]
        super().__init__(
            seed=seed,
            environment_factory=environment_factory,
            network_factory=network_factory,
            builder=lfd_builder,
            policy_network=policy_network_fn,
            evaluator_factories=evaluator_factories,
            num_actors=num_actors,
            max_number_of_steps=max_number_of_steps,
            prefetch_size=td3_config.prefetch_size,
            log_to_bigtable=log_to_bigtable,
            actor_logger_fn=distributed_layout.get_default_logger_fn(
                log_to_bigtable, log_every),
        )
Ejemplo n.º 2
0
    def __init__(self, spec: specs.EnvironmentSpec,
                 network: ail_networks.AILNetworks, config: DACConfig, *args,
                 **kwargs):
        td3_agent = td3.TD3Builder(config.td3_config)

        dac_loss = losses.add_gradient_penalty(
            losses.gail_loss(entropy_coefficient=config.entropy_coefficient),
            gradient_penalty_coefficient=config.gradient_penalty_coefficient,
            gradient_penalty_target=1.)
        kwargs['discriminator_loss'] = dac_loss
        super().__init__(spec, td3_agent, network, config.ail_config, *args,
                         **kwargs)
Ejemplo n.º 3
0
Archivo: dac.py Proyecto: deepmind/acme
    def __init__(self, config: DACConfig,
                 make_demonstrations: Callable[[int],
                                               Iterator[types.Transition]]):

        td3_builder = td3.TD3Builder(config.td3_config)
        dac_loss = losses.add_gradient_penalty(
            losses.gail_loss(entropy_coefficient=config.entropy_coefficient),
            gradient_penalty_coefficient=config.gradient_penalty_coefficient,
            gradient_penalty_target=1.)
        super().__init__(td3_builder,
                         config=config.ail_config,
                         discriminator_loss=dac_loss,
                         make_demonstrations=make_demonstrations)
Ejemplo n.º 4
0
    def __init__(self, environment_factory: jax_types.EnvironmentFactory,
                 config: DACConfig, *args, **kwargs):
        logger_fn = functools.partial(loggers.make_default_logger,
                                      'direct_learner',
                                      kwargs['log_to_bigtable'],
                                      time_delta=kwargs['log_every'],
                                      asynchronous=True,
                                      serialize_fn=utils.fetch_devicearray,
                                      steps_key='learner_steps')
        td3_agent = td3.TD3Builder(config.td3_config, logger_fn=logger_fn)

        dac_loss = losses.add_gradient_penalty(
            losses.gail_loss(entropy_coefficient=config.entropy_coefficient),
            gradient_penalty_coefficient=config.gradient_penalty_coefficient,
            gradient_penalty_target=1.)
        kwargs['discriminator_loss'] = dac_loss
        super().__init__(environment_factory, td3_agent, config.ail_config,
                         *args, **kwargs)
Ejemplo n.º 5
0
    def __init__(self,
                 spec: specs.EnvironmentSpec,
                 td3_network: td3.TD3Networks,
                 td3_fd_config: TD3fDConfig,
                 lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]],
                 seed: int,
                 counter: Optional[counting.Counter] = None):
        """New instance of a TD3fD agent."""
        td3_config = td3_fd_config.td3_config
        lfd_config = td3_fd_config.lfd_config
        td3_builder = td3.TD3Builder(td3_config)
        lfd_builder = builder.LfdBuilder(td3_builder, lfd_iterator_fn,
                                         lfd_config)

        min_replay_size = td3_config.min_replay_size
        # Local layout (actually agent.Agent) makes sure that we populate the
        # buffer with min_replay_size initial transitions and that there's no need
        # for tolerance_rate. In order for deadlocks not to happen we need to
        # disable rate limiting that heppens inside the SACBuilder. This is achieved
        # by the following two lines.
        td3_config.samples_per_insert_tolerance_rate = float('inf')
        td3_config.min_replay_size = 1

        behavior_policy = td3.get_default_behavior_policy(
            networks=td3_network,
            action_specs=spec.actions,
            sigma=td3_config.sigma)

        self.builder = lfd_builder
        super().__init__(
            seed=seed,
            environment_spec=spec,
            builder=lfd_builder,
            networks=td3_network,
            policy_network=behavior_policy,
            batch_size=td3_config.batch_size,
            prefetch_size=td3_config.prefetch_size,
            samples_per_insert=td3_config.samples_per_insert,
            min_replay_size=min_replay_size,
            num_sgd_steps_per_step=td3_config.num_sgd_steps_per_step,
            counter=counter,
        )
Ejemplo n.º 6
0
def build_experiment_config():
    """Builds TD3 experiment config which can be executed in different ways."""
    # Create an environment, grab the spec, and use it to create networks.

    suite, task = FLAGS.env_name.split(':', 1)
    network_factory = (lambda spec: td3.make_networks(
        spec, hidden_layer_sizes=(256, 256, 256)))

    # Construct the agent.
    config = td3.TD3Config(
        policy_learning_rate=3e-4,
        critic_learning_rate=3e-4,
    )
    td3_builder = td3.TD3Builder(config)
    # pylint:disable=g-long-lambda
    return experiments.ExperimentConfig(
        builder=td3_builder,
        environment_factory=lambda seed: helpers.make_environment(suite, task),
        network_factory=network_factory,
        seed=FLAGS.seed,
        max_num_actor_steps=FLAGS.num_steps)
Ejemplo n.º 7
0
 def __init__(self, td3_fd_config: TD3fDConfig,
              lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]]):
   td3_builder = td3.TD3Builder(td3_fd_config.td3_config)
   super().__init__(td3_builder, lfd_iterator_fn, td3_fd_config.lfd_config)