Exemple #1
0
def test_save_load_model(compiled_model):
    tmpdir = '.tmpdir'
    compiled_model.save(tmpdir)
    assert engine.load_model(tmpdir)
    with pytest.raises(FileExistsError):
        compiled_model.save(tmpdir)
    shutil.rmtree(tmpdir)
Exemple #2
0
def test_dssm(train, test):
    """Test DSSM model."""
    # do pre-processing.
    dssm_preprocessor = preprocessor.DSSMPreprocessor()
    processed_train = dssm_preprocessor.fit_transform(train, stage='train')
    processed_test = dssm_preprocessor.fit_transform(test, stage='test')
    # the dimension of dssm model is the length of tri-letters.
    input_shapes = processed_train.context['input_shapes']
    # generator.
    generator = generators.PointGenerator(processed_train, stage='train')
    # Create a dssm model
    dssm_model = models.DSSMModel()
    dssm_model.params['input_shapes'] = input_shapes
    dssm_model.guess_and_fill_missing_params()
    dssm_model.build()
    dssm_model.compile()
    dssm_model.fit_generator(generator)
    # save
    dssm_preprocessor.save('.tmpdir')
    dssm_model.save('.tmpdir')

    # testing
    dssm_proprecessor = engine.load_preprocessor('.tmpdir')
    processed_test = dssm_proprecessor.fit_transform(test, stage='test')
    generator = generators.PointGenerator(processed_test, stage='test')
    X, y = generator[0]
    dssm_model = engine.load_model('.tmpdir')
    predictions = dssm_model.predict([X.text_left, X.text_right])
    assert len(predictions) > 0
    assert type(predictions[0][0]) == np.float32
    shutil.rmtree('.tmpdir')
Exemple #3
0
 def test_predict(self, config):
     p_dir = config.paths['processed_dir']
     model = engine.load_model(p_dir, config.net_name)
     results = predict(config, model, "glacier caves", 5)
     assert len(results) == 5
     with open(config.outputs, 'w') as f:
         for r in results:
             f.write(f"{r[0]}\t{r[1]}\t{r[2]}\n")
Exemple #4
0
def test_save_load_model(compiled_model):
    model, input_dtypes = compiled_model
    tmpdir = '.tmpdir'
    if Path(tmpdir).exists():
        shutil.rmtree(tmpdir)
    model.save(tmpdir)
    assert engine.load_model(tmpdir)
    with pytest.raises(FileExistsError):
        model.save(tmpdir)
    shutil.rmtree(tmpdir)
Exemple #5
0
def predict(ctx, num_largest):
    logger.info('Loading model...')

    config = ctx.obj['CONFIG']
    pr_dir = config.paths['processed_dir']
    model_type = config.model['type']
    net_name = config.net_name

    if model_type.lower() == 'dssm':
        model = engine.load_model(pr_dir, net_name)
    else:
        raise NotImplementedError(f"Model type {model_type} not implemented")

    query = click.prompt("What do you want to search?", type=str)
    while query and query != 'exit':
        results = model_predict(ctx.obj['CONFIG'], model, query, num_largest)
        for res in results:
            print(res)
        query = click.prompt("What do you want to search?", type=str)