Esempio n. 1
0
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results
        for fetches in unbatch_dict(fetches_batch):
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            predicted_tokens = fetches["predicted_tokens"]

            # If we're using beam search we take the first beam
            if np.ndim(predicted_tokens) > 1:
                predicted_tokens = predicted_tokens[:, 0]

            fetches["features.source_tokens"] = np.char.decode(
                fetches["features.source_tokens"].astype("S"), "utf-8")
            source_tokens = fetches["features.source_tokens"]
            source_len = fetches["features.source_len"]

            if self._unk_replace_fn is not None:
                # We slice the attention scores so that we do not
                # accidentially replace UNK with a SEQUENCE_END token
                attention_scores = fetches["attention_scores"]
                attention_scores = attention_scores[:, :source_len - 1]
                predicted_tokens = self._unk_replace_fn(
                    source_tokens=source_tokens,
                    predicted_tokens=predicted_tokens,
                    attention_scores=attention_scores)

            sent = self.params["delimiter"].join(predicted_tokens).split(
                "SEQUENCE_END")[0]
            # Replace special BPE tokens
            sent = sent.replace("@@ ", "")
            sent = sent.strip()

            print(sent)
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results
        for fetches in unbatch_dict(fetches_batch):
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            predicted_tokens = fetches["predicted_tokens"]

            # If we're using beam search we take the first beam
            if np.ndim(predicted_tokens) > 1:
                predicted_tokens = predicted_tokens[:, 0]

            fetches["features.source_tokens"] = np.char.decode(
                fetches["features.source_tokens"].astype("S"), "utf-8")
            source_tokens = fetches["features.source_tokens"]
            source_len = fetches["features.source_len"]

            if self._unk_replace_fn is not None:
                # We slice the attention scores so that we do not
                # accidentially replace UNK with a SEQUENCE_END token
                attention_scores = fetches["attention_scores"]
                attention_scores = attention_scores[:, :source_len - 1]
                predicted_tokens = self._unk_replace_fn(
                    source_tokens=source_tokens,
                    predicted_tokens=predicted_tokens,
                    attention_scores=attention_scores)

            sent = self.extract_sentence(predicted_tokens)
            prompt = self.extract_sentence(source_tokens)

            print(prompt + '\n' + sent + '\n\n')
            with open(TEST_OUTPUT_FILE, 'a') as output_file:
                output_file.write(sent + '\n')
Esempio n. 3
0
 def after_run(self, _run_context, run_values):
   fetches_batch = run_values.results
   for fetches in unbatch_dict(fetches_batch):
     self._beam_accum["predicted_ids"].append(fetches[
         "beam_search_output.predicted_ids"])
     self._beam_accum["beam_parent_ids"].append(fetches[
         "beam_search_output.beam_parent_ids"])
     self._beam_accum["scores"].append(fetches["beam_search_output.scores"])
     self._beam_accum["log_probs"].append(fetches[
         "beam_search_output.log_probs"])
Esempio n. 4
0
 def after_run(self, _run_context, run_values):
   fetches_batch = run_values.results
   for fetches in unbatch_dict(fetches_batch):
     self._beam_accum["predicted_ids"].append(fetches[
         "beam_search_output.predicted_ids"])
     self._beam_accum["beam_parent_ids"].append(fetches[
         "beam_search_output.beam_parent_ids"])
     self._beam_accum["scores"].append(fetches["beam_search_output.scores"])
     self._beam_accum["log_probs"].append(fetches[
         "beam_search_output.log_probs"])
Esempio n. 5
0
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results
        for fetches in unbatch_dict(fetches_batch):
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            predicted_tokens = fetches["predicted_tokens"]

            # print("Number of beams ....... ", predicted_tokens.shape,
            #       np.ndim(predicted_tokens))
            self.callback_func(self.run_through_beam(predicted_tokens))
