def get_loss(self, **kwargs): if self.loss_key not in kwargs: check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \ f"in `{self.__class__.__name__}`"], unused=[], duplicated=[], required=[], all_needed=[], varargs=[]) raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) return kwargs[self.loss_key]
def __call__(self, pred_dict, target_dict): """ This method will call self.evaluate method. Before calling self.evaluate, it will first check the validity of output_dict, target_dict (1) whether self.evaluate has varargs, which is not supported. (2) whether params needed by self.evaluate is not included in output_dict,target_dict. (3) whether params needed by self.evaluate duplicate in pred_dict, target_dict (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) Besides, before passing params into self.evaluate, this function will filter out params from output_dict and target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering will be conducted.) This function also support _fast_param_map. :param pred_dict: usually the output of forward or prediction function :param target_dict: usually features set as target.. :return: """ if not callable(self.evaluate): raise TypeError( f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}." ) fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) if fast_param: self.evaluate(**fast_param) return if not self._checked: # 1. check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = set([arg for arg in func_spect.args if arg != 'self']) for func_arg, input_arg in self.param_map.items(): if func_arg not in func_args: raise NameError( f"`{func_arg}` not in {get_func_signature(self.evaluate)}." ) # 2. only part of the param_map are passed, left are not for arg in func_args: if arg not in self.param_map: self.param_map[ arg] = arg # This param does not need mapping. self._evaluate_args = func_args self._reverse_param_map = { input_arg: func_arg for func_arg, input_arg in self.param_map.items() } # need to wrap inputs in dict. mapped_pred_dict = {} mapped_target_dict = {} duplicated = [] for input_arg in set( list(pred_dict.keys()) + list(target_dict.keys())): not_duplicate_flag = 0 if input_arg in self._reverse_param_map: mapped_arg = self._reverse_param_map[input_arg] not_duplicate_flag += 1 else: mapped_arg = input_arg if input_arg in pred_dict: mapped_pred_dict[mapped_arg] = pred_dict[input_arg] not_duplicate_flag += 1 if input_arg in target_dict: mapped_target_dict[mapped_arg] = target_dict[input_arg] not_duplicate_flag += 1 if not_duplicate_flag == 3: duplicated.append(input_arg) # missing if not self._checked: check_res = _check_arg_dict_list( self.evaluate, [mapped_pred_dict, mapped_target_dict]) # only check missing. # replace missing. missing = check_res.missing replaced_missing = list(missing) for idx, func_arg in enumerate(missing): # Don't delete `` in this information, nor add `` replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ f"in `{self.__class__.__name__}`)" check_res = CheckRes(missing=replaced_missing, unused=check_res.unused, duplicated=duplicated, required=check_res.required, all_needed=check_res.all_needed, varargs=check_res.varargs) if check_res.missing or check_res.duplicated or check_res.varargs: raise CheckError(check_res=check_res, func_signature=get_func_signature( self.evaluate)) refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) self.evaluate(**refined_args) self._checked = True return
def __call__(self, pred_dict, target_dict, check=False): """ :param pred_dict: A dict from forward function of the network. :param target_dict: A dict from DataSet.batch_y. :param check: Boolean. Force to check the mapping functions when it is running. :return: """ fast_param = self._fast_param_map(pred_dict, target_dict) if fast_param: loss = self.get_loss(**fast_param) return loss if not self._checked: # 1. check consistence between signature and param_map func_spect = inspect.getfullargspec(self.get_loss) func_args = set([arg for arg in func_spect.args if arg != 'self']) for func_arg, input_arg in self.param_map.items(): if func_arg not in func_args: raise NameError( f"`{func_arg}` not in {get_func_signature(self.get_loss)}." ) # 2. only part of the param_map are passed, left are not for arg in func_args: if arg not in self.param_map: self.param_map[ arg] = arg # This param does not need mapping. self._evaluate_args = func_args self._reverse_param_map = { input_arg: func_arg for func_arg, input_arg in self.param_map.items() } # need to wrap inputs in dict. mapped_pred_dict = {} mapped_target_dict = {} duplicated = [] for input_arg in set( list(pred_dict.keys()) + list(target_dict.keys())): not_duplicate_flag = 0 if input_arg in self._reverse_param_map: mapped_arg = self._reverse_param_map[input_arg] not_duplicate_flag += 1 else: mapped_arg = input_arg if input_arg in pred_dict: mapped_pred_dict[mapped_arg] = pred_dict[input_arg] not_duplicate_flag += 1 if input_arg in target_dict: mapped_target_dict[mapped_arg] = target_dict[input_arg] not_duplicate_flag += 1 if not_duplicate_flag == 3: duplicated.append(input_arg) # missing if not self._checked: check_res = _check_arg_dict_list( self.get_loss, [mapped_pred_dict, mapped_target_dict]) # replace missing. missing = check_res.missing replaced_missing = list(missing) for idx, func_arg in enumerate(missing): # Don't delete `` in this information, nor add `` replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ f"in `{self.__class__.__name__}`)" check_res = CheckRes(missing=replaced_missing, unused=check_res.unused, duplicated=duplicated, required=check_res.required, all_needed=check_res.all_needed, varargs=check_res.varargs) if check_res.missing or check_res.duplicated: raise CheckError(check_res=check_res, func_signature=get_func_signature( self.get_loss)) refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) loss = self.get_loss(**refined_args) self._checked = True return loss