def generate_data(self):
    """Generate data for offline training."""
    if self.infer_without_label:
      column_num = 1
      text_ds = load_textline_dataset(self.paths_after_pre_process, column_num)
    else:
      column_num = 3
      intent_label_ds, slots_label_ds, text_ds = load_textline_dataset(
          self.paths_after_pre_process, column_num)

    logging.info("Loading text dataset...")
    input_pipeline_func = self.get_input_pipeline(for_export=False)
    text_ds = text_ds.map(
        input_pipeline_func, num_parallel_calls=self.num_parallel_calls)
    text_size_ds = text_ds.map(
        lambda x: compute_sen_lens(x, padding_token=0),
        num_parallel_calls=self.num_parallel_calls)
    text_ds = tf.data.Dataset.zip((text_ds, text_size_ds))

    if self.infer_without_label:
      data_set = text_ds
    else:
      intent_label_ds = process_one_label_dataset(
          intent_label_ds, self.config, output_index=0)
      slots_label_ds = process_multi_label_dataset(
          slots_label_ds, self.config, output_index=1)
      data_set = tf.data.Dataset.zip((text_ds, intent_label_ds, slots_label_ds))

    self.config['data']['vocab_size'] = get_vocab_size(
        self.text_vocab_file_path)
    self.config['data']['{}_data_size'.format(self.mode)] = get_file_len(
        self.paths_after_pre_process)

    return data_set
Esempio n. 2
0
    def generate_data(self):
        """Generate data for offline training."""
        if self.infer_without_label:
            column_num = 1
            text_ds = load_textline_dataset(self.paths_after_pre_process,
                                            column_num)
        else:
            column_num = 2
            label_ds, text_ds = load_textline_dataset(
                self.paths_after_pre_process, column_num)

        input_pipeline_func = self.get_input_pipeline(for_export=False)

        text_ds = text_ds.map(input_pipeline_func,
                              num_parallel_calls=self.num_parallel_calls)

        text_size_ds = text_ds.map(
            lambda x: compute_sen_lens(x, padding_token=utils.PAD_IDX),
            num_parallel_calls=self.num_parallel_calls)

        text_ds = tf.data.Dataset.zip((text_ds, text_size_ds))

        if self.use_dense:
            dense = load_npy(self.dense_npy)
            dense_ds = load_dense_dataset(dense)

        if self.infer_without_label:
            if self.use_dense:
                data_set = tf.data.Dataset.zip((text_ds, dense_ds))
            else:
                data_set = text_ds
        else:
            label_ds = process_one_label_dataset(label_ds, self.config)
            if self.use_dense:
                data_set = tf.data.Dataset.zip((text_ds, dense_ds, label_ds))
            else:
                data_set = tf.data.Dataset.zip((text_ds, label_ds))

        vocab_dict = load_vocab_dict(self.text_vocab_file_path)
        vocab_size = len(vocab_dict)
        if self.split_token != "":
            if self.split_token not in vocab_dict:
                raise ValueError(
                    "The Model uses split token: {}, not in corpus.".format(
                        self.split_token))
            self.config['data']['split_token'] = int(
                vocab_dict[self.split_token])
        self.config['data']['vocab_size'] = vocab_size
        self.config['data']['{}_data_size'.format(self.mode)] = get_file_len(
            self.paths_after_pre_process)

        return data_set
Esempio n. 3
0
    def test_process_one_label_dataset(self):
        label = ["O", "O", "O", "I-MISC"]
        label_filepath = tempfile.mktemp(suffix='label_file_for_unitest.txt')
        with open(label_filepath, mode='w', encoding='utf-8') as fobj:
            for token in label:
                fobj.write(token)
                fobj.write('\n')
        label_ds = tf.data.TextLineDataset(label_filepath)
        true_res = [0, 0, 0, 8]
        label_ds = process_one_label_dataset(label_ds, self.config)

        iterator = label_ds.make_initializable_iterator()
        label_res = iterator.get_next()

        with tf.Session() as sess:
            sess.run(iterator.initializer)
            for i in range(len(label)):
                self.assertEqual(np.argmax(sess.run(label_res)), true_res[i])
Esempio n. 4
0
    def generate_data(self):
        """Generate data for offline training."""
        if self.infer_without_label:
            column_num = 2
            text_ds_left, text_ds_right = load_textline_dataset(
                self.paths_after_pre_process, column_num)
        else:
            column_num = 3
            label, text_ds_left, text_ds_right = load_textline_dataset(
                self.paths_after_pre_process, column_num)

        input_pipeline_func = self.get_input_pipeline(for_export=False)
        text_ds_left = text_ds_left.map(
            input_pipeline_func, num_parallel_calls=self.num_parallel_calls)
        text_ds_right = text_ds_right.map(
            input_pipeline_func, num_parallel_calls=self.num_parallel_calls)
        text_size_ds_left = text_ds_left.map(
            lambda x: compute_sen_lens(x, padding_token=0),
            num_parallel_calls=self.num_parallel_calls)
        text_size_ds_right = text_ds_right.map(
            lambda x: compute_sen_lens(x, padding_token=0),
            num_parallel_calls=self.num_parallel_calls)
        text_ds_left_right = tf.data.Dataset.zip((text_ds_left, text_ds_right))
        text_len_left_right = tf.data.Dataset.zip(
            (text_size_ds_left, text_size_ds_right))
        if self.infer_without_label:
            data_set_left_right = text_ds_left_right
        else:
            label_ds = process_one_label_dataset(label, self.config)
            data_set_left_right = tf.data.Dataset.zip(
                (text_ds_left_right, label_ds))
        vocab_dict = load_vocab_dict(self.text_vocab_file_path)
        vocab_size = len(vocab_dict)

        self.config['data']['vocab_size'] = vocab_size
        self.config['data']['{}_data_size'.format(self.mode)] = get_file_len(
            self.paths_after_pre_process)

        return data_set_left_right, text_len_left_right