def compute_centroids_estimate(ds, net_wo_centroids, net_centroids, scaler, n_iter=3):
    pred = np.argmax(net_wo_centroids.predict(ds.inputs, 1000), axis=1)
    r = RegionCentroids(134)
    r.update_barycentres(ds.vx, pred)
    p = PickCentroidDistances(134)
    distances = p.pick(ds.vx, None, None, r)[0]
    ds.inputs[:, -134:] = distances
    scaler.scale(ds.inputs)

    # New evaluations
    for i in range(n_iter):
        d = compute_dice(pred, np.argmax(ds.outputs, axis=1), 134)
        print np.mean(d)

        pred = np.argmax(net_centroids.predict(ds.inputs, 1000), axis=1)
        r.update_barycentres(ds.vx, pred)
        distances = p.pick(ds.vx, None, None, r)[0]
        ds.inputs[:, -134:] = distances
        scaler.scale(ds.inputs)
Example #2
0
    # Create the data generator
    data_gen = DataGeneratorBrain()
    data_gen.init_from(file_list, pick_vx, pick_patch, pick_tg)

    # Evaluate the centroids
    net_wo_centroids_path = "./experiments/report_3_patches_balanced_conv/"
    net_wo_centroids = NetworkThreePatchesConv()
    net_wo_centroids.init(29, 135)
    net_wo_centroids.load_parameters(
        open_h5file(net_wo_centroids_path + "net.net"))
    ds_testing = DatasetBrainParcellation()
    ds_testing.read(data_path + "train.h5")
    pred_wo_centroids = np.argmax(net_wo_centroids.predict(
        ds_testing.inputs, 1000),
                                  axis=1)
    region_centroids = RegionCentroids(134)
    region_centroids.update_barycentres(ds_testing.vx, pred_wo_centroids)

    # Generate and evaluate the dataset
    start_time = time.clock()
    dices = np.zeros((n_files, 134))
    errors = np.zeros((n_files, ))

    pred_functions = {}
    for atlas_id in xrange(n_files):

        print "Atlas: {}".format(atlas_id)

        ls_vx = []
        ls_pred = []
        brain_batches = data_gen.generate_single_atlas(atlas_id, None,