Ejemplo n.º 1
0
"""Behavioral Cloning (derived from MARWIL).

Simply uses the MARWIL agent with beta force-set to 0.0.
"""
from ray.rllib.agents.marwil.marwil import MARWILTrainer, \
    DEFAULT_CONFIG as MARWIL_CONFIG
from ray.rllib.utils.typing import TrainerConfigDict

# yapf: disable
# __sphinx_doc_begin__
BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
    MARWIL_CONFIG, {
        # No need to calculate advantages (or do anything else with the
        # rewards).
        "beta": 0.0,
        # Advantages (calculated during postprocessing) not important for
        # behavioral cloning.
        "postprocess_inputs": False,
        # No reward estimation.
        "input_evaluation": [],
    })
# __sphinx_doc_end__
# yapf: enable


def validate_config(config: TrainerConfigDict):
    if config["beta"] != 0.0:
        raise ValueError(
            "For behavioral cloning, `beta` parameter must be 0.0!")

Ejemplo n.º 2
0
"""Behavioral Cloning (derived from MARWIL).

Simply uses the MARWIL agent with beta force-set to 0.0.
"""
from ray.rllib.agents.marwil.marwil import MARWILTrainer, \
    DEFAULT_CONFIG as MARWIL_CONFIG
from ray.rllib.utils.typing import TrainerConfigDict

# yapf: disable
# __sphinx_doc_begin__
BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
    MARWIL_CONFIG, {
        "beta": 0.0,
        "input": '~/cartpole-out/output-2020-12-02_02-28-22_worker-1_34.json'
    })
# __sphinx_doc_end__
# yapf: enable


def validate_config(config: TrainerConfigDict):
    if config["beta"] != 0.0:
        raise ValueError(
            "For behavioral cloning, `beta` parameter must be 0.0!")


BCTrainer = MARWILTrainer.with_updates(
    name="BC",
    default_config=BC_DEFAULT_CONFIG,
    validate_config=validate_config,
)
Ejemplo n.º 3
0
    exp_advs = torch.exp(policy.config["beta"] *
                         (adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5))))
    # log\pi_\theta(a|s)
    logprobs = action_dist.logp(actions)
    policy.p_loss = -1.0 * torch.mean(exp_advs.detach() * logprobs)

    # Combine both losses.
    policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * \
        policy.v_loss
    explained_var = explained_variance(advantages, state_values)
    policy.explained_variance = torch.mean(explained_var)

    return policy.total_loss


MARWILSTorchPolicy = MARWILTorchPolicy.with_updates(
    loss_fn=marwil_loss,
    # postprocess_fn=postprocess_advantages
)


def get_policy_class(config):
    return MARWILSTorchPolicy


MARWILSTrainer = MARWILTrainer.with_updates(
    name="MARWILS",
    default_config=MARWIL_CONFIG,
    get_policy_class=get_policy_class,
    default_policy=MARWILSTorchPolicy,
)
Ejemplo n.º 4
0
"""Behavioral Cloning (derived from MARWIL).

Simply uses the MARWIL agent with beta force-set to 0.0.
"""
from ray.rllib.agents.marwil.marwil import MARWILTrainer, \
    DEFAULT_CONFIG as MARWIL_CONFIG
from ray.rllib.utils.typing import TrainerConfigDict

# yapf: disable
# __sphinx_doc_begin__
BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
    MARWIL_CONFIG, {
        "beta": 0.0,
    })
# __sphinx_doc_end__
# yapf: enable


def validate_config(config: TrainerConfigDict):
    if config["beta"] != 0.0:
        raise ValueError(
            "For behavioral cloning, `beta` parameter must be 0.0!")


BCTrainer = MARWILTrainer.with_updates(
    name="BC",
    default_config=BC_DEFAULT_CONFIG,
    validate_config=validate_config,
)