示例#1
0
def test_experiment_model_resume(csv_filename):
    # Single sequence input, single category output
    # Tests saving a model file, loading it to rerun training and predict
    input_features = [sequence_feature(encoder='rnn', reduce_output='sum')]
    output_features = [category_feature(vocab_size=2, reduce_input='sum')]
    # Generate test data
    rel_path = generate_data(input_features, output_features, csv_filename)

    config = {
        'input_features': input_features,
        'output_features': output_features,
        'combiner': {
            'type': 'concat',
            'fc_size': 14
        },
        'training': {
            'epochs': 2
        }
    }

    _, _, _, _, output_dir = experiment_cli(config, dataset=rel_path)
    logger.info('Experiment Directory: {0}'.format(output_dir))

    experiment_cli(config, dataset=rel_path, model_resume_path=output_dir)

    predict_cli(os.path.join(output_dir, 'model'), dataset=rel_path)
    shutil.rmtree(output_dir, ignore_errors=True)
示例#2
0
def test_experiment_model_resume(tmpdir):
    # Single sequence input, single category output
    # Tests saving a model file, loading it to rerun training and predict
    input_features = [sequence_feature(encoder="rnn", reduce_output="sum")]
    output_features = [category_feature(vocab_size=2, reduce_input="sum")]
    # Generate test data
    rel_path = generate_data(input_features, output_features,
                             os.path.join(tmpdir, "dataset.csv"))

    config = {
        "input_features": input_features,
        "output_features": output_features,
        "combiner": {
            "type": "concat",
            "output_size": 14
        },
        TRAINER: {
            "epochs": 2
        },
    }

    _, _, _, _, output_dir = experiment_cli(config,
                                            dataset=rel_path,
                                            output_directory=tmpdir)

    experiment_cli(config, dataset=rel_path, model_resume_path=output_dir)

    predict_cli(os.path.join(output_dir, "model"), dataset=rel_path)
    shutil.rmtree(output_dir, ignore_errors=True)