Exemple #1
0
def test(sentence, model):
    chars = []
    # print(sentence)
    for sen in tqdm(sentence.split()):
        idx = data__.to_index(sen)
        if idx == 1:
            idx = data__.to_index(sen.lower())
        # print(idx)
        chars.append(idx)

    seq_len = len(chars)
    target = [0] * seq_len

    target = torch.Tensor(np.array([target]))
    target = target.type(torch.LongTensor)
    seq_len = torch.Tensor(np.array([seq_len]))
    seq_len = seq_len.type(torch.LongTensor)
    chars = torch.Tensor(np.array([chars]))
    chars = chars.type(torch.LongTensor)

    z = dict({'target': target, 'seq_len': seq_len, 'chars': chars})
    x = _build_args(model.predict, **z)
    result = model.predict(**x)
    nx = result['pred'][0].numpy()
    result = []
    for i, j in zip(sentence.split(), nx):
        tmp = {}
        tmp['text'] = i
        tmp['value'] = dict_labels[j]
        result.append(tmp)
    print(result)
    return result
def get_predictions(pred_model, input_data, batch_size, num_workers=4):
    texts = list(list(map(lambda x: vocabs['char'].to_word(x), sample['chars'])) for sample in input_data)
    seq_lens = [sample['seq_len'] for sample in input_data]
    pred_model.to(device)
    sampler = SequentialSampler()
    data_iterator = DataSetIter(dataset=input_data, batch_size=batch_size, sampler=sampler,
                                num_workers=num_workers)
    with torch.no_grad():
        preds, golds = [], []
        pred_model.eval()

        for batch_x, batch_y in data_iterator:
            _move_dict_value_to_device(batch_x, batch_y, device=device)
            x = _build_args(pred_model.forward, **batch_x)
            with torch.no_grad():
                y = pred_model.forward(**x)
            preds.extend(list(map(list, y['pred'].cpu().numpy())))
            golds.extend(list(map(list, batch_y['target'].cpu().numpy())))
    pred_seqs = list(list(map(lambda _y: vocabs['label'].to_word(_y), pred)) for pred in preds)
    gold_seqs = list(list(map(lambda _y: vocabs['label'].to_word(_y), pred)) for pred in golds)
    case_result = []
    for pred_seq, gold_seq, word_seq, seq_len in zip(pred_seqs, gold_seqs, texts, seq_lens):
        pred_seq = pred_seq[:seq_len]
        gold_seq = gold_seq[:seq_len]
        case_result.append((''.join(word_seq), extract_kvpairs_in_bmoes(gold_seq, word_seq),
                           extract_kvpairs_in_bmoes(pred_seq, word_seq)))

    # output for case study
    os.makedirs(f'../output/case_study/{args.dataset}', exist_ok=True)
    fout = open(f'../output/case_study/{args.dataset}/{args.dataset}_bert{args.use_bert}_scheme{args.new_tag_scheme}_ple{args.ple_channel_num}_plstm{int(args.use_ple_lstm)}_trainrate{args.train_dataset_rate}.casestudy', 'w', encoding='utf8')
    for word_seq, gold_pair, pred_pair in case_result:
        fout.write(word_seq + '\n' + str(gold_pair) + '\n' + str(pred_pair) + '\n\n')
Exemple #3
0
    def test(self):
        # turn on the testing mode; clean up the history
        network = self._model
        self.mode(network, is_test=True)
        self.eval_history.clear()
        output, truths = defaultdict(list), defaultdict(list)
        data_iterator = Batch(self.data,
                              self.batch_size,
                              sampler=RandomSampler(),
                              as_numpy=False)

        with torch.no_grad():
            for batch_x, batch_y in data_iterator:
                prediction = self.data_forward(network, batch_x)
                assert isinstance(prediction, dict)
                for k, v in prediction.items():
                    output[k].append(v)
                for k, v in batch_y.items():
                    truths[k].append(v)
            for k, v in output.items():
                output[k] = itertools.chain(*v)
            for k, v in truths.items():
                truths[k] = itertools.chain(*v)
            args = _build_args(self._evaluator, **output, **truths)
            eval_results = self._evaluator(**args)
        print("[tester] {}".format(self.print_eval_results(eval_results)))
        self.mode(network, is_test=False)
        return eval_results
