def test_save_load():
    filename = NamedTemporaryFile().name

    encoder = OneHotEncoder(['animal', 'color'], ['weight'],
                            max_levels_default=100)

    data = [{
        'animal': 'cat',
        'color': 'blue',
        'weight': 6.0
    }, {
        'animal': 'cat',
        'color': 'red',
        'weight': 3.0
    }, {
        'animal': 'dog',
        'color': 'yellow',
        'weight': 5.5
    }, {
        'animal': 'fish',
        'color': 'blue',
        'weight': 7.0
    }, {
        'animal': 'cat',
        'color': 'magenta',
        'weight': 2.0
    }, {
        'animal': 'mouse',
        'color': 'purple',
        'weight': 0.0
    }, {
        'animal': 'mouse',
        'color': 'black',
        'weight': 99.9
    }]

    encoder.load_from_data_stream(data)

    encoded_data = encoder.encode_data(data)

    encoder.save(filename)

    encoder_from_file = OneHotEncoder([], [])
    encoder_from_file.load_from_file(filename)

    encoded_data_from_file = encoder_from_file.encode_data(data)

    assert encoded_data == encoded_data_from_file
def get_encoder(pipeline_params, write=True, read_from_file=False):
    encoder_file = file_names['encoder']
    if os.path.exists(encoder_file) and read_from_file:
        print('Reading encoder from : %s' % encoder_file)
        encoder_from_file = OneHotEncoder([], [])
        encoder_from_file.load_from_file(encoder_file)
        return encoder_from_file

    print('Building encoder')
    stream = stream_data(pipeline_params)
    encoder = get_encoder_from_stream(stream)

    if write:
        print('Writing encoder to: %s' % encoder_file)
        encoder.save(encoder_file)

    return encoder