def train(batchNum = 500, batchSize = 200000, learningRate = 0.001, ImagePatchWidth = 20, #layers = [500, 1000, 500], ImagePatchStep = 4, labelOptionNum = 100, labelMode = 'NUM'): trainDS = ds.read_data_sets(ImagePatchWidth, ImagePatchStep, labelOptionNum, 'train', labelMode) classifier = skflow.TensorFlowEstimator( model_fn = conv_model, n_classes = labelOptionNum, batch_size = batchSize, steps = batchNum, learning_rate = learningRate) classifier.fit(trainDS.images, np.argmax(trainDS.labels, axis = 1), logdir = gv.__DIR__ + gv.tensorflow_log_dir) return classifier
def test(classifier, ImagePatchWidth = 20, ImagePatchStep = 4, labelOptionNum = 100, labelMode = 'PRO'): image_files, bubble_num, bubble_regions = getinfo() result_filename = gv.cnn__result_filename accuracy_filename = gv.cnn__accuracy_filename result = np.zeros((len(image_files),1)) accuracy = np.zeros((len(image_files),4)) index = -1 start_time = time.time() PROGRESS = progress.progress(0, len(image_files), prefix_info = 'Labeling ') for i, image_file in enumerate(image_files): testDS = ds.read_data_sets(ImagePatchWidth, ImagePatchStep, labelOptionNum, 'test', labelMode, imageName = image_file) y = classifier.predict(testDS.images) index = index + 1 result[index] = np.sum(y) # saving labeled result as image io.imsave(gv.__DIR__ + gv.cnn__image_dir + image_file, np.reshape(y, (testDS.ylength, testDS.xlength))) _y = np.argmax(testDS.labels, axis = 1) # total accuracy accuracy[index, 0] = np.true_divide(np.sum(y == _y), _y.size) # accuracy of negative labeled instances accuracy[index, 1] = np.true_divide(np.sum(np.all( [y == _y, _y == 0], axis = 0)), np.sum(_y == 0)) # accuracy of positive labeled instances accuracy[index, 2] = np.true_divide(np.sum(np.all( [y == _y, _y > 0], axis = 0)), np.sum(_y > 0)) # average difference sum accuracy[index, 3] = np.true_divide( np.sum(np.absolute(np.subtract(y, _y))), _y.size) PROGRESS.setCurrentIteration(i+1) PROGRESS.setInfo(suffix_info = image_file) PROGRESS.printProgress() accuracy.tofile(accuracy_filename, sep =" ") return [result, accuracy]