# Copyright 2019 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Thin wrapper around imitation.scripts.eval_policy.""" from imitation.scripts import eval_policy from evaluating_rewards.scripts import script_utils if __name__ == "__main__": script_utils.add_logging_config(eval_policy.eval_policy_ex, "eval_policy") script_utils.experiment_main(eval_policy.eval_policy_ex, "eval_policy", sacred_symlink=False)
@train_regress_ex.named_config def test(): """Small number of epochs, finish quickly, intended for tests / debugging.""" total_timesteps = 8192 # noqa: F841 pylint:disable=unused-variable @train_regress_ex.named_config def dataset_random_transition(): """Randomly samples state and action and computes next state from dynamics.""" dataset_factory = ( # noqa: F841 pylint:disable=unused-variable datasets.transitions_factory_from_random_model) script_utils.add_logging_config(train_regress_ex, "train_regress") @train_regress_ex.main def train_regress( _seed: int, # pylint:disable=invalid-name # Dataset env_name: str, discount: float, dataset_factory: datasets.TransitionsFactory, dataset_factory_kwargs: Dict[str, Any], # Target specification target_reward_type: str, target_reward_path: str, # Model parameters model_reward_type: regress_utils.EnvRewardFactory,
"""Thin wrapper around imitation.scripts.train_adversarial.""" import os from imitation.scripts import train_adversarial from evaluating_rewards import serialize from evaluating_rewards.scripts import script_utils @train_adversarial.train_ex.named_config def point_maze(): """IRL config for PointMaze environment.""" env_name = "imitation/PointMazeLeftVel-v0" rollout_path = os.path.join( serialize.get_output_dir(), "train_experts/ground_truth/20201203_105631_297835/imitation_PointMazeLeftVel-v0", "evaluating_rewards_PointMazeGroundTruthWithCtrl-v0/best/rollouts/final.pkl", ) total_timesteps = 1e6 _ = locals() del _ if __name__ == "__main__": script_utils.add_logging_config(train_adversarial.train_ex, "train_adversarial") script_utils.experiment_main(train_adversarial.train_ex, "train_adversarial", sacred_symlink=False)
batch_timesteps = 10000 # total number of timesteps in each batch learning_rate = 1e-2 weight_l2_reg = 0.0 # scaling factor for weight/parameter regularization reward_l2_reg = 1e-4 # scaling factor for regularization of output accuracy_threshold = 0.5 # minimum probability in correct direction to count as success _ = locals() # quieten flake8 unused variable warning del _ FAST_CONFIG = dict(total_timesteps=1e4) # Duplicate to have consistent interface train_preferences_ex.add_named_config("test", FAST_CONFIG) train_preferences_ex.add_named_config("fast", FAST_CONFIG) script_utils.add_logging_config(train_preferences_ex, "train_preferences") @train_preferences_ex.main def train_preferences( _seed: int, # pylint:disable=invalid-name # Dataset env_name: str, discount: float, num_vec: int, policy_type: str, policy_path: str, # Target specification target_reward_type: str, target_reward_path: str, # Model parameters
# Copyright 2019 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Thin wrapper around imitation.scripts.expert_demos.""" from imitation.scripts import expert_demos from evaluating_rewards.scripts import script_utils if __name__ == "__main__": script_utils.add_logging_config(expert_demos.expert_demos_ex, "expert_demos") script_utils.experiment_main(expert_demos.expert_demos_ex, "expert_demos", sacred_symlink=False)
locals().update(**FAST_CONFIG) visitations_factory_kwargs = { # noqa: F841 pylint:disable=unused-variable "env_name": "evaluating_rewards/PointMassLine-v0", "parallel": False, "policy_type": "random", "policy_path": "dummy", } @npec_distance_ex.named_config def high_precision(): """Increase number of timesteps to increase change of convergence.""" total_timesteps = int(1e6) # noqa: F841 pylint:disable=unused-variable script_utils.add_logging_config(npec_distance_ex, "npec") @ray.remote def npec_worker( seed: int, # Dataset env_name: str, discount: float, visitations_factory, visitations_factory_kwargs: Dict[str, Any], # Models to compare source_reward_cfg: common_config.RewardCfg, target_reward_cfg: common_config.RewardCfg, # Model parameters comparison_class: Type[comparisons.RegressModel],
affine_size = 512 total_timesteps = 8192 _ = locals() # quieten flake8 unused variable warning del _ @model_comparison_ex.named_config def dataset_random_transition(): """Randomly samples state and action and computes next state from dynamics.""" dataset_factory = datasets.transitions_factory_from_random_model dataset_factory_kwargs = {} _ = locals() # quieten flake8 unused variable warning del _ script_utils.add_logging_config(model_comparison_ex, "model_comparison") @model_comparison_ex.main def model_comparison( _seed: int, # pylint:disable=invalid-name # Dataset env_name: str, discount: float, dataset_factory: datasets.TransitionsFactory, dataset_factory_kwargs: Dict[str, Any], # Source specification source_reward_type: str, source_reward_path: str, # Target specification target_reward_type: str,
pos_density = 9 # number of points along position axis = number of plots lim = 1.0 # points span [-lim, lim] pos_lim = lim # position point range vel_lim = lim # velocity point range act_lim = lim # action point range # Figure parameters styles = ["paper", "pointmass-2col", "tex"] ncols = 3 # number of heatmaps per row cbar_kwargs = {"fraction": 0.07, "pad": 0.02} fmt = "pdf" # file type _ = locals() # quieten flake8 unused variable warning del _ script_utils.add_logging_config(plot_pm_reward_ex, "plot_pm_reward") @plot_pm_reward_ex.config def logging_config(log_root, models, reward_type, reward_path): """Default logging configuration.""" data_root = os.path.join(log_root, "model_comparison") if models is None: log_dir = os.path.join( log_root, reward_type.replace("/", "_"), reward_path.replace("/", "_"), ) _ = locals() # quieten flake8 unused variable warning del _