コード例 #1
0
    def __init__(
        self,
        environment_factory: jax_types.EnvironmentFactory,
        network_factory: NetworkFactory,
        td3_fd_config: TD3fDConfig,
        lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]],
        seed: int,
        num_actors: int,
        environment_spec: specs.EnvironmentSpec,
        max_number_of_steps: Optional[int] = None,
        log_to_bigtable: bool = False,
        log_every: float = 10.0,
        evaluator_factories: Optional[Sequence[
            distributed_layout.EvaluatorFactory]] = None,
    ):
        logger_fn = functools.partial(loggers.make_default_logger,
                                      'learner',
                                      log_to_bigtable,
                                      time_delta=log_every,
                                      asynchronous=True,
                                      serialize_fn=utils.fetch_devicearray,
                                      steps_key='learner_steps')

        td3_config = td3_fd_config.td3_config
        lfd_config = td3_fd_config.lfd_config
        td3_builder = td3.TD3Builder(td3_config, logger_fn=logger_fn)
        lfd_builder = builder.LfdBuilder(td3_builder, lfd_iterator_fn,
                                         lfd_config)

        action_specs = environment_spec.actions
        policy_network_fn = functools.partial(td3.get_default_behavior_policy,
                                              action_specs=action_specs,
                                              sigma=td3_config.sigma)

        if evaluator_factories is None:
            eval_network_fn = functools.partial(
                td3.get_default_behavior_policy,
                action_specs=action_specs,
                sigma=0.)
            evaluator_factories = [
                distributed_layout.default_evaluator_factory(
                    environment_factory=environment_factory,
                    network_factory=network_factory,
                    policy_factory=eval_network_fn,
                    log_to_bigtable=log_to_bigtable)
            ]
        super().__init__(
            seed=seed,
            environment_factory=environment_factory,
            network_factory=network_factory,
            builder=lfd_builder,
            policy_network=policy_network_fn,
            evaluator_factories=evaluator_factories,
            num_actors=num_actors,
            max_number_of_steps=max_number_of_steps,
            prefetch_size=td3_config.prefetch_size,
            log_to_bigtable=log_to_bigtable,
            actor_logger_fn=distributed_layout.get_default_logger_fn(
                log_to_bigtable, log_every),
        )
コード例 #2
0
 def __init__(
     self,
     environment_factory: jax_types.EnvironmentFactory,
     network_factory: NetworkFactory,
     config: ppo_config.PPOConfig,
     seed: int,
     num_actors: int,
     normalize_input: bool = False,
     logger_fn: Optional[Callable[[], loggers.Logger]] = None,
     save_reverb_logs: bool = False,
     log_every: float = 10.0,
     max_number_of_steps: Optional[int] = None,
     evaluator_factories: Optional[Sequence[
         distributed_layout.EvaluatorFactory]] = None,
 ):
     logger_fn = logger_fn or functools.partial(
         loggers.make_default_logger,
         'learner',
         save_reverb_logs,
         time_delta=log_every,
         asynchronous=True,
         serialize_fn=utils.fetch_devicearray,
         steps_key='learner_steps')
     ppo_builder = builder.PPOBuilder(config, logger_fn=logger_fn)
     if normalize_input:
         dummy_seed = 1
         environment_spec = specs.make_environment_spec(
             environment_factory(dummy_seed))
         # Two batch dimensions: [num_sequences, num_steps, ...]
         batch_dims = (0, 1)
         ppo_builder = normalization.NormalizationBuilder(
             ppo_builder,
             environment_spec,
             is_sequence_based=True,
             batch_dims=batch_dims)
     if evaluator_factories is None:
         eval_policy_factory = (lambda networks: ppo_networks.
                                make_inference_fn(networks, True))
         evaluator_factories = [
             distributed_layout.default_evaluator_factory(
                 environment_factory=environment_factory,
                 network_factory=network_factory,
                 policy_factory=eval_policy_factory,
                 log_to_bigtable=save_reverb_logs)
         ]
     super().__init__(
         seed=seed,
         environment_factory=environment_factory,
         network_factory=network_factory,
         builder=ppo_builder,
         policy_network=ppo_networks.make_inference_fn,
         evaluator_factories=evaluator_factories,
         num_actors=num_actors,
         prefetch_size=config.prefetch_size,
         max_number_of_steps=max_number_of_steps,
         log_to_bigtable=save_reverb_logs,
         actor_logger_fn=distributed_layout.get_default_logger_fn(
             save_reverb_logs, log_every),
     )
