コード例 #1
0
 def __init__(self, model: SimpleMultiHeadModel, config: ActorCriticConfig):
     if model.task_names is None or set(
             model.task_names) != {"actor", "critic"}:
         raise UnrecognizedTask(
             f"Expected model task names 'actor' and 'critic', but got {model.task_names}"
         )
     super().__init__(model, config)
コード例 #2
0
 def __init__(self, model: SimpleMultiHeadModel, config: DDPGConfig, explorer: NoiseExplorer = None):
     if model.task_names is None or set(model.task_names) != {"policy", "q_value"}:
         raise UnrecognizedTask(f"Expected model task names 'policy' and 'q_value', but got {model.task_names}")
     super().__init__(model, config)
     self._explorer = explorer
     self._target_model = model.copy() if model.trainable else None
     self._train_cnt = 0
コード例 #3
0
 def __init__(self, model: SimpleMultiHeadModel, config: DQNConfig):
     if (config.advantage_type is not None and
         (model.task_names is None
          or set(model.task_names) != {"state_value", "advantage"})):
         raise UnrecognizedTask(
             f"Expected model task names 'state_value' and 'advantage' since dueling DQN is used, "
             f"got {model.task_names}")
     super().__init__(model, config)
     self._training_counter = 0
     self._target_model = model.copy() if model.trainable else None
コード例 #4
0
 def validate_task_names(model_task_names, expected_task_names):
     task_names, expected_task_names = set(model_task_names), set(
         expected_task_names)
     if len(model_task_names) > 1 and task_names != expected_task_names:
         raise UnrecognizedTask(
             f"Expected task names {expected_task_names}, got {task_names}")