示例#1
0
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#            http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Thin wrapper around imitation.scripts.eval_policy."""

from imitation.scripts import eval_policy

from evaluating_rewards.scripts import script_utils

if __name__ == "__main__":
    script_utils.add_logging_config(eval_policy.eval_policy_ex, "eval_policy")
    script_utils.experiment_main(eval_policy.eval_policy_ex,
                                 "eval_policy",
                                 sacred_symlink=False)
示例#2
0

@train_regress_ex.named_config
def test():
    """Small number of epochs, finish quickly, intended for tests / debugging."""
    total_timesteps = 8192  # noqa: F841  pylint:disable=unused-variable


@train_regress_ex.named_config
def dataset_random_transition():
    """Randomly samples state and action and computes next state from dynamics."""
    dataset_factory = (  # noqa: F841  pylint:disable=unused-variable
        datasets.transitions_factory_from_random_model)


script_utils.add_logging_config(train_regress_ex, "train_regress")


@train_regress_ex.main
def train_regress(
    _seed: int,  # pylint:disable=invalid-name
    # Dataset
    env_name: str,
    discount: float,
    dataset_factory: datasets.TransitionsFactory,
    dataset_factory_kwargs: Dict[str, Any],
    # Target specification
    target_reward_type: str,
    target_reward_path: str,
    # Model parameters
    model_reward_type: regress_utils.EnvRewardFactory,
"""Thin wrapper around imitation.scripts.train_adversarial."""

import os

from imitation.scripts import train_adversarial

from evaluating_rewards import serialize
from evaluating_rewards.scripts import script_utils


@train_adversarial.train_ex.named_config
def point_maze():
    """IRL config for PointMaze environment."""
    env_name = "imitation/PointMazeLeftVel-v0"
    rollout_path = os.path.join(
        serialize.get_output_dir(),
        "train_experts/ground_truth/20201203_105631_297835/imitation_PointMazeLeftVel-v0",
        "evaluating_rewards_PointMazeGroundTruthWithCtrl-v0/best/rollouts/final.pkl",
    )
    total_timesteps = 1e6
    _ = locals()
    del _


if __name__ == "__main__":
    script_utils.add_logging_config(train_adversarial.train_ex,
                                    "train_adversarial")
    script_utils.experiment_main(train_adversarial.train_ex,
                                 "train_adversarial",
                                 sacred_symlink=False)
    batch_timesteps = 10000  # total number of timesteps in each batch
    learning_rate = 1e-2
    weight_l2_reg = 0.0  # scaling factor for weight/parameter regularization
    reward_l2_reg = 1e-4  # scaling factor for regularization of output
    accuracy_threshold = 0.5  # minimum probability in correct direction to count as success

    _ = locals()  # quieten flake8 unused variable warning
    del _


FAST_CONFIG = dict(total_timesteps=1e4)
# Duplicate to have consistent interface
train_preferences_ex.add_named_config("test", FAST_CONFIG)
train_preferences_ex.add_named_config("fast", FAST_CONFIG)

script_utils.add_logging_config(train_preferences_ex, "train_preferences")


@train_preferences_ex.main
def train_preferences(
    _seed: int,  # pylint:disable=invalid-name
    # Dataset
    env_name: str,
    discount: float,
    num_vec: int,
    policy_type: str,
    policy_path: str,
    # Target specification
    target_reward_type: str,
    target_reward_path: str,
    # Model parameters
示例#5
0
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#            http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Thin wrapper around imitation.scripts.expert_demos."""

from imitation.scripts import expert_demos

from evaluating_rewards.scripts import script_utils

if __name__ == "__main__":
    script_utils.add_logging_config(expert_demos.expert_demos_ex,
                                    "expert_demos")
    script_utils.experiment_main(expert_demos.expert_demos_ex,
                                 "expert_demos",
                                 sacred_symlink=False)
示例#6
0
    locals().update(**FAST_CONFIG)
    visitations_factory_kwargs = {  # noqa: F841  pylint:disable=unused-variable
        "env_name": "evaluating_rewards/PointMassLine-v0",
        "parallel": False,
        "policy_type": "random",
        "policy_path": "dummy",
    }


@npec_distance_ex.named_config
def high_precision():
    """Increase number of timesteps to increase change of convergence."""
    total_timesteps = int(1e6)  # noqa: F841  pylint:disable=unused-variable


script_utils.add_logging_config(npec_distance_ex, "npec")


@ray.remote
def npec_worker(
    seed: int,
    # Dataset
    env_name: str,
    discount: float,
    visitations_factory,
    visitations_factory_kwargs: Dict[str, Any],
    # Models to compare
    source_reward_cfg: common_config.RewardCfg,
    target_reward_cfg: common_config.RewardCfg,
    # Model parameters
    comparison_class: Type[comparisons.RegressModel],
示例#7
0
    affine_size = 512
    total_timesteps = 8192
    _ = locals()  # quieten flake8 unused variable warning
    del _


@model_comparison_ex.named_config
def dataset_random_transition():
    """Randomly samples state and action and computes next state from dynamics."""
    dataset_factory = datasets.transitions_factory_from_random_model
    dataset_factory_kwargs = {}
    _ = locals()  # quieten flake8 unused variable warning
    del _


script_utils.add_logging_config(model_comparison_ex, "model_comparison")


@model_comparison_ex.main
def model_comparison(
    _seed: int,  # pylint:disable=invalid-name
    # Dataset
    env_name: str,
    discount: float,
    dataset_factory: datasets.TransitionsFactory,
    dataset_factory_kwargs: Dict[str, Any],
    # Source specification
    source_reward_type: str,
    source_reward_path: str,
    # Target specification
    target_reward_type: str,
示例#8
0
    pos_density = 9  # number of points along position axis = number of plots
    lim = 1.0  # points span [-lim, lim]
    pos_lim = lim  # position point range
    vel_lim = lim  # velocity point range
    act_lim = lim  # action point range

    # Figure parameters
    styles = ["paper", "pointmass-2col", "tex"]
    ncols = 3  # number of heatmaps per row
    cbar_kwargs = {"fraction": 0.07, "pad": 0.02}
    fmt = "pdf"  # file type
    _ = locals()  # quieten flake8 unused variable warning
    del _


script_utils.add_logging_config(plot_pm_reward_ex, "plot_pm_reward")


@plot_pm_reward_ex.config
def logging_config(log_root, models, reward_type, reward_path):
    """Default logging configuration."""
    data_root = os.path.join(log_root, "model_comparison")
    if models is None:
        log_dir = os.path.join(
            log_root,
            reward_type.replace("/", "_"),
            reward_path.replace("/", "_"),
        )
    _ = locals()  # quieten flake8 unused variable warning
    del _