예제 #1
0
def train():
    print('-'*30)
    print('Loading and preprocessing train data...')
    print('-'*30)
    imgs_train, imgs_gtruth_train = load_train_data()
    
    print('-'*30)
    print('Loading and preprocessing validation data...')
    print('-'*30)
    imgs_val, imgs_gtruth_val  = load_validatation_data()
    
    print('-'*30)
    print('Creating and compiling model...')
    print('-'*30)

   # create a model
    model = unet_model_3d(input_shape=config["input_shape"],
                                depth=config["depth"],
                                pool_size=config["pool_size"],
                                n_labels=config["n_labels"],
                                initial_learning_rate=config["initial_learning_rate"],
                                deconvolution=config["deconvolution"])

    model.summary()
    
    print('-'*30)
    print('Fitting model...')
    print('-'*30)
    
    #============================================================================
    print('training starting..')
    log_filename = 'outputs/' + image_type +'_model_train.csv' 
    
    
    csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True)
    
#    early_stopping = callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='min')
    
    #checkpoint_filepath = 'outputs/' + image_type +"_best_weight_model_{epoch:03d}_{val_loss:.4f}.hdf5"
    checkpoint_filepath = 'outputs/' + 'weights.h5'
    
    checkpoint = callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
    
    callbacks_list = [csv_log, checkpoint]
    callbacks_list.append(ReduceLROnPlateau(factor=config["learning_rate_drop"], patience=config["patience"],
                                           verbose=True))
    callbacks_list.append(EarlyStopping(verbose=True, patience=config["early_stop"]))

    #============================================================================
    hist = model.fit(imgs_train, imgs_gtruth_train, batch_size=config["batch_size"], nb_epoch=config["n_epochs"], verbose=1, validation_data=(imgs_val,imgs_gtruth_val), shuffle=True, callbacks=callbacks_list) #              validation_split=0.2,
        
     
    model_name = 'outputs/' + image_type + '_model_last'
    model.save(model_name)  # creates a HDF5 file 'my_model.h5'
def train_and_predict():
    print('-'*30)
    print('Loading and preprocessing train data...')
    print('-'*30)
    imgs_train, imgs_gtruth_train = load_train_data()

    imgs_train = np.transpose(imgs_train, (0, 4, 1, 2, 3))
    imgs_gtruth_train = np.transpose(imgs_gtruth_train, (0, 4, 1, 2, 3))
    
    print('-'*30)
    print('Loading and preprocessing validation data...')
    print('-'*30)
    
    imgs_val, imgs_gtruth_val  = load_validatation_data()
    imgs_val = np.transpose(imgs_val, (0, 4, 1, 2, 3))
    imgs_gtruth_val = np.transpose(imgs_gtruth_val, (0, 4, 1, 2, 3))
    
    
    print('-'*30)
    print('Creating and compiling model...')
    print('-'*30)

   # create a model
    model = isensee2017_model(input_shape=config["input_shape"], n_labels=config["n_labels"],
                                          initial_learning_rate=config["initial_learning_rate"],
                                          n_base_filters=config["n_base_filters"],loss_function=dice_coef_loss)

    model.summary()



    
    #summarize layers
    #print(model.summary())
    # plot graph
    #plot_model(model, to_file='3d_unet.png')
    
    print('-'*30)
    print('Fitting model...')
    print('-'*30)
    
    #============================================================================
    print('training starting..')
    log_filename = 'outputs/' + image_type +'_model_train.csv' 
    
    
    csv_log = callbacks.CSVLogger(log_filename, separator=',', append=True)
    
#    early_stopping = callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='min')
    
    #checkpoint_filepath = 'outputs/' + image_type +"_best_weight_model_{epoch:03d}_{val_loss:.4f}.hdf5"
    checkpoint_filepath = 'outputs/' + 'weights.h5'
    
    checkpoint = callbacks.ModelCheckpoint(checkpoint_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
    
    callbacks_list = [csv_log, checkpoint]
    callbacks_list.append(ReduceLROnPlateau(factor=config["learning_rate_drop"], patience=config["patience"],
                                           verbose=True))
    callbacks_list.append(EarlyStopping(verbose=True, patience=config["early_stop"]))

    #============================================================================
    hist = model.fit(imgs_train, imgs_gtruth_train, batch_size=config["batch_size"], nb_epoch=config["n_epochs"], verbose=1, validation_data=(imgs_val,imgs_gtruth_val), shuffle=True, callbacks=callbacks_list) #              validation_split=0.2,
        
     
    model_name = 'outputs/' + image_type + '_model_last'
    model.save(model_name)  # creates a HDF5 file 'my_model.h5'