コード例 #3
0
ファイル: agents.py プロジェクト: vishalbelsare/acme
 def __init__(
     self,
     environment_factory: jax_types.EnvironmentFactory,
     network_factory: NetworkFactory,
     config: sac_config.SACConfig,
     seed: int,
     num_actors: int,
     max_number_of_steps: Optional[int] = None,
     log_to_bigtable: bool = False,
     log_every: float = 10.0,
     normalize_input: bool = True,
     evaluator_factories: Optional[Sequence[
         distributed_layout.EvaluatorFactory]] = None,
 ):
     logger_fn = functools.partial(loggers.make_default_logger,
                                   'learner',
                                   log_to_bigtable,
                                   time_delta=log_every,
                                   asynchronous=True,
                                   serialize_fn=utils.fetch_devicearray,
                                   steps_key='learner_steps')
     sac_builder = builder.SACBuilder(config, logger_fn=logger_fn)
     if normalize_input:
         dummy_seed = 1
         environment_spec = specs.make_environment_spec(
             environment_factory(dummy_seed))
         # One batch dimension: [batch_size, ...]
         batch_dims = (0, )
         sac_builder = normalization.NormalizationBuilder(
             sac_builder,
             environment_spec,
             is_sequence_based=False,
             batch_dims=batch_dims)
     if evaluator_factories is None:
         eval_policy_factory = (
             lambda n: networks.apply_policy_and_sample(n, True))
         evaluator_factories = [
             distributed_layout.default_evaluator_factory(
                 environment_factory=environment_factory,
                 network_factory=network_factory,
                 policy_factory=eval_policy_factory,
                 log_to_bigtable=log_to_bigtable)
         ]
     super().__init__(
         seed=seed,
         environment_factory=environment_factory,
         network_factory=network_factory,
         builder=sac_builder,
         policy_network=networks.apply_policy_and_sample,
         evaluator_factories=evaluator_factories,
         num_actors=num_actors,
         max_number_of_steps=max_number_of_steps,
         prefetch_size=config.prefetch_size,
         log_to_bigtable=log_to_bigtable,
         actor_logger_fn=distributed_layout.get_default_logger_fn(
             log_to_bigtable, log_every),
         checkpointing_config=distributed_layout.CheckpointingConfig(),
         make_snapshot_models=networks.default_models_to_snapshot,
     )
コード例 #4
0
 def __init__(
     self,
     environment_factory: jax_types.EnvironmentFactory,
     rl_agent: builders.GenericActorLearnerBuilder,
     config: ail_config.AILConfig,
     network_factory: NetworkFactory,
     seed: int,
     batch_size: int,
     make_demonstrations: Callable[[int], Iterator[types.Transition]],
     policy_network: Any,
     evaluator_policy_network: Any,
     num_actors: int,
     max_number_of_steps: Optional[int] = None,
     log_to_bigtable: bool = False,
     log_every: float = 10.0,
     prefetch_size: int = 4,
     discriminator_loss: Optional[losses.Loss] = None,
     evaluator_factories: Optional[Sequence[
         distributed_layout.EvaluatorFactory]] = None,
 ):
   assert discriminator_loss is not None
   logger_fn = functools.partial(
       loggers.make_default_logger,
       'learner',
       log_to_bigtable,
       time_delta=log_every,
       asynchronous=True,
       serialize_fn=utils.fetch_devicearray,
       steps_key='learner_steps')
   ail_builder = builder.AILBuilder(
       rl_agent=rl_agent,
       config=config,
       discriminator_loss=discriminator_loss,
       make_demonstrations=make_demonstrations,
       logger_fn=logger_fn)
   if evaluator_factories is None:
     evaluator_factories = [
         distributed_layout.default_evaluator_factory(
             environment_factory=environment_factory,
             network_factory=network_factory,
             policy_factory=evaluator_policy_network,
             log_to_bigtable=log_to_bigtable)
     ]
   super().__init__(
       seed=seed,
       environment_factory=environment_factory,
       network_factory=network_factory,
       builder=ail_builder,
       policy_network=policy_network,
       evaluator_factories=evaluator_factories,
       num_actors=num_actors,
       max_number_of_steps=max_number_of_steps,
       prefetch_size=prefetch_size,
       log_to_bigtable=log_to_bigtable,
       actor_logger_fn=distributed_layout.get_default_logger_fn(
           log_to_bigtable, log_every),
   )
