예제 #1
0
def get_tensorizers(add_dict_feat=False, add_contextual_feat=False):
    schema = {"source_sequence": str, "dict_feat": Gazetteer, "target_sequence": str}
    data_source = TSVDataSource.from_config(
        TSVDataSource.Config(
            train_filename=TEST_FILE_NAME,
            field_names=["source_sequence", "dict_feat", "target_sequence"],
        ),
        schema,
    )
    src_tensorizer = TokenTensorizer.from_config(
        TokenTensorizer.Config(
            column="source_sequence", add_eos_token=True, add_bos_token=True
        )
    )
    tgt_tensorizer = TokenTensorizer.from_config(
        TokenTensorizer.Config(
            column="target_sequence", add_eos_token=True, add_bos_token=True
        )
    )
    tensorizers = {"src_seq_tokens": src_tensorizer, "trg_seq_tokens": tgt_tensorizer}
    initialize_tensorizers(tensorizers, data_source.train)

    if add_dict_feat:
        tensorizers["dict_feat"] = GazetteerTensorizer.from_config(
            GazetteerTensorizer.Config(
                text_column="source_sequence", dict_column="dict_feat"
            )
        )
        initialize_tensorizers(
            {"dict_feat": tensorizers["dict_feat"]}, data_source.train
        )
    return tensorizers
예제 #2
0
 def test_read_data_source_with_utf8_issues(self):
     schema = {"text": str, "label": str}
     data_source = TSVDataSource.from_config(
         TSVDataSource.Config(
             train_filename=tests_module.test_file("test_utf8_errors.tsv"),
             field_names=["label", "text"],
         ),
         schema,
     )
     list(data_source.train)
 def _get_tensorizers(self):
     schema = {"source_sequence": str, "target_sequence": str}
     data_source = TSVDataSource.from_config(
         TSVDataSource.Config(
             train_filename=tests_module.test_file(
                 "compositional_seq2seq_unit.tsv"),
             field_names=["source_sequence", "target_sequence"],
         ),
         schema,
     )
     src_tensorizer = TokenTensorizer.from_config(
         TokenTensorizer.Config(column="source_sequence",
                                add_eos_token=True,
                                add_bos_token=True))
     tgt_tensorizer = TokenTensorizer.from_config(
         TokenTensorizer.Config(column="target_sequence",
                                add_eos_token=True,
                                add_bos_token=True))
     tensorizers = {
         "src_seq_tokens": src_tensorizer,
         "trg_seq_tokens": tgt_tensorizer,
     }
     initialize_tensorizers(tensorizers, data_source.train)
     return tensorizers