total_timesteps += len(batch) # Add NEXT_OBS if not available. This is slightly hacked # as for the very last time step, we will use next-obs=zeros # and therefore force-set DONE=True to avoid this missing # next-obs to cause learning problems. if SampleBatch.NEXT_OBS not in batch: obs = batch[SampleBatch.OBS] batch[SampleBatch.NEXT_OBS] = \ np.concatenate([obs[1:], np.zeros_like(obs[0:1])]) batch[SampleBatch.DONES][-1] = True replay_buffer.add_batch(batch) print( f"Loaded {num_batches} batches ({total_timesteps} ts) into the " f"replay buffer, which has capacity {replay_buffer.buffer_size}.") else: raise ValueError( "Unknown offline input! config['input'] must either be list of " "offline files (json) or a D4RL-specific InputReader specifier " "(e.g. 'd4rl.hopper-medium-v0').") CQLTrainer = SACTrainer.with_updates( name="CQL", default_config=CQL_DEFAULT_CONFIG, validate_config=validate_config, default_policy=CQLTFPolicy, get_policy_class=get_policy_class, after_init=after_init, execution_plan=execution_plan, )
from ray.rllib.agents.sac.sac import SACTrainer from ray.rllib.agents.sac.sac import DEFAULT_CONFIG as SAC_CONFIG from ray.rllib.utils import merge_dicts APEX_SAC_DEFAULT_CONFIG = merge_dicts( SAC_CONFIG, # see also the options in ddpg.py, which are also supported { "optimizer": merge_dicts( SAC_CONFIG["optimizer"], { "max_weight_sync_delay": 400, "num_replay_buffer_shards": 4, "debug": False }), "num_gpus": 0, "num_workers": 32, "buffer_size": 2000000, "learning_starts": 50000, "train_batch_size": 512, "sample_batch_size": 50, "target_network_update_freq": 500000, "timesteps_per_iteration": 25000, "per_worker_exploration": True, "worker_side_prioritization": True, "min_iter_time_s": 30, }, ) ApexSACTrainer = SACTrainer.with_updates( name="APEX_SAC", default_config=APEX_SAC_DEFAULT_CONFIG, **APEX_TRAINER_PROPERTIES)
# Num Actions to sample for CQL Loss "num_actions": 10, # Whether to use the Langrangian for Alpha Prime (in CQL Loss) "lagrangian": False, # Lagrangian Threshold "lagrangian_thresh": 5.0, # Min Q Weight multiplier "min_q_weight": 5.0, }) # __sphinx_doc_end__ # yapf: enable def validate_config(config: TrainerConfigDict): if config["framework"] == "tf": raise ValueError("Tensorflow CQL not implemented yet!") def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: if config["framework"] == "torch": return CQLTorchPolicy CQLTrainer = SACTrainer.with_updates( name="CQL", default_config=CQL_DEFAULT_CONFIG, validate_config=validate_config, default_policy=CQLTorchPolicy, get_policy_class=get_policy_class, )
def make_policy_optimizer(workers, config): """Create the single process DQN policy optimizer. Returns: SyncReplayOptimizer: Used for generic off-policy Trainers. """ return SyncReplayOptimizerModified( workers, learning_starts=config["learning_starts"], buffer_size=config["buffer_size"], prioritized_replay=config["prioritized_replay"], prioritized_replay_alpha=config["prioritized_replay_alpha"], prioritized_replay_beta=config["prioritized_replay_beta"], prioritized_replay_beta_annealing_timesteps=config[ "prioritized_replay_beta_annealing_timesteps"], final_prioritized_replay_beta=config["final_prioritized_replay_beta"], prioritized_replay_eps=config["prioritized_replay_eps"], train_batch_size=config["train_batch_size"], **config["optimizer"]) DiCESACTrainer = SACTrainer.with_updates( name="DiCESACTrainer", default_config=dice_sac_default_config, default_policy=DiCESACPolicy, get_policy_class=lambda _: DiCESACPolicy, after_init=setup_policies_pool, after_optimizer_step=after_optimizer_step, validate_config=validate_config, make_policy_optimizer=make_policy_optimizer)