Ejemplo n.º 1
0
    def apply(self, func, new_field_name=None, **kwargs):
        """Apply a function to every instance of the DataSet.

        :param func: a function that takes an instance as input.
        :param str new_field_name: If not None, results of the function will be stored as a new field.
        :param **kwargs: Accept parameters will be
            (1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input.
            (2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target.
        :return results: if new_field_name is not passed, returned values of the function over all instances.
        """
        results = [func(ins) for ins in self._inner_iter()]
        if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0:  # all None
            raise ValueError("{} always return None.".format(get_func_signature(func=func)))

        extra_param = {}
        if 'is_input' in kwargs:
            extra_param['is_input'] = kwargs['is_input']
        if 'is_target' in kwargs:
            extra_param['is_target'] = kwargs['is_target']
        if new_field_name is not None:
            if new_field_name in self.field_arrays:
                # overwrite the field, keep same attributes
                old_field = self.field_arrays[new_field_name]
                if 'is_input' not in extra_param:
                    extra_param['is_input'] = old_field.is_input
                if 'is_target' not in extra_param:
                    extra_param['is_target'] = old_field.is_target
                self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"],
                               is_target=extra_param["is_target"])
            else:
                self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None),
                               is_target=extra_param.get("is_target", None))
        else:
            return results
Ejemplo n.º 2
0
def _check_forward_error(model_func, check_level, batch_x):
    check_res = _check_arg_dict_list(model_func, batch_x)
    _missing = ''
    _unused = ''
    func_signature = get_func_signature(model_func)
    if len(check_res.missing) != 0:
        _missing = "Function {} misses {}, only provided with {}, " \
                   ".\n".format(func_signature, check_res.missing,
                                             list(batch_x.keys()))
    if len(check_res.unused) != 0:
        if len(check_res.unused) > 1:
            _unused = "{} are not used ".format(check_res.unused)
        else:
            _unused = "{} is not used ".format(check_res.unused)
        _unused += "in function {}.\n".format(func_signature)
    if _missing:
        if len(_unused) > 0 and STRICT_CHECK_LEVEL:
            _error_str = "(1).{}\n(2).{}".format(_missing, _unused)
        else:
            _error_str = _missing
        # TODO 这里可能需要自定义一些Error类型
        raise TypeError(_error_str)
    if _unused:
        if check_level == STRICT_CHECK_LEVEL:
            # TODO 这里可能需要自定义一些Error类型
            raise ValueError(_unused)
        elif check_level == WARNING_CHECK_LEVEL:
            warnings.warn(message=_unused)
Ejemplo n.º 3
0
    def test(self):
        # turn on the testing mode; clean up the history
        network = self._model
        self._mode(network, is_test=True)
        data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)
        eval_results = {}
        try:
            with torch.no_grad():
                for batch_x, batch_y in data_iterator:
                    _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
                    pred_dict = self._data_forward(self._predict_func, batch_x)
                    if not isinstance(pred_dict, dict):
                        raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " 
                                                         f"must be `dict`, got {type(pred_dict)}.")
                    for metric in self.metrics:
                        metric(pred_dict, batch_y)
                for metric in self.metrics:
                    eval_result = metric.get_metric()
                    if not isinstance(eval_result, dict):
                        raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be "
                                        f"`dict`, got {type(eval_result)}")
                    metric_name = metric.__class__.__name__
                    eval_results[metric_name] = eval_result
        except CheckError as e:
            prev_func_signature = get_func_signature(self._predict_func)
            _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
                                 check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
                                 dataset=self.data, check_level=0)

        if self.verbose >= 1:
            print("[tester] \n{}".format(self._format_eval_results(eval_results)))
        self._mode(network, is_test=False)
        return eval_results
Ejemplo n.º 4
0
 def get_loss(self, **kwargs):
     if self.loss_key not in kwargs:
         check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \
                                                                     f"in `{self.__class__.__name__}`"],
                              unused=[],
                              duplicated=[],
                              required=[],
                              all_needed=[],
                              varargs=[])
         raise CheckError(check_res=check_res,
                          func_signature=get_func_signature(self.get_loss))
     return kwargs[self.loss_key]
