예제 #1
0
파일: cql.py 프로젝트: hngenc/ray
            total_timesteps += len(batch)
            # Add NEXT_OBS if not available. This is slightly hacked
            # as for the very last time step, we will use next-obs=zeros
            # and therefore force-set DONE=True to avoid this missing
            # next-obs to cause learning problems.
            if SampleBatch.NEXT_OBS not in batch:
                obs = batch[SampleBatch.OBS]
                batch[SampleBatch.NEXT_OBS] = \
                    np.concatenate([obs[1:], np.zeros_like(obs[0:1])])
                batch[SampleBatch.DONES][-1] = True
            replay_buffer.add_batch(batch)
        print(
            f"Loaded {num_batches} batches ({total_timesteps} ts) into the "
            f"replay buffer, which has capacity {replay_buffer.buffer_size}.")
    else:
        raise ValueError(
            "Unknown offline input! config['input'] must either be list of "
            "offline files (json) or a D4RL-specific InputReader specifier "
            "(e.g. 'd4rl.hopper-medium-v0').")


CQLTrainer = SACTrainer.with_updates(
    name="CQL",
    default_config=CQL_DEFAULT_CONFIG,
    validate_config=validate_config,
    default_policy=CQLTFPolicy,
    get_policy_class=get_policy_class,
    after_init=after_init,
    execution_plan=execution_plan,
)
예제 #2
0
from ray.rllib.agents.sac.sac import SACTrainer
from ray.rllib.agents.sac.sac import DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.utils import merge_dicts

APEX_SAC_DEFAULT_CONFIG = merge_dicts(
    SAC_CONFIG,  # see also the options in ddpg.py, which are also supported
    {
        "optimizer": merge_dicts(
            SAC_CONFIG["optimizer"], {
                "max_weight_sync_delay": 400,
                "num_replay_buffer_shards": 4,
                "debug": False
            }),
        "num_gpus": 0,
        "num_workers": 32,
        "buffer_size": 2000000,
        "learning_starts": 50000,
        "train_batch_size": 512,
        "sample_batch_size": 50,
        "target_network_update_freq": 500000,
        "timesteps_per_iteration": 25000,
        "per_worker_exploration": True,
        "worker_side_prioritization": True,
        "min_iter_time_s": 30,
    },
)

ApexSACTrainer = SACTrainer.with_updates(
    name="APEX_SAC",
    default_config=APEX_SAC_DEFAULT_CONFIG,
    **APEX_TRAINER_PROPERTIES)
예제 #3
0
파일: cql.py 프로젝트: tuyulers5/jav44
        # Num Actions to sample for CQL Loss
        "num_actions": 10,
        # Whether to use the Langrangian for Alpha Prime (in CQL Loss)
        "lagrangian": False,
        # Lagrangian Threshold
        "lagrangian_thresh": 5.0,
        # Min Q Weight multiplier
        "min_q_weight": 5.0,
    })
# __sphinx_doc_end__
# yapf: enable


def validate_config(config: TrainerConfigDict):
    if config["framework"] == "tf":
        raise ValueError("Tensorflow CQL not implemented yet!")


def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
    if config["framework"] == "torch":
        return CQLTorchPolicy


CQLTrainer = SACTrainer.with_updates(
    name="CQL",
    default_config=CQL_DEFAULT_CONFIG,
    validate_config=validate_config,
    default_policy=CQLTorchPolicy,
    get_policy_class=get_policy_class,
)
예제 #4
0
def make_policy_optimizer(workers, config):
    """Create the single process DQN policy optimizer.

    Returns:
        SyncReplayOptimizer: Used for generic off-policy Trainers.
    """
    return SyncReplayOptimizerModified(
        workers,
        learning_starts=config["learning_starts"],
        buffer_size=config["buffer_size"],
        prioritized_replay=config["prioritized_replay"],
        prioritized_replay_alpha=config["prioritized_replay_alpha"],
        prioritized_replay_beta=config["prioritized_replay_beta"],
        prioritized_replay_beta_annealing_timesteps=config[
            "prioritized_replay_beta_annealing_timesteps"],
        final_prioritized_replay_beta=config["final_prioritized_replay_beta"],
        prioritized_replay_eps=config["prioritized_replay_eps"],
        train_batch_size=config["train_batch_size"],
        **config["optimizer"])


DiCESACTrainer = SACTrainer.with_updates(
    name="DiCESACTrainer",
    default_config=dice_sac_default_config,
    default_policy=DiCESACPolicy,
    get_policy_class=lambda _: DiCESACPolicy,
    after_init=setup_policies_pool,
    after_optimizer_step=after_optimizer_step,
    validate_config=validate_config,
    make_policy_optimizer=make_policy_optimizer)