Exemple #4
0
    def process(self, dataset):
        self.model.eval()
        assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
        data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler())

        batch_output = defaultdict(list)
        if hasattr(self.model, "predict"):
            predict_func = self.model.predict
        else:
            predict_func = self.model.forward
        with torch.no_grad():
            for batch_x, _ in data_iterator:
                refined_batch_x = _build_args(predict_func, **batch_x)
                prediction = predict_func(**refined_batch_x)
                seq_lens = batch_x[self.seq_len_field_name].tolist()

                for key, value in prediction.items():
                    tmp_batch = []
                    value = value.cpu().numpy()
                    if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
                        batch_output[key].extend(value.tolist())
                    else:
                        for idx, seq_len in enumerate(seq_lens):
                            tmp_batch.append(value[idx, :seq_len])
                        batch_output[key].extend(tmp_batch)
                if not self.seq_len_field_name in prediction:
                    batch_output[self.seq_len_field_name].extend(seq_lens)

        # TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么
        for field_name, fields in batch_output.items():
            dataset.add_field(field_name, fields, is_input=True, is_target=False)

        return dataset
Exemple #5
0
 def _data_forward(self, network, x):
     x = _build_args(network.forward, **x)
     y = network(**x)
     if not isinstance(y, dict):
         raise TypeError(
             f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}."
         )
     return y
Exemple #6
0
    def get_loss(self, predict, truth):
        """Compute loss given prediction and ground truth.

        :param predict: prediction label vector
        :param truth: ground truth label vector
        :return: a scalar
        """
        assert isinstance(predict, dict) and isinstance(truth, dict)
        args = _build_args(self.loss_func, **predict, **truth)
        return self.loss_func(**args)
Exemple #7
0
    def predict(self, data: DataSet, seq_len_field_name=None):
        r"""用已经训练好的模型进行inference.

        :param fastNLP.DataSet data: 待预测的数据集
        :param str seq_len_field_name: 表示序列长度信息的field名字
        :return: dict dict里面的内容为模型预测的结果
        """
        if not isinstance(data, DataSet):
            raise ValueError("Only Dataset class is allowed, not {}.".format(
                type(data)))
        if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
            raise ValueError("Field name {} not found in DataSet {}.".format(
                seq_len_field_name, data))

        prev_training = self.network.training
        self.network.eval()
        network_device = _get_model_device(self.network)
        batch_output = defaultdict(list)
        data_iterator = DataSetIter(data,
                                    batch_size=self.batch_size,
                                    sampler=SequentialSampler(),
                                    as_numpy=False)

        if hasattr(self.network, "predict"):
            predict_func = self.network.predict
        else:
            predict_func = self.network.forward

        with torch.no_grad():
            for batch_x, _ in data_iterator:
                _move_dict_value_to_device(batch_x, _, device=network_device)
                refined_batch_x = _build_args(predict_func, **batch_x)
                prediction = predict_func(**refined_batch_x)

                if seq_len_field_name is not None:
                    seq_lens = batch_x[seq_len_field_name].tolist()

                for key, value in prediction.items():
                    value = value.cpu().numpy()
                    if len(value.shape) == 1 or (len(value.shape) == 2
                                                 and value.shape[1] == 1):
                        batch_output[key].extend(value.tolist())
                    else:
                        if seq_len_field_name is not None:
                            tmp_batch = []
                            for idx, seq_len in enumerate(seq_lens):
                                tmp_batch.append(value[idx, :seq_len])
                            batch_output[key].extend(tmp_batch)
                        else:
                            batch_output[key].append(value)

        self.network.train(prev_training)
        return batch_output
