def __init__( self, environment_factory: jax_types.EnvironmentFactory, network_factory: NetworkFactory, td3_fd_config: TD3fDConfig, lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]], seed: int, num_actors: int, environment_spec: specs.EnvironmentSpec, max_number_of_steps: Optional[int] = None, log_to_bigtable: bool = False, log_every: float = 10.0, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, ): logger_fn = functools.partial(loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') td3_config = td3_fd_config.td3_config lfd_config = td3_fd_config.lfd_config td3_builder = td3.TD3Builder(td3_config, logger_fn=logger_fn) lfd_builder = builder.LfdBuilder(td3_builder, lfd_iterator_fn, lfd_config) action_specs = environment_spec.actions policy_network_fn = functools.partial(td3.get_default_behavior_policy, action_specs=action_specs, sigma=td3_config.sigma) if evaluator_factories is None: eval_network_fn = functools.partial( td3.get_default_behavior_policy, action_specs=action_specs, sigma=0.) evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=eval_network_fn, log_to_bigtable=log_to_bigtable) ] super().__init__( seed=seed, environment_factory=environment_factory, network_factory=network_factory, builder=lfd_builder, policy_network=policy_network_fn, evaluator_factories=evaluator_factories, num_actors=num_actors, max_number_of_steps=max_number_of_steps, prefetch_size=td3_config.prefetch_size, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every), )
def __init__( self, environment_factory: jax_types.EnvironmentFactory, network_factory: NetworkFactory, config: ppo_config.PPOConfig, seed: int, num_actors: int, normalize_input: bool = False, logger_fn: Optional[Callable[[], loggers.Logger]] = None, save_reverb_logs: bool = False, log_every: float = 10.0, max_number_of_steps: Optional[int] = None, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, ): logger_fn = logger_fn or functools.partial( loggers.make_default_logger, 'learner', save_reverb_logs, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') ppo_builder = builder.PPOBuilder(config, logger_fn=logger_fn) if normalize_input: dummy_seed = 1 environment_spec = specs.make_environment_spec( environment_factory(dummy_seed)) # Two batch dimensions: [num_sequences, num_steps, ...] batch_dims = (0, 1) ppo_builder = normalization.NormalizationBuilder( ppo_builder, environment_spec, is_sequence_based=True, batch_dims=batch_dims) if evaluator_factories is None: eval_policy_factory = (lambda networks: ppo_networks. make_inference_fn(networks, True)) evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=eval_policy_factory, log_to_bigtable=save_reverb_logs) ] super().__init__( seed=seed, environment_factory=environment_factory, network_factory=network_factory, builder=ppo_builder, policy_network=ppo_networks.make_inference_fn, evaluator_factories=evaluator_factories, num_actors=num_actors, prefetch_size=config.prefetch_size, max_number_of_steps=max_number_of_steps, log_to_bigtable=save_reverb_logs, actor_logger_fn=distributed_layout.get_default_logger_fn( save_reverb_logs, log_every), )
def __init__( self, environment_factory: jax_types.EnvironmentFactory, network_factory: NetworkFactory, config: sac_config.SACConfig, seed: int, num_actors: int, max_number_of_steps: Optional[int] = None, log_to_bigtable: bool = False, log_every: float = 10.0, normalize_input: bool = True, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, ): logger_fn = functools.partial(loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') sac_builder = builder.SACBuilder(config, logger_fn=logger_fn) if normalize_input: dummy_seed = 1 environment_spec = specs.make_environment_spec( environment_factory(dummy_seed)) # One batch dimension: [batch_size, ...] batch_dims = (0, ) sac_builder = normalization.NormalizationBuilder( sac_builder, environment_spec, is_sequence_based=False, batch_dims=batch_dims) if evaluator_factories is None: eval_policy_factory = ( lambda n: networks.apply_policy_and_sample(n, True)) evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=eval_policy_factory, log_to_bigtable=log_to_bigtable) ] super().__init__( seed=seed, environment_factory=environment_factory, network_factory=network_factory, builder=sac_builder, policy_network=networks.apply_policy_and_sample, evaluator_factories=evaluator_factories, num_actors=num_actors, max_number_of_steps=max_number_of_steps, prefetch_size=config.prefetch_size, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every), checkpointing_config=distributed_layout.CheckpointingConfig(), make_snapshot_models=networks.default_models_to_snapshot, )
def __init__( self, environment_factory: jax_types.EnvironmentFactory, rl_agent: builders.GenericActorLearnerBuilder, config: ail_config.AILConfig, network_factory: NetworkFactory, seed: int, batch_size: int, make_demonstrations: Callable[[int], Iterator[types.Transition]], policy_network: Any, evaluator_policy_network: Any, num_actors: int, max_number_of_steps: Optional[int] = None, log_to_bigtable: bool = False, log_every: float = 10.0, prefetch_size: int = 4, discriminator_loss: Optional[losses.Loss] = None, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, ): assert discriminator_loss is not None logger_fn = functools.partial( loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') ail_builder = builder.AILBuilder( rl_agent=rl_agent, config=config, discriminator_loss=discriminator_loss, make_demonstrations=make_demonstrations, logger_fn=logger_fn) if evaluator_factories is None: evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=evaluator_policy_network, log_to_bigtable=log_to_bigtable) ] super().__init__( seed=seed, environment_factory=environment_factory, network_factory=network_factory, builder=ail_builder, policy_network=policy_network, evaluator_factories=evaluator_factories, num_actors=num_actors, max_number_of_steps=max_number_of_steps, prefetch_size=prefetch_size, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every), )
def __init__( self, environment_factory: jax_types.EnvironmentFactory, environment_spec: specs.EnvironmentSpec, network_factory: NetworkFactory, config: r2d2_config.R2D2Config, seed: int, num_actors: int, workdir: str = '~/acme', device_prefetch: bool = False, log_to_bigtable: bool = True, log_every: float = 10.0, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, max_number_of_steps: Optional[int] = None, ): logger_fn = functools.partial(loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') r2d2_builder = builder.R2D2Builder( networks=network_factory(environment_spec), config=config, logger_fn=logger_fn) policy_network_factory = ( lambda n: r2d2_networks.make_behavior_policy(n, config)) if evaluator_factories is None: evaluator_policy_network_factory = ( lambda n: r2d2_networks.make_behavior_policy(n, config, True)) evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=evaluator_policy_network_factory, log_to_bigtable=log_to_bigtable) ] super().__init__( seed=seed, environment_factory=environment_factory, network_factory=network_factory, builder=r2d2_builder, policy_network=policy_network_factory, evaluator_factories=evaluator_factories, num_actors=num_actors, environment_spec=environment_spec, device_prefetch=device_prefetch, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every), prefetch_size=config.prefetch_size, checkpointing_config=distributed_layout.CheckpointingConfig( directory=workdir, add_uid=(workdir == '~/acme')), max_number_of_steps=max_number_of_steps)
def __init__( self, environment_factory: jax_types.EnvironmentFactory, network_factory: NetworkFactory, sac_fd_config: SACfDConfig, lfd_iterator_fn: Callable[[], Iterator[builder.LfdStep]], seed: int, num_actors: int, environment_spec: Optional[specs.EnvironmentSpec] = None, max_number_of_steps: Optional[int] = None, log_to_bigtable: bool = False, log_every: float = 10.0, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, ): logger_fn = functools.partial(loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') sac_config = sac_fd_config.sac_config lfd_config = sac_fd_config.lfd_config sac_builder = sac.SACBuilder(sac_config, logger_fn=logger_fn) lfd_builder = builder.LfdBuilder(sac_builder, lfd_iterator_fn, lfd_config) if evaluator_factories is None: eval_policy_factory = ( lambda n: sac.apply_policy_and_sample(n, True)) evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=eval_policy_factory, log_to_bigtable=log_to_bigtable) ] super().__init__( seed=seed, environment_factory=environment_factory, network_factory=network_factory, environment_spec=environment_spec, builder=lfd_builder, policy_network=sac.apply_policy_and_sample, evaluator_factories=evaluator_factories, num_actors=num_actors, max_number_of_steps=max_number_of_steps, prefetch_size=sac_config.prefetch_size, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every))
def __init__( self, environment_factory: jax_types.EnvironmentFactory, network_factory: NetworkFactory, config: value_dice_config.ValueDiceConfig, make_demonstrations: Callable[[int], Iterator[types.Transition]], seed: int, num_actors: int, max_number_of_steps: Optional[int] = None, log_to_bigtable: bool = False, log_every: float = 10.0, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, ): logger_fn = functools.partial(loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') dummy_seed = 1 spec = specs.make_environment_spec(environment_factory(dummy_seed)) value_dice_builder = builder.ValueDiceBuilder( config=config, logger_fn=logger_fn, make_demonstrations=make_demonstrations) if evaluator_factories is None: eval_policy_factory = ( lambda n: networks.apply_policy_and_sample(n, True)) evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=eval_policy_factory, log_to_bigtable=log_to_bigtable) ] super().__init__( seed=seed, environment_spec=spec, environment_factory=environment_factory, network_factory=network_factory, builder=value_dice_builder, policy_network=networks.apply_policy_and_sample, evaluator_factories=evaluator_factories, num_actors=num_actors, max_number_of_steps=max_number_of_steps, prefetch_size=config.prefetch_size, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every), )
def __init__( self, environment_factory, network_factory, config, seed, num_actors, max_number_of_steps=None, log_to_bigtable=False, log_every=10.0, evaluator_factories=None, ): # Check that the environment-specific parts of the config have been set. assert config.max_episode_steps > 0 assert config.obs_dim > 0 logger_fn = functools.partial(loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') contrastive_builder = builder.ContrastiveBuilder(config, logger_fn=logger_fn) if evaluator_factories is None: eval_policy_factory = ( lambda n: networks.apply_policy_and_sample(n, True)) eval_observers = [ contrastive_utils.SuccessObserver(), contrastive_utils.DistanceObserver( obs_dim=config.obs_dim, start_index=config.start_index, end_index=config.end_index) ] evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=environment_factory, network_factory=network_factory, policy_factory=eval_policy_factory, log_to_bigtable=log_to_bigtable, observers=eval_observers) ] if config.local: evaluator_factories = [] actor_observers = [ contrastive_utils.SuccessObserver(), contrastive_utils.DistanceObserver(obs_dim=config.obs_dim, start_index=config.start_index, end_index=config.end_index) ] super().__init__( seed=seed, environment_factory=environment_factory, network_factory=network_factory, builder=contrastive_builder, policy_network=networks.apply_policy_and_sample, evaluator_factories=evaluator_factories, num_actors=num_actors, max_number_of_steps=max_number_of_steps, prefetch_size=config.prefetch_size, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every), observers=actor_observers, checkpointing_config=distributed_layout.CheckpointingConfig())
def __init__( self, environment_factory: Callable[[bool], dm_env.Environment], network_factory: NetworkFactory, random_seed: int, num_actors: int, environment_spec: Optional[specs.EnvironmentSpec] = None, batch_size: int = 256, prefetch_size: int = 2, min_replay_size: int = 1000, max_replay_size: int = 1000000, samples_per_insert: float = 32.0, n_step: int = 5, sigma: float = 0.3, clipping: bool = True, discount: float = 0.99, target_update_period: int = 100, device_prefetch: bool = True, log_to_bigtable: bool = False, log_every: float = 10.0, evaluator_factories: Optional[Sequence[ distributed_layout.EvaluatorFactory]] = None, ): config = d4pg_config.D4PGConfig( discount=discount, learning_rate=1e-4, batch_size=batch_size, prefetch_size=prefetch_size, target_update_period=target_update_period, min_replay_size=min_replay_size, max_replay_size=max_replay_size, samples_per_insert=samples_per_insert, n_step=n_step, sigma=sigma, clipping=clipping, ) logger_fn = functools.partial(loggers.make_default_logger, 'learner', log_to_bigtable, time_delta=log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') builder = d4pg_builder.D4PGBuilder(config, logger_fn=logger_fn) def _policy_network(networks): return d4pg_networks.get_default_behavior_policy(networks, config=config) if evaluator_factories is None: def _eval_policy_network(networks): return d4pg_networks.get_default_eval_policy(networks) evaluator_factories = [ distributed_layout.default_evaluator_factory( environment_factory=lambda seed: environment_factory(True), network_factory=network_factory, policy_factory=_eval_policy_network, log_to_bigtable=log_to_bigtable) ] super().__init__( seed=random_seed, environment_factory=lambda seed: environment_factory(False), network_factory=network_factory, builder=builder, policy_network=_policy_network, evaluator_factories=evaluator_factories, num_actors=num_actors, environment_spec=environment_spec, device_prefetch=device_prefetch, log_to_bigtable=log_to_bigtable, actor_logger_fn=distributed_layout.get_default_logger_fn( log_to_bigtable, log_every), prefetch_size=config.prefetch_size, )