Ejemplo n.º 5
0
    def __call__(self, pred_dict, target_dict):
        """

        This method will call self.evaluate method.
        Before calling self.evaluate, it will first check the validity of output_dict, target_dict
            (1) whether self.evaluate has varargs, which is not supported.
            (2) whether params needed by self.evaluate is not included in output_dict,target_dict.
            (3) whether params needed by self.evaluate duplicate in pred_dict, target_dict
            (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
        Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
            target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
            will be conducted.)
        This function also support _fast_param_map.
        :param pred_dict: usually the output of forward or prediction function
        :param target_dict: usually features set as target..
        :return:
        """
        if not callable(self.evaluate):
            raise TypeError(
                f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}."
            )

        fast_param = self._fast_param_map(pred_dict=pred_dict,
                                          target_dict=target_dict)
        if fast_param:
            self.evaluate(**fast_param)
            return

        if not self._checked:
            # 1. check consistence between signature and param_map
            func_spect = inspect.getfullargspec(self.evaluate)
            func_args = set([arg for arg in func_spect.args if arg != 'self'])
            for func_arg, input_arg in self.param_map.items():
                if func_arg not in func_args:
                    raise NameError(
                        f"`{func_arg}` not in {get_func_signature(self.evaluate)}."
                    )

            # 2. only part of the param_map are passed, left are not
            for arg in func_args:
                if arg not in self.param_map:
                    self.param_map[
                        arg] = arg  # This param does not need mapping.
            self._evaluate_args = func_args
            self._reverse_param_map = {
                input_arg: func_arg
                for func_arg, input_arg in self.param_map.items()
            }

        # need to wrap inputs in dict.
        mapped_pred_dict = {}
        mapped_target_dict = {}
        duplicated = []
        for input_arg in set(
                list(pred_dict.keys()) + list(target_dict.keys())):
            not_duplicate_flag = 0
            if input_arg in self._reverse_param_map:
                mapped_arg = self._reverse_param_map[input_arg]
                not_duplicate_flag += 1
            else:
                mapped_arg = input_arg
            if input_arg in pred_dict:
                mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
                not_duplicate_flag += 1
            if input_arg in target_dict:
                mapped_target_dict[mapped_arg] = target_dict[input_arg]
                not_duplicate_flag += 1
            if not_duplicate_flag == 3:
                duplicated.append(input_arg)

        # missing
        if not self._checked:
            check_res = _check_arg_dict_list(
                self.evaluate, [mapped_pred_dict, mapped_target_dict])
            # only check missing.
            # replace missing.
            missing = check_res.missing
            replaced_missing = list(missing)
            for idx, func_arg in enumerate(missing):
                # Don't delete `` in this information, nor add ``
                replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
                                                                        f"in `{self.__class__.__name__}`)"

            check_res = CheckRes(missing=replaced_missing,
                                 unused=check_res.unused,
                                 duplicated=duplicated,
                                 required=check_res.required,
                                 all_needed=check_res.all_needed,
                                 varargs=check_res.varargs)

            if check_res.missing or check_res.duplicated or check_res.varargs:
                raise CheckError(check_res=check_res,
                                 func_signature=get_func_signature(
                                     self.evaluate))
        refined_args = _build_args(self.evaluate, **mapped_pred_dict,
                                   **mapped_target_dict)

        self.evaluate(**refined_args)
        self._checked = True

        return
Ejemplo n.º 6
0
def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):

    check_res = _check_arg_dict_list(func, [output, batch_y])
    _missing = ''
    _unused = ''
    _duplicated = ''
    func_signature = get_func_signature(func)
    prev_func_signature = get_func_signature(prev_func)
    if len(check_res.missing) > 0:
        _missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \
                   "{}(from target in Dataset)." \
                   .format(func_signature, check_res.missing,
                            list(output.keys()), prev_func_signature,
                           list(batch_y.keys()))
    if len(check_res.unused) > 0:
        if len(check_res.unused) > 1:
            _unused = "{} are not used ".format(check_res.unused)
        else:
            _unused = "{} is not used ".format(check_res.unused)
        _unused += "in function {}.\n".format(func_signature)
    if len(check_res.duplicated) > 0:
        if len(check_res.duplicated) > 1:
            _duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \
                          "them in {} at the same time.".format(check_res.duplicated,
                                                                            func_signature,
                                                                            check_res.duplicated,
                                                                  prev_func_signature)
        else:
            _duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
                          "it in {} at the same time.".format(check_res.duplicated,
                                                                func_signature,
                                                                check_res.duplicated,
                                                                prev_func_signature)
    _number_errs = int(len(_missing) != 0) + int(len(_duplicated) != 0) + int(
        len(_unused) != 0)
    if _number_errs > 0:
        _error_strs = []
        if _number_errs > 1:
            count = 0
            order_words = ['Firstly', 'Secondly', 'Thirdly']
            if _missing:
                _error_strs.append('{}, {}'.format(order_words[count],
                                                   _missing))
                count += 1
            if _duplicated:
                _error_strs.append('{}, {}'.format(order_words[count],
                                                   _duplicated))
                count += 1
            if _unused and check_level == STRICT_CHECK_LEVEL:
                _error_strs.append('{}, {}'.format(order_words[count],
                                                   _unused))
        else:
            if _unused:
                if check_level == STRICT_CHECK_LEVEL:
                    # TODO 这里可能需要自定义一些Error类型
                    _error_strs.append(_unused)
                elif check_level == WARNING_CHECK_LEVEL:
                    _unused = _unused.strip()
                    warnings.warn(_unused)
            else:
                if _missing:
                    _error_strs.append(_missing)
                if _duplicated:
                    _error_strs.append(_duplicated)

        if _error_strs:
            raise ValueError('\n' + '\n'.join(_error_strs))