Exemple #8
0
    def predict(self, data, seq_len_field_name=None):
        """Perform inference using the trained model.

        :param data: a DataSet object.
        :param str seq_len_field_name: field name indicating sequence lengths
        :return: list of batch outputs
        """
        if not isinstance(data, DataSet):
            raise ValueError("Only Dataset class is allowed, not {}.".format(
                type(data)))
        if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
            raise ValueError("Field name {} not found in DataSet {}.".format(
                seq_len_field_name, data))

        self.network.eval()
        batch_output = defaultdict(list)
        data_iterator = Batch(data,
                              batch_size=self.batch_size,
                              sampler=SequentialSampler(),
                              as_numpy=False,
                              prefetch=False)

        if hasattr(self.network, "predict"):
            predict_func = self.network.predict
        else:
            predict_func = self.network.forward

        with torch.no_grad():
            for batch_x, _ in data_iterator:
                refined_batch_x = _build_args(predict_func, **batch_x)
                prediction = predict_func(**refined_batch_x)

                if seq_len_field_name is not None:
                    seq_lens = batch_x[seq_len_field_name].tolist()

                for key, value in prediction.items():
                    value = value.cpu().numpy()
                    if len(value.shape) == 1 or (len(value.shape) == 2
                                                 and value.shape[1] == 1):
                        batch_output[key].extend(value.tolist())
                    else:
                        if seq_len_field_name is not None:
                            tmp_batch = []
                            for idx, seq_len in enumerate(seq_lens):
                                tmp_batch.append(value[idx, :seq_len])
                            batch_output[key].extend(tmp_batch)
                        else:
                            batch_output[key].append(value)

        return batch_output
Exemple #9
0
    def predict(self, data: DataSet, seq_len_field_name=None):
        r"""
        """
        if not isinstance(data, DataSet):
            raise ValueError(
                "Only Dataset class is allowed, not {}.".format(type(data)))
        if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
            raise ValueError("Field name {} not found in DataSet {}.".format(
                seq_len_field_name, data))

        self.network.eval()  # self.network.module for multi-GPU
        network_device = _get_model_device(self.network)
        batch_output = defaultdict(list)
        data_iterator = DataSetIter(
            data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)

        # predict_func = self.network.module.predict  # self.network.module for
        # multi-GPU
        try:
            predict_func = self.network.predict
        except ModuleAttributeError:
            predict_func = self.network.module.predict

        with torch.no_grad():
            #            for batch_x, _ in tqdm(data_iterator):
            for batch_x, _ in tqdm(data_iterator, total=len(data_iterator)):
                _move_dict_value_to_device(batch_x, _, device=network_device)
                refined_batch_x = _build_args(predict_func, **batch_x)
                prediction = predict_func(**refined_batch_x)
                if seq_len_field_name is not None:
                    seq_lens = batch_x[seq_len_field_name].tolist()

                for key, value in prediction.items():
                    value = value.cpu().numpy()
                    if len(value.shape) == 1 or (
                            len(value.shape) == 2 and value.shape[1] == 1):
                        batch_output[key].extend(value.tolist())
                    else:
                        if seq_len_field_name is not None:
                            tmp_batch = []
                            for idx, seq_len in enumerate(seq_lens):
                                tmp_batch.append(value[idx, :seq_len])
                            batch_output[key].extend(tmp_batch)
                        else:
                            batch_output[key].append(value)
        return batch_output
Exemple #10
0
    def get_loss(self, inputs, targets, hidden, dags):
        """Computes the loss for the same batch for M models.

        This amounts to an estimate of the loss, which is turned into an
        estimate for the gradients of the shared model.
        """
        if not isinstance(dags, list):
            dags = [dags]

        loss = 0
        for dag in dags:
            self.shared.setDAG(dag)
            inputs = _build_args(self.shared.forward, **inputs)
            inputs['hidden'] = hidden
            result = self.shared(**inputs)
            output, hidden, extra_out = result['pred'], result[
                'hidden'], result['extra_out']

            self.callback_manager.on_loss_begin(targets, result)
            sample_loss = self._compute_loss(result, targets)
            loss += sample_loss

        assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
        return loss, hidden, extra_out