Esempio n. 6
0
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results
        for fetches in unbatch_dict(fetches_batch):
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            predicted_tokens = fetches["predicted_tokens"]

            # If we're using beam search we take the first beam
            if np.ndim(predicted_tokens) > 1:
                predicted_tokens = predicted_tokens[:, 0]

            fetches["features.source_tokens"] = np.char.decode(
                fetches["features.source_tokens"].astype("S"), "utf-8")
            source_tokens = fetches["features.source_tokens"]
            source_len = fetches["features.source_len"]

            if self._unk_replace_fn is not None:
                # We slice the attention scores so that we do not
                # accidentially replace UNK with a SEQUENCE_END token
                attention_scores = fetches["attention_scores"]
                attention_scores = attention_scores[:, :source_len - 1]
                predicted_tokens = self._unk_replace_fn(
                    source_tokens=source_tokens,
                    predicted_tokens=predicted_tokens,
                    attention_scores=attention_scores)

            predicted_tokens_str = self.params["delimiter"].join(
                predicted_tokens).split("SEQUENCE_END")[0]
            logits = fetches["logits"]
            predicted_size = len(predicted_tokens_str.split())
            predicted_ids = fetches["predicted_ids"][:predicted_size]

            # Perplexity calculation
            p = 0.
            count_n = 0
            for i in range(len(predicted_ids)):
                entries = logits[i]
                exps = np.exp(entries) / np.sum(np.exp(entries), axis=0)
                # Take the probability of the word given history
                p += np.log2(exps[predicted_ids[i]])
                count_n += 1
            # Perplexity for the sentence
            perplexity = np.power(2, -(1 / max(1, count_n)) * p)
            sent = "{:.3f}\n".format(perplexity)

            # Apply postproc
            if self._postproc_fn:
                sent = self._postproc_fn(sent)

            sent = sent.strip()

            print(sent)
Esempio n. 7
0
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results
        for fetches in unbatch_dict(fetches_batch):
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            predicted_tokens = fetches["predicted_tokens"]

            print("Number of beams ....... ", predicted_tokens.shape,
                  np.ndim(predicted_tokens))
            self.callback_func(self.run_through_beam(predicted_tokens))

    # # Original after_run for decode
    # def after_run(self, _run_context, run_values):
    #   fetches_batch = run_values.results
    #   for fetches in unbatch_dict(fetches_batch):
    #     # Convert to unicode
    #     fetches["predicted_tokens"] = np.char.decode(
    #         fetches["predicted_tokens"].astype("S"), "utf-8")
    #     predicted_tokens = fetches["predicted_tokens"]

    #     # If we're using beam search we take the first beam
    #     if np.ndim(predicted_tokens) > 1:
    #       predicted_tokens = predicted_tokens[:, 0]

    #     fetches["features.source_tokens"] = np.char.decode(
    #         fetches["features.source_tokens"].astype("S"), "utf-8")
    #     source_tokens = fetches["features.source_tokens"]
    #     source_len = fetches["features.source_len"]

    #     if self._unk_replace_fn is not None:
    #       # We slice the attention scores so that we do not
    #       # accidentially replace UNK with a SEQUENCE_END token
    #       attention_scores = fetches["attention_scores"]
    #       attention_scores = attention_scores[:, :source_len - 1]
    #       predicted_tokens = self._unk_replace_fn(
    #           source_tokens=source_tokens,
    #           predicted_tokens=predicted_tokens,
    #           attention_scores=attention_scores)

    #     sent = self.params["delimiter"].join(predicted_tokens).split(
    #         "SEQUENCE_END")[0]

    #     # Apply postproc
    #     if self._postproc_fn:
    #       sent = self._postproc_fn(sent)

    #     sent = sent.strip()

    #     print(sent)
    def after_run(self, _run_context, run_values):
        # print('_run_context: ', _run_context)
        # print('run_values:', _run_context)
        fetches_batch = run_values.results
        # print('\nunbatch_dict(fetches_batch):', unbatch_dict(fetches_batch),'\n')
        for fetches in unbatch_dict(fetches_batch):
            # print('\nfetches:',fetches,'\n')
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            # print('After Run->fetches["predicted_tokens"]: ', fetches["predicted_tokens"])
            predicted_tokens = fetches["predicted_tokens"]

            # If we're using beam search we take the first beam
            if np.ndim(predicted_tokens) > 1:
                predicted_tokens = predicted_tokens[:, 0]

            fetches["features.source_tokens"] = np.char.decode(
                fetches["features.source_tokens"].astype("S"), "utf-8")
            source_tokens = fetches["features.source_tokens"]
            source_len = fetches["features.source_len"]

            if self._unk_replace_fn is not None:
                # We slice the attention scores so that we do not
                # accidentially replace UNK with a SEQUENCE_END token
                attention_scores = fetches["attention_scores"]
                attention_scores = attention_scores[:, :source_len - 1]
                predicted_tokens = self._unk_replace_fn(
                    source_tokens=source_tokens,
                    predicted_tokens=predicted_tokens,
                    attention_scores=attention_scores)

            sent = self.params["delimiter"].join(predicted_tokens).split(
                "SEQUENCE_END")[0]

            # Apply postproc
            if self._postproc_fn:
                sent = self._postproc_fn(sent)

            sent = sent.strip()

            if self.params["print_source"]:
                print()
                print(self.params["delimiter"].join(source_tokens).split(
                    "SEQUENCE_END")[0])

            try:
                print(sent)
            except:
                print(sent.encode('utf-8'))
