Ejemplo n.º 1
0
def train_leave_one_out(tempStore, modelPath, testOutputDir, Reference):
    
    print('-'*30)
    print('Loading all the data...')
    print('-'*30)
    imgs_train, imgs_label_train, addInformation_train, imgs_id_train = load_train_data(tempStore)

    imgs_train = preprocess(imgs_train)
    imgs_label_train = preprocess(imgs_label_train)
    addInformation_train = preprocess(addInformation_train)

    imgs_train = imgs_train.astype('float32')

    if IfglobalNorm == True:
        mean = np.mean(imgs_train)  # mean for data centering
        std = np.std(imgs_train)  # std for data normalization
        imgs_train -= mean
        imgs_train /= std

#   save mean and std of training data    
    imgs_label_train = imgs_label_train.astype(np.uint32)
    addInformation_train = addInformation_train.astype(np.float32)

    TotalNum = len(imgs_id_train)
    preImageList = []
    for i in xrange(TotalNum):
        inBaseName = os.path.basename(imgs_id_train[i])
        outBaseName = string.join(inBaseName.split("_")[-4:-1], "_")
        currentTrainImgs = np.delete(imgs_train,range(i*sliceNum,(i+1)*sliceNum), axis=0)
        currentTrainLab = np.delete(imgs_label_train,range(i*sliceNum,(i+1)*sliceNum), axis=0)
        currentTrainAdd = np.delete(addInformation_train,range(i*sliceNum,(i+1)*sliceNum), axis=0)
    
        # begin the model
        #############################
        model = get_unet_short()
        #############################
        weightName = modelPath + '/' + outBaseName + '_weights.h5'
        model_checkpoint = ModelCheckpoint(weightName, monitor='val_loss', save_best_only=True)
        early_stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=40, verbose=0, mode='auto')        

        train_history = model.fit([currentTrainImgs, currentTrainAdd], currentTrainLab, batch_size,\
        epochs, verbose=1, shuffle=True, validation_split=0.2,\
        callbacks=[model_checkpoint, early_stop])               

        loss = train_history.history['loss']
        val_loss = train_history.history['val_loss']
        np.save(tempStore + '/' + outBaseName + '_loss.npy',loss)
        np.save(tempStore + '/' + outBaseName + '_val_loss',val_loss)
        
        # prediction
        currentTestImgs = imgs_train[i*sliceNum:(i+1)*sliceNum,:,:,:]
        currentTestAdd = addInformation_train[i*sliceNum:(i+1)*sliceNum,:,:,:]               
        
        model.load_weights(weightName)
        imgs_label_test = model.predict([currentTestImgs,currentTestAdd], verbose=1)
        ThreeDImagePath = VolumeDataTofiles(imgs_label_test, outBaseName, testOutputDir, Reference)
        preImageList.append(ThreeDImagePath)    
        # np.save(tempStore + '/' + outBaseName + '_imgs_label_test.npy', imgs_label_test)

        print('-'*30)
        print(str(i) + 'th is finished...')
        print('-'*30)

    WriteListtoFile(preImageList, testOutputDir + '/FileList.txt')
Ejemplo n.º 2
0
def train_and_predict(tempStore, modelPath):
    print('-' * 30)
    print('Loading and preprocessing train data...')
    print('-' * 30)
    imgs_train, imgs_label_train, imgs_id_train = load_train_data(tempStore)

    imgs_train = preprocess(imgs_train)
    imgs_label_train = preprocess(imgs_label_train)

    imgs_train = imgs_train.astype('float32')

    if IfglobalNorm == True:
        mean = np.mean(imgs_train)  # mean for data centering
        std = np.std(imgs_train)  # std for data normalization
        imgs_train -= mean
        imgs_train /= std


#   save mean and std of training data
    imgs_label_train = imgs_label_train.astype(np.uint32)

    print('-' * 30)
    print('Creating and compiling model...')
    print('-' * 30)

    ##############################
    model = get_unet_short()
    ##############################

    model_checkpoint = ModelCheckpoint(os.path.join(modelPath, 'weights.h5'),
                                       monitor='val_loss',
                                       save_best_only=True)

    early_stop = EarlyStopping(monitor='val_loss',
                               min_delta=0,
                               patience=40,
                               verbose=0,
                               mode='auto')
    #
    print('-' * 30)
    print('Fitting model...')
    print('-' * 30)
    train_history = model.fit([imgs_train], imgs_label_train, batch_size,\
    epochs, verbose=1, shuffle=True, validation_split=0.2,\
    callbacks=[model_checkpoint, early_stop])

    loss = train_history.history['loss']
    val_loss = train_history.history['val_loss']
    np.save(tempStore + '/loss.npy', loss)
    np.save(tempStore + '/val_loss', val_loss)

    print('-' * 30)
    print('Loading and preprocessing test data...')
    print('-' * 30)
    imgs_test, imgs_id_test = load_test_data(tempStore)
    imgs_test = preprocess(imgs_test)

    imgs_test = imgs_test.astype('float32')

    if IfglobalNorm == True:
        imgs_test -= mean
        imgs_test /= std

    print('-' * 30)
    print('Loading saved weights...')
    print('-' * 30)
    model.load_weights(os.path.join(modelPath, 'weights.h5'))

    print('-' * 30)
    print('Predicting masks on test data...')
    print('-' * 30)
    imgs_label_test = model.predict([imgs_test], verbose=1)
    np.save(os.path.join(tempStore, 'imgs_label_test.npy'), imgs_label_test)

    if IfglobalNorm == True:
        np.save(os.path.join(tempStore, 'imgs_mean.npy'), mean)
        np.save(os.path.join(tempStore, 'imgs_std.npy'), std)