Exemplo n.º 1
0
def _get_reader(config,
                skip_labels=False,
                bert_max_length=None,
                reader_max_length=150,
                read_first=None):
    indexers = {}
    for embedder_config in config.embedder.models:
        if embedder_config.name == 'elmo':
            indexers[embedder_config.name] = ELMoTokenCharactersIndexer()
        elif embedder_config.name.endswith('bert'):
            bert_path = os.path.join(config.data.pretrained_models_dir,
                                     embedder_config.name)
            indexers[
                embedder_config.name] = PretrainedTransformerMismatchedIndexer(
                    model_name=bert_path,
                    tokenizer_kwargs={'do_lower_case': False},
                    max_length=bert_max_length)
        elif embedder_config.name == 'char_bilstm':
            indexers[embedder_config.name] = TokenCharactersIndexer()
        else:
            assert False, 'Unknown embedder {}'.format(embedder_config.name)

    return UDDatasetReader(indexers,
                           skip_labels=skip_labels,
                           max_length=reader_max_length,
                           read_first=read_first)
Exemplo n.º 2
0
 def __init__(self,
              token_indexers: Dict[str, TokenIndexer] = None,
              lazy=False) -> None:
     super().__init__(lazy=lazy)
     self.token_indexers = token_indexers or {
         "tokens": SingleIdTokenIndexer(lowercase_tokens=True),
         "chars": TokenCharactersIndexer()
     }
Exemplo n.º 3
0
 def __init__(self, token_indexers: Dict[str, TokenIndexer] = None,
              token_character_indexers: Dict[str, TokenCharactersIndexer] = None) -> None:
     super().__init__(lazy=False)
     self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
     # add indexer for characters
     self.token_character_indexers = token_character_indexers or {"token_characters": TokenCharactersIndexer()}
     # following the PyTorch tutorial, turn everything to plain ASCII
     self.all_letters = string.ascii_letters + " .,;'"