Esempio n. 9
0
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results

        for fetches in unbatch_dict(fetches_batch):
            fetches['features.source_tokens'] = np.char.decode(
                fetches['features.source_tokens'].astype('S'), 'utf-8')
            source_tokens = fetches['features.source_tokens']

            fetches['predicted_tokens'] = np.char.decode(
                fetches['predicted_tokens'].astype('S'), 'utf-8')
            predicted_tokens = fetches['predicted_tokens']

            # if using beam search, take the first beam
            if predicted_tokens.shape[1] > 1:
                predicted_tokens = predicted_tokens[:, 0]

            self.callback_func(source_tokens, predicted_tokens)
Esempio n. 10
0
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results
        for fetches in unbatch_dict(fetches_batch):
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            predicted_tokens = fetches["predicted_tokens"]
            if 'attention_scores' in fetches:
                attention_scores = fetches['attention_scores']

            fetches["features.source_tokens"] = np.char.decode(
                fetches["features.source_tokens"].astype("S"), "utf-8")
            source_tokens = fetches["features.source_tokens"]
            source_len = fetches["features.source_len"]

            # If we're not using beam search we convert them into beam search format
            if np.ndim(predicted_tokens) == 1:
                predicted_tokens = predicted_tokens[:, np.newaxis]
                if 'attention_scores' in fetches:
                    attention_scores = fetches['attention_scores'][:, np.
                                                                   newaxis, :]

            for i in range(predicted_tokens.shape[1]):  # loop over all beams
                p_tokens_beam = predicted_tokens[:, i]

                if self._unk_replace_fn is not None:
                    # We slice the attention scores so that we do not
                    # accidentially replace UNK with a SEQUENCE_END token
                    a_scores_beam = attention_scores[:, i, :source_len - 1]

                    p_tokens = self._unk_replace_fn(
                        source_tokens=source_tokens,
                        predicted_tokens=p_tokens_beam,
                        attention_scores=a_scores_beam)

                sent = self.params["delimiter"].join(p_tokens).split(
                    "SEQUENCE_END")[0]

                # Apply postproc
                if self._postproc_fn:
                    sent = self._postproc_fn(sent)

                sent = sent.strip()

                print(sent)
Esempio n. 11
0
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results
        for fetches in unbatch_dict(fetches_batch):
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            predicted_tokens = fetches["predicted_tokens"]

            # If we're using beam search we take the first beam
            # TODO: beam search top k
            if np.ndim(predicted_tokens) > 1:
                predicted_tokens = predicted_tokens[:, 0]

            fetches["features.source_tokens"] = np.char.decode(
                fetches["features.source_tokens"].astype("S"), "utf-8")
            source_tokens = fetches["features.source_tokens"]

            self.callback_func(source_tokens, predicted_tokens)