Exemple #11
0
    def __call__(self, pred_dict, target_dict):
        """

        This method will call self.evaluate method.
        Before calling self.evaluate, it will first check the validity of output_dict, target_dict
            (1) whether self.evaluate has varargs, which is not supported.
            (2) whether params needed by self.evaluate is not included in output_dict,target_dict.
            (3) whether params needed by self.evaluate duplicate in pred_dict, target_dict
            (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
        Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
            target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
            will be conducted.)
        This function also support _fast_param_map.
        :param pred_dict: usually the output of forward or prediction function
        :param target_dict: usually features set as target..
        :return:
        """
        if not callable(self.evaluate):
            raise TypeError(
                f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}."
            )

        fast_param = self._fast_param_map(pred_dict=pred_dict,
                                          target_dict=target_dict)
        if fast_param:
            self.evaluate(**fast_param)
            return

        if not self._checked:
            # 1. check consistence between signature and param_map
            func_spect = inspect.getfullargspec(self.evaluate)
            func_args = set([arg for arg in func_spect.args if arg != 'self'])
            for func_arg, input_arg in self.param_map.items():
                if func_arg not in func_args:
                    raise NameError(
                        f"`{func_arg}` not in {get_func_signature(self.evaluate)}."
                    )

            # 2. only part of the param_map are passed, left are not
            for arg in func_args:
                if arg not in self.param_map:
                    self.param_map[
                        arg] = arg  # This param does not need mapping.
            self._evaluate_args = func_args
            self._reverse_param_map = {
                input_arg: func_arg
                for func_arg, input_arg in self.param_map.items()
            }

        # need to wrap inputs in dict.
        mapped_pred_dict = {}
        mapped_target_dict = {}
        duplicated = []
        for input_arg in set(
                list(pred_dict.keys()) + list(target_dict.keys())):
            not_duplicate_flag = 0
            if input_arg in self._reverse_param_map:
                mapped_arg = self._reverse_param_map[input_arg]
                not_duplicate_flag += 1
            else:
                mapped_arg = input_arg
            if input_arg in pred_dict:
                mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
                not_duplicate_flag += 1
            if input_arg in target_dict:
                mapped_target_dict[mapped_arg] = target_dict[input_arg]
                not_duplicate_flag += 1
            if not_duplicate_flag == 3:
                duplicated.append(input_arg)

        # missing
        if not self._checked:
            check_res = _check_arg_dict_list(
                self.evaluate, [mapped_pred_dict, mapped_target_dict])
            # only check missing.
            # replace missing.
            missing = check_res.missing
            replaced_missing = list(missing)
            for idx, func_arg in enumerate(missing):
                # Don't delete `` in this information, nor add ``
                replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
                                                                        f"in `{self.__class__.__name__}`)"

            check_res = CheckRes(missing=replaced_missing,
                                 unused=check_res.unused,
                                 duplicated=duplicated,
                                 required=check_res.required,
                                 all_needed=check_res.all_needed,
                                 varargs=check_res.varargs)

            if check_res.missing or check_res.duplicated or check_res.varargs:
                raise CheckError(check_res=check_res,
                                 func_signature=get_func_signature(
                                     self.evaluate))
        refined_args = _build_args(self.evaluate, **mapped_pred_dict,
                                   **mapped_target_dict)

        self.evaluate(**refined_args)
        self._checked = True

        return
