def test(self): # turn on the testing mode; clean up the history network = self._model self._mode(network, is_test=True) data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) eval_results = {} try: with torch.no_grad(): for batch_x, batch_y in data_iterator: _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) pred_dict = self._data_forward(self._predict_func, batch_x) if not isinstance(pred_dict, dict): raise TypeError(f"The return value of {get_func_signature(self._predict_func)} " f"must be `dict`, got {type(pred_dict)}.") for metric in self.metrics: metric(pred_dict, batch_y) for metric in self.metrics: eval_result = metric.get_metric() if not isinstance(eval_result, dict): raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be " f"`dict`, got {type(eval_result)}") metric_name = metric.__class__.__name__ eval_results[metric_name] = eval_result except CheckError as e: prev_func_signature = get_func_signature(self._predict_func) _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, dataset=self.data, check_level=0) if self.verbose >= 1: print("[tester] \n{}".format(self._format_eval_results(eval_results))) self._mode(network, is_test=False) return eval_results
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 model_devcie = model.parameters().__next__().device batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) for batch_count, (batch_x, batch_y) in enumerate(batch): _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) # forward check if batch_count == 0: info_str = "" input_fields = _get_value_info(batch_x) target_fields = _get_value_info(batch_y) if len(input_fields) > 0: info_str += "input fields after batch(if batch size is {}):\n".format( batch_size) info_str += "\n".join(input_fields) info_str += '\n' else: raise RuntimeError("There is no input field.") if len(target_fields) > 0: info_str += "target fields after batch(if batch size is {}):\n".format( batch_size) info_str += "\n".join(target_fields) info_str += '\n' else: info_str += 'There is no target field.' print(info_str) _check_forward_error(forward_func=model.forward, dataset=dataset, batch_x=batch_x, check_level=check_level) refined_batch_x = _build_args(model.forward, **batch_x) pred_dict = model(**refined_batch_x) func_signature = get_func_signature(model.forward) if not isinstance(pred_dict, dict): raise TypeError( f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`." ) # loss check try: loss = losser(pred_dict, batch_y) # check loss output if batch_count == 0: if not isinstance(loss, torch.Tensor): raise TypeError( f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, " f"but got `{type(loss)}`.") if len(loss.size()) != 0: raise ValueError( f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, " f"should be torch.size([])") loss.backward() except CheckError as e: # TODO: another error raised if CheckError caught pre_func_signature = get_func_signature(model.forward) _check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, dataset=dataset, check_level=check_level) model.zero_grad() if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: break if dev_data is not None: tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, batch_size=batch_size, verbose=-1) evaluate_results = tester.test() _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)