コード例 #1
0
 def compute_metrics(self,
                     ind_x,
                     str_y,
                     metric_names=list(),
                     from_line=False):
     if from_line:
         str_x = list()
         for lines_token in ind_x:
             str_x.append(" ".join([
                 LM_ind_to_str(self.dataset.charset,
                               self.ctc_remove_successives_identical_ind(p),
                               oov_symbol="") if p is not None else ""
                 for p in lines_token
             ]).strip(" "))
     else:
         str_x = [
             LM_ind_to_str(self.dataset.charset,
                           self.ctc_remove_successives_identical_ind(p),
                           oov_symbol="") if p is not None else ""
             for p in ind_x
         ]
     metrics = dict()
     for metric_name in metric_names:
         if metric_name == "cer":
             metrics[metric_name] = [
                 editdistance.eval(u, v) for u, v in zip(str_y, str_x)
             ]
             metrics["nb_chars"] = nb_chars_from_list(str_y)
         elif metric_name == "wer":
             metrics[metric_name] = edit_wer_from_list(str_y, str_x)
             metrics["nb_words"] = nb_words_from_list(str_y)
     metrics["nb_samples"] = len(str_x)
     return metrics
コード例 #2
0
 def compute_metrics(self,
                     x,
                     y,
                     x_len,
                     y_len,
                     loss=None,
                     metric_names=list()):
     batch_size = y.shape[0]
     ind_x = [x[i][:x_len[i]] for i in range(batch_size)]
     ind_y = [y[i][:y_len[i]] for i in range(batch_size)]
     ind_x = [self.ctc_remove_successives_identical_ind(t) for t in ind_x]
     str_x = [
         LM_ind_to_str(self.dataset.charset, t, oov_symbol="")
         for t in ind_x
     ]
     str_y = [LM_ind_to_str(self.dataset.charset, t) for t in ind_y]
     metrics = dict()
     for metric_name in metric_names:
         if metric_name == "cer":
             metrics[metric_name] = [
                 editdistance.eval(u, v) for u, v in zip(str_y, str_x)
             ]
             metrics["nb_chars"] = nb_chars_from_list(str_y)
         elif metric_name == "wer":
             metrics[metric_name] = edit_wer_from_list(str_y, str_x)
             metrics["nb_words"] = nb_words_from_list(str_y)
         elif metric_name == "pred":
             metrics["pred"] = [
                 str_x,
             ]
     if "loss_ctc" in metric_names:
         metrics["loss_ctc"] = loss / metrics["nb_chars"]
     metrics["nb_samples"] = len(x)
     return metrics
コード例 #3
0
 def compute_metrics(self, ind_x, str_y, loss=None, metric_names=list()):
     ind_x = [self.ctc_remove_successives_identical_ind(t) for t in ind_x]
     str_x = [
         LM_ind_to_str(self.dataset.charset, t, oov_symbol="")
         for t in ind_x
     ]
     metrics = dict()
     for metric_name in metric_names:
         if metric_name == "cer":
             metrics[metric_name] = [
                 editdistance.eval(u, v) for u, v in zip(str_y, str_x)
             ]
             metrics["nb_chars"] = nb_chars_from_list(str_y)
         elif metric_name == "wer":
             metrics[metric_name] = edit_wer_from_list(str_y, str_x)
             metrics["nb_words"] = nb_words_from_list(str_y)
         elif metric_name == "pred":
             metrics["pred"] = [
                 str_x,
             ]
     if "loss_ctc" in metric_names:
         metrics["loss_ctc"] = loss
     metrics["nb_samples"] = len(str_y)
     return metrics