コード例 #5
0
 def __init__(
     self,
     environment_factory: jax_types.EnvironmentFactory,
     environment_spec: specs.EnvironmentSpec,
     network_factory: NetworkFactory,
     config: r2d2_config.R2D2Config,
     seed: int,
     num_actors: int,
     workdir: str = '~/acme',
     device_prefetch: bool = False,
     log_to_bigtable: bool = True,
     log_every: float = 10.0,
     evaluator_factories: Optional[Sequence[
         distributed_layout.EvaluatorFactory]] = None,
     max_number_of_steps: Optional[int] = None,
 ):
     logger_fn = functools.partial(loggers.make_default_logger,
                                   'learner',
                                   log_to_bigtable,
                                   time_delta=log_every,
                                   asynchronous=True,
                                   serialize_fn=utils.fetch_devicearray,
                                   steps_key='learner_steps')
     r2d2_builder = builder.R2D2Builder(
         networks=network_factory(environment_spec),
         config=config,
         logger_fn=logger_fn)
     policy_network_factory = (
         lambda n: r2d2_networks.make_behavior_policy(n, config))
     if evaluator_factories is None:
         evaluator_policy_network_factory = (
             lambda n: r2d2_networks.make_behavior_policy(n, config, True))
         evaluator_factories = [
             distributed_layout.default_evaluator_factory(
                 environment_factory=environment_factory,
                 network_factory=network_factory,
                 policy_factory=evaluator_policy_network_factory,
                 log_to_bigtable=log_to_bigtable)
         ]
     super().__init__(
         seed=seed,
         environment_factory=environment_factory,
         network_factory=network_factory,
         builder=r2d2_builder,
         policy_network=policy_network_factory,
         evaluator_factories=evaluator_factories,
         num_actors=num_actors,
         environment_spec=environment_spec,
         device_prefetch=device_prefetch,
         log_to_bigtable=log_to_bigtable,
         actor_logger_fn=distributed_layout.get_default_logger_fn(
             log_to_bigtable, log_every),
         prefetch_size=config.prefetch_size,
         checkpointing_config=distributed_layout.CheckpointingConfig(
             directory=workdir, add_uid=(workdir == '~/acme')),
         max_number_of_steps=max_number_of_steps)