Ejemplo n.º 7
0
def _check_code(dataset,
                model,
                batch_size=DEFAULT_CHECK_BATCH_SIZE,
                dev_data=None,
                check_level=WARNING_CHECK_LEVEL):
    # check get_loss 方法
    model_name = model.__class__.__name__
    if not hasattr(model, 'get_loss'):
        raise AttributeError(
            "{} has to have a 'get_loss' function.".format(model_name))

    batch = Batch(dataset=dataset,
                  batch_size=batch_size,
                  sampler=SequentialSampler())
    for batch_count, (batch_x, batch_y) in enumerate(batch):
        _syn_model_data(model, batch_x, batch_y)
        # forward check
        if batch_count == 0:
            _check_forward_error(model_func=model.forward,
                                 check_level=check_level,
                                 batch_x=batch_x)

        refined_batch_x = _build_args(model.forward, **batch_x)
        output = model(**refined_batch_x)
        func_signature = get_func_signature(model.forward)
        assert isinstance(
            output, dict), "The return value of {} should be dict.".format(
                func_signature)

        # loss check
        if batch_count == 0:
            _check_loss_evaluate(prev_func=model.forward,
                                 func=model.get_loss,
                                 check_level=check_level,
                                 output=output,
                                 batch_y=batch_y)
        loss_input = _build_args(model.get_loss, **output, **batch_y)
        loss = model.get_loss(**loss_input)

        # check loss output
        if batch_count == 0:
            if not isinstance(loss, torch.Tensor):
                raise ValueError(
                    "The return value of {}.get_loss() should be torch.Tensor, but {} got."
                    .format(model_name, type(loss)))
            if len(loss.size()) != 0:
                raise ValueError(
                    "The size of return value of {}.get_loss() is {}, should be torch.size([])"
                    .format(model_name, loss.size()))
        loss.backward()
        model.zero_grad()
        if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
            break

    if dev_data is not None:
        if not hasattr(model, 'evaluate'):
            raise AttributeError(
                "{} has to have a 'evaluate' function to do evaluation. Or set"
                "dev_data to 'None'.".format(model_name))
        outputs, truths = defaultdict(list), defaultdict(list)
        dev_batch = Batch(dataset=dataset,
                          batch_size=batch_size,
                          sampler=SequentialSampler())
        with torch.no_grad():
            for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
                _syn_model_data(model, batch_x, batch_y)

                if hasattr(model, 'predict'):
                    refined_batch_x = _build_args(model.predict, **batch_x)
                    prev_func = model.predict
                    output = prev_func(**refined_batch_x)
                    func_signature = get_func_signature(model.predict)
                    assert isinstance(
                        output,
                        dict), "The return value of {} should be dict.".format(
                            func_signature)
                else:
                    refined_batch_x = _build_args(model.forward, **batch_x)
                    prev_func = model.forward
                    output = prev_func(**refined_batch_x)
                for k, v in output.items():
                    outputs[k].append(v)
                for k, v in batch_y.items():
                    truths[k].append(v)
                if batch_count + 1 > DEFAULT_CHECK_NUM_BATCH:
                    break
            for k, v in outputs.items():
                outputs[k] = itertools.chain(*v)
            for k, v in truths.items():
                truths[k] = itertools.chain(*v)
            _check_loss_evaluate(prev_func=prev_func,
                                 func=model.evaluate,
                                 check_level=check_level,
                                 output=outputs,
                                 batch_y=truths)
            refined_input = _build_args(model.evaluate, **outputs, **truths)
            metrics = model.evaluate(**refined_input)
            func_signature = get_func_signature(model.evaluate)
            assert isinstance(metrics, dict), "The return value of {} should be dict.". \
                format(func_signature)
