示例#1
0
def model(label_vocab, embeds):
    import baseline.pytorch.tagger.model
    return create_tagger_model(
        embeds, label_vocab,
        crf=True, crf_mask=True, span_type=SPAN_TYPE,
        hsz=HSZ, cfiltsz=[3], wsz=WSZ,
        layers=2, rnntype="blstm"
    )
示例#2
0
def model(label_vocab, embeds, mask):
    from baseline.tf import tagger
    model = create_tagger_model(
        embeds, label_vocab,
        crf=True, constraint=mask,
        hsz=HSZ, cfiltsz=[3], wsz=WSZ,
        layers=2, rnntype="blstm"
    )
    model.create_loss()
    model.sess.run(tf.global_variables_initializer())
    return model
示例#3
0
def test_skip_mask(label_vocab, embeds, mask):
    from baseline.tf import tagger
    model = create_tagger_model(
        embeds, label_vocab,
        crf=True,
        hsz=HSZ, cfiltsz=[3], wsz=WSZ,
        layers=2, rnntype="blstm"
    )
    model.create_loss()
    model.sess.run(tf.global_variables_initializer())
    transition = model.sess.run(model.A)
    assert transition[label_vocab['O'], label_vocab[S]] != -1e4
def model(label_vocab, embeds, mask):
    from baseline.tf import tagger
    model = create_tagger_model(embeds,
                                label_vocab,
                                crf=True,
                                constraint=mask,
                                hsz=HSZ,
                                cfiltsz=[3],
                                wsz=WSZ,
                                layers=2,
                                rnntype="blstm")
    model.create_loss()
    model.sess.run(tf.global_variables_initializer())
    return model
def test_skip_mask(label_vocab, embeds, mask):
    from baseline.tf import tagger
    model = create_tagger_model(embeds,
                                label_vocab,
                                crf=True,
                                hsz=HSZ,
                                cfiltsz=[3],
                                wsz=WSZ,
                                layers=2,
                                rnntype="blstm")
    model.create_loss()
    model.sess.run(tf.global_variables_initializer())
    transition = model.sess.run(model.A)
    assert transition[label_vocab['O'], label_vocab[S]] != -1e4
示例#6
0
def create_model(labels, embeddings, **kwargs):
    return create_tagger_model(BASELINE_TAGGER_MODELS, labels, embeddings, **kwargs)
示例#7
0
def create_model(labels, word_embedding, char_embedding, **kwargs):
    model = create_tagger_model(RNNTaggerModel.create, labels, word_embedding,
                                char_embedding, **kwargs)
    return model
示例#8
0
def create_model(labels, embeddings, **kwargs):
    model = create_tagger_model(RNNTaggerModel.create, labels, embeddings,
                                **kwargs)
    return model