コード例 #6
0
ファイル: sacfd_agents.py プロジェクト: vishalbelsare/acme
    def __init__(
        self,
        environment_factory: jax_types.EnvironmentFactory,
        network_factory: NetworkFactory,
        sac_fd_config: SACfDConfig,
        lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]],
        seed: int,
        num_actors: int,
        environment_spec: Optional[specs.EnvironmentSpec] = None,
        max_number_of_steps: Optional[int] = None,
        log_to_bigtable: bool = False,
        log_every: float = 10.0,
        evaluator_factories: Optional[Sequence[
            distributed_layout.EvaluatorFactory]] = None,
    ):
        logger_fn = functools.partial(loggers.make_default_logger,
                                      'learner',
                                      log_to_bigtable,
                                      time_delta=log_every,
                                      asynchronous=True,
                                      serialize_fn=utils.fetch_devicearray,
                                      steps_key='learner_steps')

        sac_config = sac_fd_config.sac_config
        lfd_config = sac_fd_config.lfd_config
        sac_builder = sac.SACBuilder(sac_config, logger_fn=logger_fn)
        lfd_builder = builder.LfdBuilder(sac_builder, lfd_iterator_fn,
                                         lfd_config)

        if evaluator_factories is None:
            eval_policy_factory = (
                lambda n: sac.apply_policy_and_sample(n, True))
            evaluator_factories = [
                distributed_layout.default_evaluator_factory(
                    environment_factory=environment_factory,
                    network_factory=network_factory,
                    policy_factory=eval_policy_factory,
                    log_to_bigtable=log_to_bigtable)
            ]

        super().__init__(
            seed=seed,
            environment_factory=environment_factory,
            network_factory=network_factory,
            environment_spec=environment_spec,
            builder=lfd_builder,
            policy_network=sac.apply_policy_and_sample,
            evaluator_factories=evaluator_factories,
            num_actors=num_actors,
            max_number_of_steps=max_number_of_steps,
            prefetch_size=sac_config.prefetch_size,
            log_to_bigtable=log_to_bigtable,
            actor_logger_fn=distributed_layout.get_default_logger_fn(
                log_to_bigtable, log_every))
コード例 #7
0
 def __init__(
     self,
     environment_factory: jax_types.EnvironmentFactory,
     network_factory: NetworkFactory,
     config: value_dice_config.ValueDiceConfig,
     make_demonstrations: Callable[[int], Iterator[types.Transition]],
     seed: int,
     num_actors: int,
     max_number_of_steps: Optional[int] = None,
     log_to_bigtable: bool = False,
     log_every: float = 10.0,
     evaluator_factories: Optional[Sequence[
         distributed_layout.EvaluatorFactory]] = None,
 ):
     logger_fn = functools.partial(loggers.make_default_logger,
                                   'learner',
                                   log_to_bigtable,
                                   time_delta=log_every,
                                   asynchronous=True,
                                   serialize_fn=utils.fetch_devicearray,
                                   steps_key='learner_steps')
     dummy_seed = 1
     spec = specs.make_environment_spec(environment_factory(dummy_seed))
     value_dice_builder = builder.ValueDiceBuilder(
         config=config,
         logger_fn=logger_fn,
         make_demonstrations=make_demonstrations)
     if evaluator_factories is None:
         eval_policy_factory = (
             lambda n: networks.apply_policy_and_sample(n, True))
         evaluator_factories = [
             distributed_layout.default_evaluator_factory(
                 environment_factory=environment_factory,
                 network_factory=network_factory,
                 policy_factory=eval_policy_factory,
                 log_to_bigtable=log_to_bigtable)
         ]
     super().__init__(
         seed=seed,
         environment_spec=spec,
         environment_factory=environment_factory,
         network_factory=network_factory,
         builder=value_dice_builder,
         policy_network=networks.apply_policy_and_sample,
         evaluator_factories=evaluator_factories,
         num_actors=num_actors,
         max_number_of_steps=max_number_of_steps,
         prefetch_size=config.prefetch_size,
         log_to_bigtable=log_to_bigtable,
         actor_logger_fn=distributed_layout.get_default_logger_fn(
             log_to_bigtable, log_every),
     )