Esempio n. 12
0
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results
        for fetches in unbatch_dict(fetches_batch):
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            fetches["features.source_tokens"] = np.char.decode(
                fetches["features.source_tokens"].astype("S"), "utf-8")

            if self.params["dump_plots"]:
                output_path = os.path.join(self.params["output_dir"],
                                           "{:05d}.png".format(self._idx))
                _create_figure(fetches)
                plt.savefig(output_path)
                plt.close()
                tf.logging.info("Wrote %s", output_path)
                self._idx += 1
            self._attention_scores_accum.append(_get_scores(fetches))
Esempio n. 13
0
  def after_run(self, _run_context, run_values):
    fetches_batch = run_values.results
    for fetches in unbatch_dict(fetches_batch):
      # Convert to unicode
      fetches["predicted_tokens"] = np.char.decode(
          fetches["predicted_tokens"].astype("S"), "utf-8")
      fetches["features.source_tokens"] = np.char.decode(
          fetches["features.source_tokens"].astype("S"), "utf-8")

      if self.params["dump_plots"]:
        output_path = os.path.join(self.params["output_dir"],
                                   "{:05d}.png".format(self._idx))
        _create_figure(fetches)
        plt.savefig(output_path)
        plt.close()
        tf.logging.info("Wrote %s", output_path)
        self._idx += 1
      self._attention_scores_accum.append(_get_scores(fetches))
Esempio n. 14
0
  def after_run(self, _run_context, run_values):
    fetches_batch = run_values.results
    for fetches in unbatch_dict(fetches_batch):
      # Convert to unicode
      fetches["predicted_tokens"] = np.char.decode(
          fetches["predicted_tokens"].astype("S"), "utf-8")
      predicted_tokens = fetches["predicted_tokens"]

      # If we're using beam search we take the first beam
      if np.ndim(predicted_tokens) > 1:
        predicted_tokens = predicted_tokens[:, 0]

      fetches["features.source_tokens"] = np.char.decode(
          fetches["features.source_tokens"].astype("S"), "utf-8")
      source_tokens = fetches["features.source_tokens"]
      source_len = fetches["features.source_len"]

      if self._unk_replace_fn is not None:
        # We slice the attention scores so that we do not
        # accidentially replace UNK with a SEQUENCE_END token
        attention_scores = fetches["attention_scores"]
        attention_scores = attention_scores[:, :source_len - 1]
        predicted_tokens = self._unk_replace_fn(
            source_tokens=source_tokens,
            predicted_tokens=predicted_tokens,
            attention_scores=attention_scores)

      sent = self.params["delimiter"].join(predicted_tokens).split(
          "SEQUENCE_END")[0]

      # Apply postproc
      if self._postproc_fn:
        sent = self._postproc_fn(sent)
      
      #LOOK UP
      with open('../autoencoder/processed_data/inv_vocab.json','r') as fp:
        vocab = json.load(fp)

      ids = [vocab[token] for token in predicted_tokens if token != 'SEQUENCE_END']
      print (ids)
Esempio n. 15
0
  def after_run(self, _run_context, run_values):
    fetches_batch = run_values.results
    for fetches in unbatch_dict(fetches_batch):
      # Convert to unicode
      fetches["predicted_tokens"] = np.char.decode(
          fetches["predicted_tokens"].astype("S"), "utf-8")
      predicted_tokens = fetches["predicted_tokens"]

      # If we're using beam search we take the first beam
      if np.ndim(predicted_tokens) > 1:
        predicted_tokens = predicted_tokens[:, 0]

      fetches["features.source_tokens"] = np.char.decode(
          fetches["features.source_tokens"].astype("S"), "utf-8")
      source_tokens = fetches["features.source_tokens"]
      source_len = fetches["features.source_len"]

      if self._unk_replace_fn is not None:
        # We slice the attention scores so that we do not
        # accidentially replace UNK with a SEQUENCE_END token
        attention_scores = fetches["attention_scores"]
        attention_scores = attention_scores[:, :source_len - 1]
        predicted_tokens = self._unk_replace_fn(
            source_tokens=source_tokens,
            predicted_tokens=predicted_tokens,
            attention_scores=attention_scores)

      sent = self.params["delimiter"].join(predicted_tokens).split(
          "SEQUENCE_END")[0]

      # Apply postproc
      if self._postproc_fn:
        sent = self._postproc_fn(sent)

      sent = sent.strip()

      print(sent)
