예제 #1
0
파일: appcom.py 프로젝트: tboquet/scheduler
def predict(model_db, data, custom_objects=None):
    """A function to predict given a model, datasets, custom objects and
    batch size"""
    from keras.models import model_from_config
    from utils import sliced
    from databasesetup import get_models
    import numpy as np

    if custom_objects == None:
        custom_objects = []
    # get the models collection
    models = get_models()

    # check if the predict function is already compiled
    if model_db['hashed_mod'] in COMPILED_MODELS:
        pred_function = COMPILED_MODELS[model_db['hashed_mod']]
        model_json = model_db['keras_model']
        model_name = model_json.get('class_name')
    else:
        # get the model in the DB
        # model_db = models.find_one({'hashed_mod': hashed_mod})
        model_json = model_db['keras_model']

        model_json.pop('optimizer')
        # load model
        model = model_from_config(model_json, custom_objects=custom_objects)
        model_name = model_json.get('name')

        # load the weights
        model.load_weights(model_db['params_dump'])

        # build the prediction function
        pred_function = build_predict_func(model)
        COMPILED_MODELS[model_db['hashed_mod']] = pred_function

    # predict according to the input/output type
    if model_name == 'Graph':
        input_order = model_json.get('input_order')
        pred = pred_function([np.array(data[n]) for n in input_order])
    elif model_name == 'Sequential':
        # unpack data
        X = data['X']
        pred = pred_function([X])
    else:
        raise NotImplementedError('This type of model is not supported')

    return pred
예제 #2
0
파일: appcom.py 프로젝트: tboquet/scheduler
def fit(model_str,
        data_s,
        nb_train,
        nb_test,
        offset,
        nb_epoch=10,
        batch_size=32,
        custom_objects=None,
        callbacks=None,
        cuts_shutd=True):
    """A function to train models given datasets,a serialized model,
    custom objects, callbacks, nb_epochs and batch size

    Args:
        model_str(str): the model dumped with the `to_json` method
        data(str): the path of the dataset to loads
        nb_train(int): the number of train datapoints to take
        nb_test(int): the number of test datapoints to take
        offset(int): how many datapoints to burn
        weights(numpy.array): an array of weights of the size of
            the training set.
        nb_epoch(int, optionnal): the number of epochs to train the model
        batch_size(int, optionnal): the batch size of the mini batches
        callbacks(list): a list of callbacks used in the fit

    Returns:
        the unique id of the model"""

    from utils import sliced
    from databasesetup import get_models
    from datetime import datetime
    import hashlib
    import json
    import numpy as np

    if custom_objects == None:
        custom_objects = []
    if callbacks == None:
        callbacks = []
    # convert string to json
    model_json = model_str

    # get the models collection
    models = get_models()

    # load data
    data_path = data_s.pop('data_path')
    elec = data_s.pop('elec')
    cell = data_s.pop('cell')
    if data_path != 'test':
        root = data_s.pop('root')

        reader = dr.DataReader(data_path, root)
        reader.get_files_info()
        data = reader.commit(data_s)
    else:
        cuts_shutd = False
        data = data_s.copy()
        data['current'] = []
        data['bins'] = []
        root = None
        data_s = {key: data_s['data'][key].tolist() for key in data_s['data']}

    # slice data
    beg, endt, endv = sliced(data['data'], nb_train, nb_test, offset)
    data_t = {n: data['data'][n][beg:endt] for n in data['data']}
    data_val = {n: data['data'][n][endt:endv] for n in data['data']}
    current_t = data['current'][beg:endt]
    current_val = data['current'][endt:endv]
    bins_t = data['bins'][beg:endt]
    bins_val = data['bins'][endt:endv]

    # cuts the data by shutdown
    if cuts_shutd == True:
        datas = dr.make_datasets(data_t, current_t, bins_t)
        datas_val = dr.make_datasets(data_val, current_val, bins_val)
        diff = len(datas) - len(datas_val)
        if diff > 0:
            datas_val += [datas_val[-1] for i in range(diff)]
    else:
        datas = [data_t]
        datas_val = [data_val]

    # TODO: implement cut dataset to match batch sizes
    if 'statefull' in model_json:
        pass
    # get a unique descriptor of the db
    first = data_t.keys()[0]
    un_data = data_t[first].mean()

    # create the hash from the stringified json
    m = hashlib.md5()
    m.update(json.dumps(model_str) + str(un_data) + str(batch_size))
    hexdi = m.hexdigest()

    params_dump = "/parameters_h5/" + hexdi + '.h5'

    # update the full json
    full_json = {'keras_model': model_json,
                 'datetime': datetime.now(),
                 'hashed_mod': hexdi,
                 'data_id': str(un_data),
                 'params_dump': params_dump,
                 'batch_size': batch_size,
                 'trained': 0,
                 'cell': cell,
                 'elec': elec,
                 'data_path': data_path,
                 'root': root,
                 'data_s': data_s}
    mod_id = models.insert_one(full_json).inserted_id

    try:
        loss, val_loss, model = train_model(model_str, custom_objects, datas,
                                            datas_val, batch_size, nb_epoch,
                                            callbacks)
        upres = models.update({"_id": mod_id}, {'$set': {
            'train_loss': loss,
            'min_tloss': np.min(loss),
            'valid_loss': val_loss,
            'min_vloss': np.min(val_loss),
            'iter_stopped': nb_epoch * len(datas),
            'trained': 1,
            'date_finished_trained': datetime.now()
        }})

        model.save_weights(params_dump, overwrite=True)

    except MemoryError as e:
        models.delete_one({'hashed_mod': hexdi})
        raise

    except Exception as e:
        upres = models.update({"_id": mod_id}, {'$set': {'error': 1}})
        raise
    return hexdi
예제 #3
0
파일: app.py 프로젝트: tboquet/scheduler
    import simplejson as json
except ImportError:
    import json


app = Flask(__name__)


logger = logging.getLogger(__name__)


BASE_URL = '/pred/api/v1.0/'


# get models from the db
modelsdb = get_models()


def clean_db_response(list_of_mod):
    """Clean a list of responses from the db

    Args:
        list_of_mod(list): a list of db responses

    Returns:
        a dictionnary mapping the id of the model to pertinent information"""
    return {m['hashed_mod']: (m['min_vloss'], m['cell'], m['elec'])
            for m in list_of_mod}

#Get models
@app.route(BASE_URL + 'models', methods=['GET'])