Beispiel #1
0
    def test(self):
        # turn on the testing mode; clean up the history
        network = self._model
        self._mode(network, is_test=True)
        data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)
        eval_results = {}
        try:
            with torch.no_grad():
                for batch_x, batch_y in data_iterator:
                    _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
                    pred_dict = self._data_forward(self._predict_func, batch_x)
                    if not isinstance(pred_dict, dict):
                        raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " 
                                                         f"must be `dict`, got {type(pred_dict)}.")
                    for metric in self.metrics:
                        metric(pred_dict, batch_y)
                for metric in self.metrics:
                    eval_result = metric.get_metric()
                    if not isinstance(eval_result, dict):
                        raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be "
                                        f"`dict`, got {type(eval_result)}")
                    metric_name = metric.__class__.__name__
                    eval_results[metric_name] = eval_result
        except CheckError as e:
            prev_func_signature = get_func_signature(self._predict_func)
            _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
                                 check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
                                 dataset=self.data, check_level=0)

        if self.verbose >= 1:
            print("[tester] \n{}".format(self._format_eval_results(eval_results)))
        self._mode(network, is_test=False)
        return eval_results
Beispiel #2
0
def _check_code(dataset,
                model,
                losser,
                metrics,
                batch_size=DEFAULT_CHECK_BATCH_SIZE,
                dev_data=None,
                metric_key=None,
                check_level=0):
    # check get_loss 方法
    model_devcie = model.parameters().__next__().device

    batch = Batch(dataset=dataset,
                  batch_size=batch_size,
                  sampler=SequentialSampler())
    for batch_count, (batch_x, batch_y) in enumerate(batch):
        _move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
        # forward check
        if batch_count == 0:
            info_str = ""
            input_fields = _get_value_info(batch_x)
            target_fields = _get_value_info(batch_y)
            if len(input_fields) > 0:
                info_str += "input fields after batch(if batch size is {}):\n".format(
                    batch_size)
                info_str += "\n".join(input_fields)
                info_str += '\n'
            else:
                raise RuntimeError("There is no input field.")
            if len(target_fields) > 0:
                info_str += "target fields after batch(if batch size is {}):\n".format(
                    batch_size)
                info_str += "\n".join(target_fields)
                info_str += '\n'
            else:
                info_str += 'There is no target field.'
            print(info_str)
            _check_forward_error(forward_func=model.forward,
                                 dataset=dataset,
                                 batch_x=batch_x,
                                 check_level=check_level)

        refined_batch_x = _build_args(model.forward, **batch_x)
        pred_dict = model(**refined_batch_x)
        func_signature = get_func_signature(model.forward)
        if not isinstance(pred_dict, dict):
            raise TypeError(
                f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`."
            )

        # loss check
        try:
            loss = losser(pred_dict, batch_y)
            # check loss output
            if batch_count == 0:
                if not isinstance(loss, torch.Tensor):
                    raise TypeError(
                        f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
                        f"but got `{type(loss)}`.")
                if len(loss.size()) != 0:
                    raise ValueError(
                        f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, "
                        f"should be torch.size([])")
            loss.backward()
        except CheckError as e:
            # TODO: another error raised if CheckError caught
            pre_func_signature = get_func_signature(model.forward)
            _check_loss_evaluate(prev_func_signature=pre_func_signature,
                                 func_signature=e.func_signature,
                                 check_res=e.check_res,
                                 pred_dict=pred_dict,
                                 target_dict=batch_y,
                                 dataset=dataset,
                                 check_level=check_level)
        model.zero_grad()
        if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
            break

    if dev_data is not None:
        tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH],
                        model=model,
                        metrics=metrics,
                        batch_size=batch_size,
                        verbose=-1)
        evaluate_results = tester.test()
        _check_eval_results(metrics=evaluate_results,
                            metric_key=metric_key,
                            metric_list=metrics)