Esempio n. 16
0
    def after_run(self, _run_context, run_values):

        fetches_batch = copy.deepcopy(run_values.results)

        for fetches in unbatch_dict(fetches_batch):
            self.sample_cnt += 1
            # tf.logging.info("done samples: {}".format(self.sample_cnt))
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            predicted_tokens_list = fetches["predicted_tokens"]

            fetches["features.source_tokens"] = np.char.decode(
                fetches["features.source_tokens"].astype("S"), "utf-8")
            source_tokens = fetches["features.source_tokens"]
            source_len = fetches["features.source_len"]

            source_sent = self.params["delimiter"].join(source_tokens)
            beam_search_sents = []
            beam_width = 1

            if predicted_tokens_list.ndim > 1:
                beam_width = np.shape(predicted_tokens_list)[1]

            # If we're using beam search we take the first beam
            if np.ndim(predicted_tokens_list) > 1:
                predicted_tokens = predicted_tokens_list[:, 0]

            for i in range(beam_width):
                if predicted_tokens_list.ndim > 1:
                    predicted_tokens = predicted_tokens_list[:, i]
                else:
                    predicted_tokens = predicted_tokens_list
                if self._unk_replace_fn is not None:
                    # We slice the attention scores so that we do not
                    # accidentially replace UNK with a SEQUENCE_END token
                    if "beam_search_output.original_outputs.attention_scores" in fetches:
                        attention_scores = fetches[
                            "beam_search_output.original_outputs.attention_scores"][:,
                                                                                    i, :]
                    else:
                        attention_scores = fetches["attention_scores"]
                    attention_scores = attention_scores[:, :source_len - 1]
                    predicted_tokens = self._unk_replace_fn(
                        source_tokens=source_tokens,
                        predicted_tokens=predicted_tokens,
                        attention_scores=attention_scores)

                pred_sent = self.params["delimiter"].join(
                    predicted_tokens).split("SEQUENCE_END")[0]

                # Apply postproc
                if self._postproc_fn:
                    pred_sent = self._postproc_fn(pred_sent)
                pred_sent = pred_sent.strip()
                actual_source_sent = source_sent.split("SEQUENCE_END")[0]
                actual_source_len = source_len - 1
                pred_len = len(pred_sent.split(self.params["delimiter"]))

                dump_attention_scores = attention_scores[0:pred_len,
                                                         0:actual_source_len]
                self.attn_scores_list.append({
                    "source_sent":
                    actual_source_sent.split(" "),
                    "pred_sent":
                    pred_sent.split(" "),
                    "attn_score":
                    dump_attention_scores
                })
                beam_search_sents.append(pred_sent)

            pred_sents_str = "\n".join(beam_search_sents)
            if self._save_pred_path is not None:
                infer_out = source_sent + "\n" + pred_sents_str + "\n\n"
                self.infer_outs.append(infer_out)
                if self.sample_cnt % 100 == 0:
                    self.write_buffer_to_disk()
            else:
                print(source_sent + "\n" + pred_sents_str + "\n\n")
    def after_run(self, _run_context, run_values):
        fetches_batch = run_values.results
        for fetches in unbatch_dict(fetches_batch):
            # Convert to unicode
            fetches["predicted_tokens"] = np.char.decode(
                fetches["predicted_tokens"].astype("S"), "utf-8")
            predicted_tokens = fetches["predicted_tokens"]

            #self._beam_accum["predicted_ids"].append(fetches["beam_search_output.predicted_ids"])
            #self._beam_accum["beam_parent_ids"].append(fetches["beam_search_output.beam_parent_ids"])
            #self._beam_accum["scores"].append(fetches["beam_search_output.scores"])
            #self._beam_accum["log_probs"].append(fetches["beam_search_output.log_probs"])

            self._beam_accum["predicted_ids"] = [
                fetches["beam_search_output.predicted_ids"]
            ]
            self._beam_accum["beam_parent_ids"] = [
                fetches["beam_search_output.beam_parent_ids"]
            ]
            self._beam_accum["scores"] = [fetches["beam_search_output.scores"]]
            self._beam_accum["log_probs"] = [
                fetches["beam_search_output.log_probs"]
            ]

            #     print("\n\n")
            #      print(self._beam_accum)
            #print(predicted_tokens)
            #      print("\n\n")

            def beam_search_traceback(i, cur_id):
                if i == 0: return np.array([])
                else:
                    cur_prediction = predicted_tokens[i - 1:i, cur_id]
                    parent_id = self._beam_accum["beam_parent_ids"][0][
                        i - 1][cur_id]
                    return np.append(beam_search_traceback(i - 1, parent_id),
                                     cur_prediction)

            # If we're using beam search we take the first beam
            # TODO: beam search top k
            if np.ndim(predicted_tokens) > 1:
                #predicted_tokens = predicted_tokens[:, 0]
                try:
                    beam_search_predicted_tokens = []
                    seq_len = predicted_tokens.shape[0]
                    beam_width = predicted_tokens.shape[1]

                    for length in range(seq_len, 0, -1):
                        for k in range(0, beam_width):
                            parent_id = self._beam_accum["beam_parent_ids"][0][
                                length - 1][k]

                            bigram_score = 0
                            char_cur = predicted_tokens[length - 1, k]
                            char_prev = predicted_tokens[length - 2, parent_id]
                            if char_cur == "SEQUENCE_END":
                                char_cur = "^"
                            else:
                                char_cur = char_cur[0]
                            if char_prev == "SEQUENCE_END":
                                char_prev = "^"
                            else:
                                char_prev = char_prev[len(char_prev) - 1]
                            try:
                                bigram_score = (ev.bigram_dict[
                                    (char_prev, char_cur)]) / float(
                                        ev.unigram_dict[char_prev])
                            except ZeroDivisionError or KeyError:
                                bigram_score = 0
                            self._beam_accum["scores"][0][
                                length - 1][k] = self._beam_accum["scores"][0][
                                    length - 1][k] + bigram_score * 5

                    for length in range(1, seq_len):
                        prediction_per_len = []
                        for k in range(0, min(beam_width, self.top_k)):
                            pred_tokens_k = beam_search_traceback(length, k)
                            prob_pred_token_k = self._beam_accum["scores"][0][
                                length - 1][k]
                            if not _arreq_in_list(pred_tokens_k,
                                                  prediction_per_len):
                                prediction_per_len.append(
                                    (pred_tokens_k, prob_pred_token_k))
                        prediction_per_len = sorted(prediction_per_len,
                                                    key=lambda x: x[1],
                                                    reverse=True)[:10]
                        beam_search_predicted_tokens.append(prediction_per_len)
                    predicted_tokens = beam_search_predicted_tokens
                except IndexError as e:
                    logging.exception("")
                    print(self._beam_accum)
                    print(predicted_tokens)
                    predicted_tokens = []
                    print("parents dim",
                          np.ndim(self._beam_accum["beam_parent_ids"]))
                    print("predicted tokends dim", np.ndim(predicted_tokens))

            fetches["features.source_tokens"] = np.char.decode(
                fetches["features.source_tokens"].astype("S"), "utf-8")
            source_tokens = fetches["features.source_tokens"]

            self.callback_func(source_tokens, predicted_tokens)