示例#1
0
    def from_config(cls, config: Config, tensorizers: Dict[str, Tensorizer]):
        src_tokens = tensorizers["src_seq_tokens"]
        src_embedding_list = [
            create_module(config.source_embedding, tensorizer=src_tokens)
        ]
        gazetteer_tensorizer = tensorizers.get("dict_feat")
        if gazetteer_tensorizer:
            src_embedding_list.append(
                create_module(config.dict_embedding,
                              tensorizer=gazetteer_tensorizer))
        source_embedding = ScriptableEmbeddingList(src_embedding_list)

        trg_tokens = tensorizers["trg_seq_tokens"]
        target_embedding = ScriptableEmbeddingList(
            [create_module(config.target_embedding, tensorizer=trg_tokens)])

        model = create_module(
            config.encoder_decoder,
            src_tokens.vocab,
            source_embedding,
            trg_tokens.vocab,
            target_embedding,
        )
        output_layer = create_module(config.output_layer, trg_tokens.vocab)

        dictfeat_tokens = tensorizers.get("dict_feat")

        return cls(
            model=model,
            output_layer=output_layer,
            src_vocab=src_tokens.vocab,
            trg_vocab=trg_tokens.vocab,
            dictfeat_vocab=dictfeat_tokens.vocab if dictfeat_tokens else None,
            generator_config=config.sequence_generator,
        )
示例#2
0
def create_src_embedding_list(config, tensorizers):
    src_tokens = tensorizers["src_seq_tokens"]
    src_embedding_list = [create_module(config.source_embedding, tensorizer=src_tokens)]
    source_vocab = src_tokens.vocab
    gazetteer_tensorizer = tensorizers.get("dict_feat")
    dict_vocab, dict_embedding = None, None
    if gazetteer_tensorizer:
        dict_embedding = create_module(
            config.dict_embedding, tensorizer=gazetteer_tensorizer
        )
        src_embedding_list.append(dict_embedding)
        dict_vocab = gazetteer_tensorizer.vocab
    source_embedding = ScriptableEmbeddingList(src_embedding_list)
    return source_embedding, source_vocab, dict_embedding, dict_vocab
示例#3
0
def create_tgt_embedding_list(config, tensorizers):
    trg_tokens = tensorizers["trg_seq_tokens"]
    target_embedding = ScriptableEmbeddingList(
        [create_module(config.target_embedding, tensorizer=trg_tokens)]
    )
    return target_embedding, trg_tokens.vocab