def train_model(train_files, model_save_path):
    """ Main function to train model.

    Inputs:
    - train_files: files we will use to train our model. 
    - model_save_path: path to save the model. 
    """

    #Split the data between training and validation. The variable "test_size" is 0.1 (10%) the percentage for validation.
    # ---> train_test_split is a function that randomly splits the data. For now we won't cherry pick our random state.

    train, val = train_test_split(train_files,
                                  test_size=0.1)  #, random_state=1337)

    # Load all our train data
    train_dict = {k: np.load(k) for k in train}

    # Load all our validation data
    val_dict = {k: np.load(k) for k in val}
    print("Validating: " + str(val_dict))

    #The model architecture has 3 repeated sets of two 1-D convolutional (Conv1D) layers, 1-D max-pooling and spatial dropout layers.
    # This is followed by two Conv1D, 1-D global max-pooling, dropout and dense layers. We finally have a dropout layer as the output of "Base-CNN".
    # This is fed to the Time-Distributed Base-CNN model, and then a 1-D convolutional layer, spatial dropout, another 1-D convolutional layer, dropout, 1D conv and finally the multiclass sleep labels.

    model = get_model_cnn()

    # Training
    #This is useful to avoid overfitting. Saves what is the best so far for validation accuracy (for every epoch).
    checkpoint = ModelCheckpoint(model_save_path,
                                 monitor='val_acc',
                                 verbose=1,
                                 save_best_only=True,
                                 mode='max')
    early = EarlyStopping(monitor="val_acc",
                          mode="max",
                          patience=20,
                          verbose=1)

    #Learning rate is reduced each time the validation accuracy plateaus using ReduceLROnPlateau Keras Callbacks.
    redonplat = ReduceLROnPlateau(monitor="val_acc",
                                  mode="max",
                                  patience=5,
                                  verbose=2)
    callbacks_list = [checkpoint, redonplat]

    model.fit_generator(gen(train_dict, aug=False),
                        validation_data=gen(val_dict),
                        epochs=25,
                        verbose=2,
                        steps_per_epoch=1000,
                        validation_steps=300,
                        callbacks=callbacks_list)

    #And finally we save our model!
    model.save(model_save_path)
示例#2
0
plotly_chart_2 = st.empty()

option = st.sidebar.selectbox("Which architecture do you want to use ?",
                              ("None", "MLP", "CNN"))

markdown = st.sidebar.markdown(
    "#### Classify Mnist Digits and Display their representation"
    "![](https://warwick.ac.uk/fac/cross_fac/complexity/study/msc_and_phd/co902/2013_2014/resources/mnisttrain.png)"
)

x_train, y_train, x_test, y_test = get_data(flatten=(option == "MLP"))

if option == "MLP":
    model, model_aux = get_model_mlp()
else:
    model, model_aux = get_model_cnn()

if option in ("MLP", "CNN"):

    accuracies = []

    for i in range(20):
        progress_bar.progress(5 * (i + 1))

        history = model.fit(x_train,
                            y_train,
                            validation_data=(x_test, y_test),
                            nb_epoch=1)

        acc = float(history.history["val_acc"][0])
        accuracies.append(acc)
示例#3
0
ids = sorted(list(set([x.split("/")[-1][:5] for x in files])))
#split by test subject
train_ids, test_ids = train_test_split(ids, test_size=0.15, random_state=1338)

train_val, test = [x for x in files if x.split("/")[-1][:5] in train_ids],\
                  [x for x in files if x.split("/")[-1][:5] in test_ids]

train, val = train_test_split(train_val, test_size=0.1, random_state=1337)

train_dict = {k: np.load(k, encoding="bytes") for k in train}
test_dict = {k: np.load(k, encoding="bytes") for k in test}
val_dict = {k: np.load(k, encoding="bytes") for k in val}


#model = get_model_lstm()
model = get_model_cnn()

model.load_weights(file_path)

#print(dir(model))

#example_batch = np.random.rand(1, 1, 3000, 1)

#ret = model.predict_on_batch(example_batch)
#print(ret)

print(list(val_dict.values())[0].keys())


def pred(X):
    #print("Shape of X: %s" % X.shape)