Ejemplo n.º 8
0
def _check_code(dataset,
                model,
                losser,
                metrics,
                batch_size=DEFAULT_CHECK_BATCH_SIZE,
                dev_data=None,
                metric_key=None,
                check_level=0):
    # check get_loss 方法
    model_devcie = model.parameters().__next__().device

    batch = Batch(dataset=dataset,
                  batch_size=batch_size,
                  sampler=SequentialSampler())
    for batch_count, (batch_x, batch_y) in enumerate(batch):
        _move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
        # forward check
        if batch_count == 0:
            info_str = ""
            input_fields = _get_value_info(batch_x)
            target_fields = _get_value_info(batch_y)
            if len(input_fields) > 0:
                info_str += "input fields after batch(if batch size is {}):\n".format(
                    batch_size)
                info_str += "\n".join(input_fields)
                info_str += '\n'
            else:
                raise RuntimeError("There is no input field.")
            if len(target_fields) > 0:
                info_str += "target fields after batch(if batch size is {}):\n".format(
                    batch_size)
                info_str += "\n".join(target_fields)
                info_str += '\n'
            else:
                info_str += 'There is no target field.'
            print(info_str)
            _check_forward_error(forward_func=model.forward,
                                 dataset=dataset,
                                 batch_x=batch_x,
                                 check_level=check_level)

        refined_batch_x = _build_args(model.forward, **batch_x)
        pred_dict = model(**refined_batch_x)
        func_signature = get_func_signature(model.forward)
        if not isinstance(pred_dict, dict):
            raise TypeError(
                f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`."
            )

        # loss check
        try:
            loss = losser(pred_dict, batch_y)
            # check loss output
            if batch_count == 0:
                if not isinstance(loss, torch.Tensor):
                    raise TypeError(
                        f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
                        f"but got `{type(loss)}`.")
                if len(loss.size()) != 0:
                    raise ValueError(
                        f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, "
                        f"should be torch.size([])")
            loss.backward()
        except CheckError as e:
            # TODO: another error raised if CheckError caught
            pre_func_signature = get_func_signature(model.forward)
            _check_loss_evaluate(prev_func_signature=pre_func_signature,
                                 func_signature=e.func_signature,
                                 check_res=e.check_res,
                                 pred_dict=pred_dict,
                                 target_dict=batch_y,
                                 dataset=dataset,
                                 check_level=check_level)
        model.zero_grad()
        if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
            break

    if dev_data is not None:
        tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH],
                        model=model,
                        metrics=metrics,
                        batch_size=batch_size,
                        verbose=-1)
        evaluate_results = tester.test()
        _check_eval_results(metrics=evaluate_results,
                            metric_key=metric_key,
                            metric_list=metrics)
Ejemplo n.º 9
0
    def __call__(self, pred_dict, target_dict, check=False):
        """
        :param pred_dict: A dict from forward function of the network.
        :param target_dict: A dict from DataSet.batch_y.
        :param check: Boolean. Force to check the mapping functions when it is running.
        :return:
        """
        fast_param = self._fast_param_map(pred_dict, target_dict)
        if fast_param:
            loss = self.get_loss(**fast_param)
            return loss

        if not self._checked:
            # 1. check consistence between signature and param_map
            func_spect = inspect.getfullargspec(self.get_loss)
            func_args = set([arg for arg in func_spect.args if arg != 'self'])
            for func_arg, input_arg in self.param_map.items():
                if func_arg not in func_args:
                    raise NameError(
                        f"`{func_arg}` not in {get_func_signature(self.get_loss)}."
                    )

            # 2. only part of the param_map are passed, left are not
            for arg in func_args:
                if arg not in self.param_map:
                    self.param_map[
                        arg] = arg  # This param does not need mapping.
            self._evaluate_args = func_args
            self._reverse_param_map = {
                input_arg: func_arg
                for func_arg, input_arg in self.param_map.items()
            }

        # need to wrap inputs in dict.
        mapped_pred_dict = {}
        mapped_target_dict = {}
        duplicated = []
        for input_arg in set(
                list(pred_dict.keys()) + list(target_dict.keys())):
            not_duplicate_flag = 0
            if input_arg in self._reverse_param_map:
                mapped_arg = self._reverse_param_map[input_arg]
                not_duplicate_flag += 1
            else:
                mapped_arg = input_arg
            if input_arg in pred_dict:
                mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
                not_duplicate_flag += 1
            if input_arg in target_dict:
                mapped_target_dict[mapped_arg] = target_dict[input_arg]
                not_duplicate_flag += 1
            if not_duplicate_flag == 3:
                duplicated.append(input_arg)

        # missing
        if not self._checked:
            check_res = _check_arg_dict_list(
                self.get_loss, [mapped_pred_dict, mapped_target_dict])
            # replace missing.
            missing = check_res.missing
            replaced_missing = list(missing)
            for idx, func_arg in enumerate(missing):
                # Don't delete `` in this information, nor add ``
                replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
                                                                        f"in `{self.__class__.__name__}`)"

            check_res = CheckRes(missing=replaced_missing,
                                 unused=check_res.unused,
                                 duplicated=duplicated,
                                 required=check_res.required,
                                 all_needed=check_res.all_needed,
                                 varargs=check_res.varargs)

            if check_res.missing or check_res.duplicated:
                raise CheckError(check_res=check_res,
                                 func_signature=get_func_signature(
                                     self.get_loss))
        refined_args = _build_args(self.get_loss, **mapped_pred_dict,
                                   **mapped_target_dict)

        loss = self.get_loss(**refined_args)
        self._checked = True

        return loss