Exemple #12
0
def _check_code(dataset,
                model,
                batch_size=DEFAULT_CHECK_BATCH_SIZE,
                dev_data=None,
                check_level=WARNING_CHECK_LEVEL):
    # check get_loss 方法
    model_name = model.__class__.__name__
    if not hasattr(model, 'get_loss'):
        raise AttributeError(
            "{} has to have a 'get_loss' function.".format(model_name))

    batch = Batch(dataset=dataset,
                  batch_size=batch_size,
                  sampler=SequentialSampler())
    for batch_count, (batch_x, batch_y) in enumerate(batch):
        _syn_model_data(model, batch_x, batch_y)
        # forward check
        if batch_count == 0:
            _check_forward_error(model_func=model.forward,
                                 check_level=check_level,
                                 batch_x=batch_x)

        refined_batch_x = _build_args(model.forward, **batch_x)
        output = model(**refined_batch_x)
        func_signature = get_func_signature(model.forward)
        assert isinstance(
            output, dict), "The return value of {} should be dict.".format(
                func_signature)

        # loss check
        if batch_count == 0:
            _check_loss_evaluate(prev_func=model.forward,
                                 func=model.get_loss,
                                 check_level=check_level,
                                 output=output,
                                 batch_y=batch_y)
        loss_input = _build_args(model.get_loss, **output, **batch_y)
        loss = model.get_loss(**loss_input)

        # check loss output
        if batch_count == 0:
            if not isinstance(loss, torch.Tensor):
                raise ValueError(
                    "The return value of {}.get_loss() should be torch.Tensor, but {} got."
                    .format(model_name, type(loss)))
            if len(loss.size()) != 0:
                raise ValueError(
                    "The size of return value of {}.get_loss() is {}, should be torch.size([])"
                    .format(model_name, loss.size()))
        loss.backward()
        model.zero_grad()
        if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
            break

    if dev_data is not None:
        if not hasattr(model, 'evaluate'):
            raise AttributeError(
                "{} has to have a 'evaluate' function to do evaluation. Or set"
                "dev_data to 'None'.".format(model_name))
        outputs, truths = defaultdict(list), defaultdict(list)
        dev_batch = Batch(dataset=dataset,
                          batch_size=batch_size,
                          sampler=SequentialSampler())
        with torch.no_grad():
            for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
                _syn_model_data(model, batch_x, batch_y)

                if hasattr(model, 'predict'):
                    refined_batch_x = _build_args(model.predict, **batch_x)
                    prev_func = model.predict
                    output = prev_func(**refined_batch_x)
                    func_signature = get_func_signature(model.predict)
                    assert isinstance(
                        output,
                        dict), "The return value of {} should be dict.".format(
                            func_signature)
                else:
                    refined_batch_x = _build_args(model.forward, **batch_x)
                    prev_func = model.forward
                    output = prev_func(**refined_batch_x)
                for k, v in output.items():
                    outputs[k].append(v)
                for k, v in batch_y.items():
                    truths[k].append(v)
                if batch_count + 1 > DEFAULT_CHECK_NUM_BATCH:
                    break
            for k, v in outputs.items():
                outputs[k] = itertools.chain(*v)
            for k, v in truths.items():
                truths[k] = itertools.chain(*v)
            _check_loss_evaluate(prev_func=prev_func,
                                 func=model.evaluate,
                                 check_level=check_level,
                                 output=outputs,
                                 batch_y=truths)
            refined_input = _build_args(model.evaluate, **outputs, **truths)
            metrics = model.evaluate(**refined_input)
            func_signature = get_func_signature(model.evaluate)
            assert isinstance(metrics, dict), "The return value of {} should be dict.". \
                format(func_signature)
Exemple #13
0
 def data_forward(self, network, x):
     x = _build_args(network.forward, **x)
     y = network(**x)
     return y
