Exemplo n.º 1
0
def test_tokenizer():
    if not os.path.exists(constants.TOKENIZER_PICKLE):
        warnings.warn(
            f"{constants.TOKENIZER_PICKLE} not found, skipping test_tokenizer()"
        )
        return

    t = model.get_tokenizer()
    assert isinstance(t, tf.keras.preprocessing.text.Tokenizer)
    assert len(t.word_counts) > 0
Exemplo n.º 2
0
def gen_model_loaders(config):
    #model, tokenizers = M.build_model(config)

    tokenizers = M.get_tokenizer()

    pad_sequence = PadSequence(tokenizers.src.pad_token_id,
                               tokenizers.tgt.pad_token_id)

    train_loader = DataLoader(IndicDataset(tokenizers.src, tokenizers.tgt,
                                           config.data, True),
                              batch_size=config.batch_size,
                              shuffle=False,
                              collate_fn=pad_sequence)
    eval_loader = DataLoader(IndicDataset(tokenizers.src, tokenizers.tgt,
                                          config.data, False),
                             batch_size=config.eval_size,
                             shuffle=False,
                             collate_fn=pad_sequence)

    model = M.build_model(config, train_loader, eval_loader)

    return model, tokenizers, train_loader, eval_loader
Exemplo n.º 3
0
def test_model_performance():
    if not os.path.exists(constants.MODEL_FILE):
        warnings.warn(
            f"{constants.MODEL_FILE} not found, skipping test_model_performance()"
        )
        return
    if not os.path.exists(constants.TOKENIZER_PICKLE):
        warnings.warn(
            f"{constants.TOKENIZER_PICKLE} not found, skipping test_model_performance()"
        )
        return

    m = model.get_model()
    tokenizer = model.get_tokenizer()

    texts, labels = model.get_data()

    X = tokenizer.texts_to_sequences(texts)
    X = tf.keras.preprocessing.sequence.pad_sequences(X, maxlen=16)
    y = tf.keras.utils.to_categorical(labels)

    accuracy = m.evaluate(X, y)[1]

    assert accuracy > .9