def compute_centroids_estimate(ds, net_wo_centroids, net_centroids, scaler, n_iter=3): pred = np.argmax(net_wo_centroids.predict(ds.inputs, 1000), axis=1) r = RegionCentroids(134) r.update_barycentres(ds.vx, pred) p = PickCentroidDistances(134) distances = p.pick(ds.vx, None, None, r)[0] ds.inputs[:, -134:] = distances scaler.scale(ds.inputs) # New evaluations for i in range(n_iter): d = compute_dice(pred, np.argmax(ds.outputs, axis=1), 134) print np.mean(d) pred = np.argmax(net_centroids.predict(ds.inputs, 1000), axis=1) r.update_barycentres(ds.vx, pred) distances = p.pick(ds.vx, None, None, r)[0] ds.inputs[:, -134:] = distances scaler.scale(ds.inputs)
# Create the data generator data_gen = DataGeneratorBrain() data_gen.init_from(file_list, pick_vx, pick_patch, pick_tg) # Evaluate the centroids net_wo_centroids_path = "./experiments/report_3_patches_balanced_conv/" net_wo_centroids = NetworkThreePatchesConv() net_wo_centroids.init(29, 135) net_wo_centroids.load_parameters( open_h5file(net_wo_centroids_path + "net.net")) ds_testing = DatasetBrainParcellation() ds_testing.read(data_path + "train.h5") pred_wo_centroids = np.argmax(net_wo_centroids.predict( ds_testing.inputs, 1000), axis=1) region_centroids = RegionCentroids(134) region_centroids.update_barycentres(ds_testing.vx, pred_wo_centroids) # Generate and evaluate the dataset start_time = time.clock() dices = np.zeros((n_files, 134)) errors = np.zeros((n_files, )) pred_functions = {} for atlas_id in xrange(n_files): print "Atlas: {}".format(atlas_id) ls_vx = [] ls_pred = [] brain_batches = data_gen.generate_single_atlas(atlas_id, None,