コード例 #1
0
ファイル: cifar10.py プロジェクト: johndpope/odin-ai
                                              inc_test=False, seed=1234)
X_valid = X_train[idx_valid]
y_valid = y_train[idx_valid]
X_train = X_train[idx_train]
y_train = y_train[idx_train]
print("#Train:", X_train.shape, y_train.shape)
print("#Valid:", X_valid.shape, y_valid.shape)
print("#Test:", X_test.shape, y_test.shape)
# ====== trainign ====== #
print('Start training ...')
task = training.MainLoop(batch_size=128, seed=1234, shuffle_level=2,
                         allow_rollback=True)
task.set_checkpoint(MODEL_PATH, model)
task.set_callbacks([
    training.NaNDetector(),
    training.EarlyStopGeneralizationLoss('valid', ce, threshold=5, patience=3)
])
task.set_train_task(func=f_train,
                    data=(X_train, y_train),
                    epoch=NB_EPOCH,
                    name='train')
task.set_valid_task(func=f_test,
                    data=(X_valid, y_valid),
                    freq=training.Timer(percentage=0.6),
                    name='valid')
task.set_eval_task(func=f_test,
                   data=(X_test, y_test),
                   name='eval')
task.run()
# ===========================================================================
# Exsternal validation
コード例 #2
0
ファイル: cifar10_cnn.py プロジェクト: liqin123/odin
print("Build scoring function ...")
f_score = K.function([X, y_true], [cost_pred, cost_eval])

# ===========================================================================
# Create trainer
# ===========================================================================
print("Create trainer ...")
trainer = training.MainLoop(batch_size=32, seed=12082518, shuffle_level=2)
trainer.set_save(utils.get_modelpath('cifar10.ai', override=True), f)
trainer.set_task(f_train, [X_learn, y_learn], epoch=25, p=1, name='Train')
trainer.set_subtask(f_score, [X_test, y_test], freq=1, name='Valid')
trainer.set_callback([
    training.ProgressMonitor(name='Train', format='Results: {:.4f}'),
    training.ProgressMonitor(name='Valid', format='Results: {:.4f},{:.4f}'),
    # early stop based on crossentropy on test (not a right procedure,
    # but only for testing)
    training.EarlyStopGeneralizationLoss(
        name='Valid',
        threshold=5,
        patience=3,
        get_value=lambda x: np.mean([j for i, j in x])),
    training.History()
])
trainer.run()

# ===========================================================================
# Evaluation and visualization
# ===========================================================================
trainer['History'].print_epoch('Train')
trainer['History'].print_epoch('Valid')