示例#4
0
文件: test.py 项目: oujieww/oqmrc2018
    valid_json_path = '../ai_challenger_oqmrc_validationset_20180816/ai_challenger_oqmrc_validationset.json'

    eval_data = read_json(valid_json_path, True)
    # eval_data = eval_data.sample(frac=1)
    eval_data['label'] = eval_data.loc[:, ['alternatives', 'answer']].apply(
        idx_label, axis=1)

    eval_inputs, eval_y = create_data(eval_data,
                                      ttv,
                                      passage_path='../log/eval_passage.npy',
                                      query_path='../log/eval_query.npy',
                                      answer_path='../log/eval_answer.npy',
                                      label_path='../log/eval_label.npy')

    model = models.get_model_cnn(eval_inputs, ttv.word_vec)

    # model = models.test_embde(train_inputs[0],
    # ttv.word_vec)
    # model.predict(train_inputs[0])
    # model.summary()
    plot_model(model, to_file='../log/model.png', show_shapes=True)
    model_path = '../log/model-0.79.h5'
    if os.path.exists(model_path):
        model.load_weights(model_path)
        print('Load weighs from', model_path)
    # checkpoint = ModelCheckpoint(
    # filepath='../log/model-{val_loss:.2f}.h5',
    # monitor='val_loss',
    # save_best_only=True,
    # save_weights_only=True,verbose=1,period=2)
示例#5
0
        idx, 'alternatives'].map(lambda x: x + '|无法确定' + '|无法确定')
    idx = test_data['alternatives'].map(lambda x: len(x.split('|'))) == 2
    test_data.loc[idx, 'alternatives'] = test_data.loc[
        idx, 'alternatives'].map(lambda x: x + '|无法确定')

    # eval_data = eval_data.sample(frac=1)
    # test_data['label'] = test_data.loc[:,['alternatives','answer']].apply(idx_label, axis=1)

    test_inputs, test_y = create_data(test_data,
                                      ttv,
                                      passage_path='../log/test_passage.npy',
                                      query_path='../log/test_query.npy',
                                      answer_path='../log/test_answer.npy',
                                      label_path=None)

    model = models.get_model_cnn(test_inputs, ttv.word_vec)

    # model = models.test_embde(train_inputs[0],
    # ttv.word_vec)
    # model.predict(train_inputs[0])
    # model.summary()
    plot_model(model, to_file='../log/model.png', show_shapes=True)
    model_path = '../log/model-0.79.h5'
    if os.path.exists(model_path):
        model.load_weights(model_path)
        print('Load weighs from', model_path)
    # checkpoint = ModelCheckpoint(
    # filepath='../log/model-{val_loss:.2f}.h5',
    # monitor='val_loss',
    # save_best_only=True,
    # save_weights_only=True,verbose=1,period=2)
示例#6
0
    train_data['label'] = train_data.loc[:,['shuffle_a','answer']].apply(idx_label, axis=1)
    
    train_len = int(train_data.shape[0] * 0.8)
        
    train_inputs, train_y = create_data(train_data[:train_len], 
                                ttv,passage_path='../log/train_passage.npy',
                                query_path='../log/train_query.npy',
                                answer_path='../log/train_answer.npy',
                                label_path='../log/train_label.npy')
    valid_inputs, valid_y = create_data(train_data[train_len:], 
                                ttv,passage_path='../log/valid_passage.npy',
                                query_path='../log/valid_query.npy',
                                answer_path='../log/valid_answer.npy',
                                label_path='../log/valid_label.npy')

    model = models.get_model_cnn(train_inputs, ttv.word_vec)
    
    # model = models.test_embde(train_inputs[0], 
                             # ttv.word_vec)
    # model.predict(train_inputs[0])
    # model.summary()
    plot_model(model, to_file='../log/model.png', show_shapes=True)
    model_path = '../log/model-0.79.h5'
    if os.path.exists(model_path):
        model.load_weights(model_path)
        print('Load weighs from', model_path)
    checkpoint = ModelCheckpoint(
                    filepath='../log/model-{val_loss:.2f}.h5', 
                    monitor='val_loss', 
                    save_best_only=True, 
                    save_weights_only=True,verbose=1,period=2)