示例#1
0
def train12Net(X, Y, windowSize=24, scaleFactor=2, stepSize=32, batchSize=16, nbEpoch=10):

    modelFileName = '12_trained_model_w' + str(windowSize) + '_scale' + str(scaleFactor) + '_step' + str(stepSize) + '.h5'
    #Load model architecture
    model = model_architecture.setUp12net(windowSize)
    if (os.path.exists(os.getcwd()+'/' + modelFileName)):
        model.load_weights(modelFileName)
        print("Loaded model: " + modelFileName)

    else:
        print("No model stored, creating new")

    print("======Training....======")
    history = model.fit(X, Y, batch_size=batchSize, nb_epoch=nbEpoch, verbose=1)
    print("Finished training!")
    model.save_weights(modelFileName, overwrite=True)
    print("saved model to: " + modelFileName) 
    return history
    ]
else:
    imdb = il.loadAndPreProcessIms('annotations_short.txt', scaleFactor,
                                   (windowSize, windowSize))

print("Image database: {0} images".format(len(imdb)))
[X, Y, W] = il.getCNNFormat(imdb, stepSize, windowSize)
print("finished preprocessing in {0}".format(time.time() - start_time))

print("=========================================")
#================================================
# 12 net
#================================================

print("\n\n============== 12Net ====================")
model12 = model_architecture.setUp12net(windowSize)
print("Loading model from: " + model12FileName)
model12.load_weights(model12FileName)

model48 = model_architecture.setUp48net(windowSize48)
print("Loading model from: " + model48FileName)
model48.load_weights(model48FileName)

# Get best predictions

start_time = time.time()
predictions_12 = model12.predict(X, batch_size=16, verbose=1)
# Get top 10%
targets = np.squeeze(predictions_12)
nb_top_targets = int(math.ceil(targets.shape[0] * 0.1))
p_idx = np.argsort(targets)[-nb_top_targets:]
if (len(sys.argv) > 3):
        imdb =[il.loadAndPreProcessSingle(str(sys.argv[3]), scaleFactor, (windowSize, windowSize))]
else:
    imdb = il.loadAndPreProcessIms('annotations_test_short.txt', scaleFactor, (windowSize,windowSize))

print("Image database: {0} images".format(len(imdb)))
[X, Y, W] = il.getCNNFormat(imdb, stepSize, windowSize)
print("finished preprocessing in {0}".format(time.time()-start_time))

print("=========================================")
#================================================
# 12 net
#================================================

print("\n\n============== 12Net ====================")
model12 = model_architecture.setUp12net(windowSize)
print("Loading model from: " + model12FileName)
model12.load_weights(model12FileName)

model48 = model_architecture.setUp48net(windowSize48)
print("Loading model from: " + model48FileName)
model48.load_weights(model48FileName)

# Get best predictions
start_time = time.time()

predictions_12 = model12.predict(X, batch_size=16, verbose=1)
# Get top 10%
targets  = np.squeeze(predictions_12)
nb_top_targets = int(math.ceil(targets.shape[0]*0.1))
p_idx = np.argsort(targets)[-nb_top_targets:]