def __init__(self, model: SimpleMultiHeadModel, config: ActorCriticConfig): if model.task_names is None or set( model.task_names) != {"actor", "critic"}: raise UnrecognizedTask( f"Expected model task names 'actor' and 'critic', but got {model.task_names}" ) super().__init__(model, config)
def __init__(self, model: SimpleMultiHeadModel, config: DDPGConfig, explorer: NoiseExplorer = None): if model.task_names is None or set(model.task_names) != {"policy", "q_value"}: raise UnrecognizedTask(f"Expected model task names 'policy' and 'q_value', but got {model.task_names}") super().__init__(model, config) self._explorer = explorer self._target_model = model.copy() if model.trainable else None self._train_cnt = 0
def __init__(self, model: SimpleMultiHeadModel, config: DQNConfig): if (config.advantage_type is not None and (model.task_names is None or set(model.task_names) != {"state_value", "advantage"})): raise UnrecognizedTask( f"Expected model task names 'state_value' and 'advantage' since dueling DQN is used, " f"got {model.task_names}") super().__init__(model, config) self._training_counter = 0 self._target_model = model.copy() if model.trainable else None
def validate_task_names(model_task_names, expected_task_names): task_names, expected_task_names = set(model_task_names), set( expected_task_names) if len(model_task_names) > 1 and task_names != expected_task_names: raise UnrecognizedTask( f"Expected task names {expected_task_names}, got {task_names}")