Exemple #14
0
def _check_code(dataset,
                model,
                losser,
                metrics,
                batch_size=DEFAULT_CHECK_BATCH_SIZE,
                dev_data=None,
                metric_key=None,
                check_level=0):
    # check get_loss 方法
    model_devcie = model.parameters().__next__().device

    batch = Batch(dataset=dataset,
                  batch_size=batch_size,
                  sampler=SequentialSampler())
    for batch_count, (batch_x, batch_y) in enumerate(batch):
        _move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
        # forward check
        if batch_count == 0:
            info_str = ""
            input_fields = _get_value_info(batch_x)
            target_fields = _get_value_info(batch_y)
            if len(input_fields) > 0:
                info_str += "input fields after batch(if batch size is {}):\n".format(
                    batch_size)
                info_str += "\n".join(input_fields)
                info_str += '\n'
            else:
                raise RuntimeError("There is no input field.")
            if len(target_fields) > 0:
                info_str += "target fields after batch(if batch size is {}):\n".format(
                    batch_size)
                info_str += "\n".join(target_fields)
                info_str += '\n'
            else:
                info_str += 'There is no target field.'
            print(info_str)
            _check_forward_error(forward_func=model.forward,
                                 dataset=dataset,
                                 batch_x=batch_x,
                                 check_level=check_level)

        refined_batch_x = _build_args(model.forward, **batch_x)
        pred_dict = model(**refined_batch_x)
        func_signature = get_func_signature(model.forward)
        if not isinstance(pred_dict, dict):
            raise TypeError(
                f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`."
            )

        # loss check
        try:
            loss = losser(pred_dict, batch_y)
            # check loss output
            if batch_count == 0:
                if not isinstance(loss, torch.Tensor):
                    raise TypeError(
                        f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
                        f"but got `{type(loss)}`.")
                if len(loss.size()) != 0:
                    raise ValueError(
                        f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, "
                        f"should be torch.size([])")
            loss.backward()
        except CheckError as e:
            # TODO: another error raised if CheckError caught
            pre_func_signature = get_func_signature(model.forward)
            _check_loss_evaluate(prev_func_signature=pre_func_signature,
                                 func_signature=e.func_signature,
                                 check_res=e.check_res,
                                 pred_dict=pred_dict,
                                 target_dict=batch_y,
                                 dataset=dataset,
                                 check_level=check_level)
        model.zero_grad()
        if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
            break

    if dev_data is not None:
        tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH],
                        model=model,
                        metrics=metrics,
                        batch_size=batch_size,
                        verbose=-1)
        evaluate_results = tester.test()
        _check_eval_results(metrics=evaluate_results,
                            metric_key=metric_key,
                            metric_list=metrics)
Exemple #15
0
    def __call__(self, pred_dict, target_dict, check=False):
        """
        :param pred_dict: A dict from forward function of the network.
        :param target_dict: A dict from DataSet.batch_y.
        :param check: Boolean. Force to check the mapping functions when it is running.
        :return:
        """
        fast_param = self._fast_param_map(pred_dict, target_dict)
        if fast_param:
            loss = self.get_loss(**fast_param)
            return loss

        if not self._checked:
            # 1. check consistence between signature and param_map
            func_spect = inspect.getfullargspec(self.get_loss)
            func_args = set([arg for arg in func_spect.args if arg != 'self'])
            for func_arg, input_arg in self.param_map.items():
                if func_arg not in func_args:
                    raise NameError(
                        f"`{func_arg}` not in {get_func_signature(self.get_loss)}."
                    )

            # 2. only part of the param_map are passed, left are not
            for arg in func_args:
                if arg not in self.param_map:
                    self.param_map[
                        arg] = arg  # This param does not need mapping.
            self._evaluate_args = func_args
            self._reverse_param_map = {
                input_arg: func_arg
                for func_arg, input_arg in self.param_map.items()
            }

        # need to wrap inputs in dict.
        mapped_pred_dict = {}
        mapped_target_dict = {}
        duplicated = []
        for input_arg in set(
                list(pred_dict.keys()) + list(target_dict.keys())):
            not_duplicate_flag = 0
            if input_arg in self._reverse_param_map:
                mapped_arg = self._reverse_param_map[input_arg]
                not_duplicate_flag += 1
            else:
                mapped_arg = input_arg
            if input_arg in pred_dict:
                mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
                not_duplicate_flag += 1
            if input_arg in target_dict:
                mapped_target_dict[mapped_arg] = target_dict[input_arg]
                not_duplicate_flag += 1
            if not_duplicate_flag == 3:
                duplicated.append(input_arg)

        # missing
        if not self._checked:
            check_res = _check_arg_dict_list(
                self.get_loss, [mapped_pred_dict, mapped_target_dict])
            # replace missing.
            missing = check_res.missing
            replaced_missing = list(missing)
            for idx, func_arg in enumerate(missing):
                # Don't delete `` in this information, nor add ``
                replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
                                                                        f"in `{self.__class__.__name__}`)"

            check_res = CheckRes(missing=replaced_missing,
                                 unused=check_res.unused,
                                 duplicated=duplicated,
                                 required=check_res.required,
                                 all_needed=check_res.all_needed,
                                 varargs=check_res.varargs)

            if check_res.missing or check_res.duplicated:
                raise CheckError(check_res=check_res,
                                 func_signature=get_func_signature(
                                     self.get_loss))
        refined_args = _build_args(self.get_loss, **mapped_pred_dict,
                                   **mapped_target_dict)

        loss = self.get_loss(**refined_args)
        self._checked = True

        return loss
Exemple #16
0
 def data_forward(self, network, x):
     """A forward pass of the model. """
     x = _build_args(network.forward, **x)
     y = self._predict_func(**x)
     return y
Exemple #17
0
 def _data_forward(self, func, x):
     """A forward pass of the model. """
     x = _build_args(func, **x)
     y = func(**x)
     return y