コード例 #8
0
    def __init__(
        self,
        environment_factory,
        network_factory,
        config,
        seed,
        num_actors,
        max_number_of_steps=None,
        log_to_bigtable=False,
        log_every=10.0,
        evaluator_factories=None,
    ):
        # Check that the environment-specific parts of the config have been set.
        assert config.max_episode_steps > 0
        assert config.obs_dim > 0

        logger_fn = functools.partial(loggers.make_default_logger,
                                      'learner',
                                      log_to_bigtable,
                                      time_delta=log_every,
                                      asynchronous=True,
                                      serialize_fn=utils.fetch_devicearray,
                                      steps_key='learner_steps')
        contrastive_builder = builder.ContrastiveBuilder(config,
                                                         logger_fn=logger_fn)
        if evaluator_factories is None:
            eval_policy_factory = (
                lambda n: networks.apply_policy_and_sample(n, True))
            eval_observers = [
                contrastive_utils.SuccessObserver(),
                contrastive_utils.DistanceObserver(
                    obs_dim=config.obs_dim,
                    start_index=config.start_index,
                    end_index=config.end_index)
            ]
            evaluator_factories = [
                distributed_layout.default_evaluator_factory(
                    environment_factory=environment_factory,
                    network_factory=network_factory,
                    policy_factory=eval_policy_factory,
                    log_to_bigtable=log_to_bigtable,
                    observers=eval_observers)
            ]
            if config.local:
                evaluator_factories = []
        actor_observers = [
            contrastive_utils.SuccessObserver(),
            contrastive_utils.DistanceObserver(obs_dim=config.obs_dim,
                                               start_index=config.start_index,
                                               end_index=config.end_index)
        ]
        super().__init__(
            seed=seed,
            environment_factory=environment_factory,
            network_factory=network_factory,
            builder=contrastive_builder,
            policy_network=networks.apply_policy_and_sample,
            evaluator_factories=evaluator_factories,
            num_actors=num_actors,
            max_number_of_steps=max_number_of_steps,
            prefetch_size=config.prefetch_size,
            log_to_bigtable=log_to_bigtable,
            actor_logger_fn=distributed_layout.get_default_logger_fn(
                log_to_bigtable, log_every),
            observers=actor_observers,
            checkpointing_config=distributed_layout.CheckpointingConfig())
コード例 #9
0
ファイル: agents.py プロジェクト: vishalbelsare/acme
    def __init__(
        self,
        environment_factory: Callable[[bool], dm_env.Environment],
        network_factory: NetworkFactory,
        random_seed: int,
        num_actors: int,
        environment_spec: Optional[specs.EnvironmentSpec] = None,
        batch_size: int = 256,
        prefetch_size: int = 2,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        samples_per_insert: float = 32.0,
        n_step: int = 5,
        sigma: float = 0.3,
        clipping: bool = True,
        discount: float = 0.99,
        target_update_period: int = 100,
        device_prefetch: bool = True,
        log_to_bigtable: bool = False,
        log_every: float = 10.0,
        evaluator_factories: Optional[Sequence[
            distributed_layout.EvaluatorFactory]] = None,
    ):
        config = d4pg_config.D4PGConfig(
            discount=discount,
            learning_rate=1e-4,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            target_update_period=target_update_period,
            min_replay_size=min_replay_size,
            max_replay_size=max_replay_size,
            samples_per_insert=samples_per_insert,
            n_step=n_step,
            sigma=sigma,
            clipping=clipping,
        )
        logger_fn = functools.partial(loggers.make_default_logger,
                                      'learner',
                                      log_to_bigtable,
                                      time_delta=log_every,
                                      asynchronous=True,
                                      serialize_fn=utils.fetch_devicearray,
                                      steps_key='learner_steps')
        builder = d4pg_builder.D4PGBuilder(config, logger_fn=logger_fn)

        def _policy_network(networks):
            return d4pg_networks.get_default_behavior_policy(networks,
                                                             config=config)

        if evaluator_factories is None:

            def _eval_policy_network(networks):
                return d4pg_networks.get_default_eval_policy(networks)

            evaluator_factories = [
                distributed_layout.default_evaluator_factory(
                    environment_factory=lambda seed: environment_factory(True),
                    network_factory=network_factory,
                    policy_factory=_eval_policy_network,
                    log_to_bigtable=log_to_bigtable)
            ]

        super().__init__(
            seed=random_seed,
            environment_factory=lambda seed: environment_factory(False),
            network_factory=network_factory,
            builder=builder,
            policy_network=_policy_network,
            evaluator_factories=evaluator_factories,
            num_actors=num_actors,
            environment_spec=environment_spec,
            device_prefetch=device_prefetch,
            log_to_bigtable=log_to_bigtable,
            actor_logger_fn=distributed_layout.get_default_logger_fn(
                log_to_bigtable, log_every),
            prefetch_size=config.prefetch_size,
        )