Example #1
0
def main(args):
    # Build mask parameters DataFrame
    df_params = pd.DataFrame({"mask_name" : args.mask_name,\
                              "slice_axis" : args.slice_axis,\
                              "n_patches" : args.n_patches,\
                              "overlap"  : args.overlap, \
                              "rotation" : args.rotation})
    #     print(df_params)

    mpl.use(args.mpl_agg)

    data_io.show_header()
    if not os.path.exists(args.seg_path): os.makedirs(args.seg_path)
    if args.run_seg:

        # Understand input data format
        if os.path.isdir(args.ct_fpath):
            tiff_input = True
        elif args.ct_fpath.split('.')[-1] in ("hdf5", "h5"):
            tiff_input = False
            if args.ct_data_tag == "":
                raise ArgumentTypeError("dataset-name required for hdf5")
        else:
            raise ArgumentTypeError(
                "input file type not recognized. must be tiff folder or hdf5 file"
            )

        ct_dfile = data_io.DataFile(args.ct_fpath, \
                                    tiff = tiff_input,\
                                    data_tag = args.ct_data_tag, \
                                    VERBOSITY = args.rw_verbosity)
        ct_dfile.show_stats()
        chunk_shape = ct_dfile.chunk_shape
        if args.stats_only:
            print("\nSet stats_only = False and start over to run program.")
            sys.exit()

        # Load model from model repo
        model_filename = os.path.join(args.model_path,
                                      args.model_name + '.hdf5')
        print("\nStarting segmentation mode ...")
        segmenter = Segmenter(model_filename=model_filename)

        print("Reading CT volume into memory...")
        dd = ct_dfile.read_full()

        if args.preprocess:
            print("\tPreprocessing volume...")
            if not os.path.exists("preprocessor.py"):
                input(
                    "Looked for preprocessor.py, but not found! Please create one and press enter. Or press CTRL+C to exit"
                )

            from preprocessor import preprocessor
            dd = preprocessor(dd)

        for idx, row in df_params.iterrows():  # iterate over masks

            # assign arguments from df_params for this mask
            slice_axis = row["slice_axis"]
            max_patches = row["n_patches"]
            segfile_tag = row["mask_name"]
            overlap = row["overlap"]
            rotation = row["rotation"]

            # define DataFile object for mask
            seg_fname = os.path.join(args.seg_path, segfile_tag)
            if not args.tiff_output: seg_fname = seg_fname + ".hdf5"
            seg_dfile = data_io.DataFile(seg_fname, \
                                         data_tag = "SEG",\
                                         tiff = args.tiff_output, \
                                         d_shape = ct_dfile.d_shape, \
                                         d_type = np.uint8, \
                                         chunk_shape = chunk_shape,\
                                         VERBOSITY = args.rw_verbosity)
            seg_dfile.create_new(overwrite=args.overwrite_OK)

            t0 = time.time()

            print("\nWorking on %s\n" % segfile_tag)
            ch = process_data(dd, segmenter, \
                              slice_axis = slice_axis, \
                              rot_angle = rotation, \
                              max_patches = max_patches, \
                              overlap = overlap, \
                              nprocs = args.nprocs, \
                              arr_split = args.arr_split,\
                              arr_split_infer = args.arr_split_infer,\
                              crops = args.crops)

            seg_dfile.write_full(ch)
            t1 = time.time()
            total_time = (t1 - t0) / 60.0
            print(
                "\nDONE on %s\nTotal time for generating %s mask: %.2f minutes"
                % (time.ctime(), segfile_tag, total_time))
            del slice_axis
            del max_patches
            del segfile_tag
            del rotation
            del ch

    if args.run_ensemble:
        print("\nStarting ensemble mode ...\n")

        t0 = time.time()
        # get the d_shape of one of the masks
        temp_fname = os.path.join(args.seg_path, df_params.loc[0, "mask_name"])
        if not args.tiff_output: temp_fname = temp_fname + ".hdf5"
        temp_ds = data_io.DataFile(temp_fname,
                                   data_tag="SEG",
                                   tiff=args.tiff_output,
                                   VERBOSITY=0)
        mask_shape = temp_ds.d_shape
        chunk_shape = temp_ds.chunk_shape
        if not args.run_seg: temp_ds.show_stats()
        del temp_ds
        del temp_fname

        if args.stats_only:
            print("\nSet stats_only = False and start over to run program.")
            sys.exit()

        vote_fname = os.path.join(args.seg_path, args.vote_maskname)
        if not args.tiff_output: vote_fname = vote_fname + ".hdf5"
        vote_dfile = data_io.DataFile(vote_fname, \
                                      tiff = args.tiff_output,\
                                      data_tag = "SEG",\
                                      d_shape = mask_shape, \
                                      d_type = np.uint8, \
                                      chunk_shape = chunk_shape,\
                                      VERBOSITY = args.rw_verbosity)
        vote_dfile.create_new(overwrite=args.overwrite_OK)

        slice_start = 0
        n_masks = len(df_params)
        pbar = tqdm(total=mask_shape[0])
        while slice_start < mask_shape[0]:
            ch = [0] * len(df_params)
            for idx, row in df_params.iterrows():
                seg_fname = os.path.join(args.seg_path, row["mask_name"])
                if not args.tiff_output: seg_fname = seg_fname + ".hdf5"


                seg_dfile = data_io.DataFile(seg_fname, \
                                             tiff = args.tiff_output, \
                                             data_tag = "SEG", \
                                             VERBOSITY = args.rw_verbosity)
                if mask_shape != seg_dfile.d_shape:
                    raise ValueError("Shape of all masks must be same")

                ch[idx], s = seg_dfile.read_chunk(axis = 0, \
                                                  slice_start = slice_start, \
                                                  max_GB = args.mem_thres/(n_masks))
            ch = np.asarray(ch)
            ch = np.median(ch, axis=0).astype(np.uint8)

            vote_dfile.write_chunk(ch, axis=0, s=s)
            del ch
            slice_start = s.stop
            pbar.update(s.stop - s.start)
        pbar.close()

        t1 = time.time()
        total_time = (t1 - t0) / 60.0
        print("\nDONE on %s\nTotal time for ensemble mask %s : %.2f minutes" %
              (time.ctime(), args.vote_maskname, total_time))

        if args.remove_masks:
            print("Intermediate masks will be removed.")
            for idx, row in df_params.iterrows():  # iterate over masks
                seg_fname = os.path.join(args.seg_path, row["mask_name"])
                if not args.tiff_output:
                    seg_fname = seg_fname + ".hdf5"
                    os.remove(seg_fname)
                else:
                    rmtree(seg_fname)

    if args.morpho_filt:
        print("\nApplying morphological operations on ensemble vote...")

        vote_fname = os.path.join(args.seg_path, args.vote_maskname)
        if not args.tiff_output: vote_fname = vote_fname + ".hdf5"
        vote_dfile = data_io.DataFile(vote_fname, \
                                      tiff = args.tiff_output,\
                                      data_tag = "SEG",\
                                      VERBOSITY = args.rw_verbosity)

        from ct_segnet.morpho import morpho_filter
        vol = vote_dfile.read_full()
        vol = morpho_filter(vol, radius = args.radius, \
                            ops = args.ops, \
                            crops = args.crops, \
                            invert_mask = args.invert_mask)
        vote_dfile.write_full(vol)

    return
Example #2
0
def main(args_summary, **kwargs):

    mpl.use('Agg')

    # Inputs from argparser
    args.kern_size = _make_tuples(args.kern_size)
    args.kern_size_upconv = _make_tuples(args.kern_size_upconv)
    args.pool_size = _make_tuples(args.pool_size)

    data_io.show_header()

    ############### LOAD OR REBUILD fCNN MODEL ####################
    model_file = os.path.join(args.model_path, args.model_name + ".hdf5")
    model_history = os.path.join(args.model_path, 'history', args.model_name)
    if not os.path.exists(model_history): os.makedirs(model_history)

    if args.rebuild:
        if args.config_model is None:
            raise Exception("Model config file must be provided if rebuilding")

        img_shape = (args.model_size, args.model_size, 1)
        segmenter = build_Unet(img_shape, \
                               n_depth = args.n_depth,\
                               n_pools = args.n_pools,\
                               activation = args.activation,\
                               batch_norm = args.is_batch_norm,\
                               kern_size = args.kern_size,\
                               kern_size_upconv = args.kern_size_upconv,\
                               pool_size = args.pool_size,\
                               dropout_level = args.dropout_level,\
                               loss = args.loss_def,\
                               stdinput = args.stdinput)

        if args.initialize_from is not None:
            print("\nInitializing weights from file:\n%s" %
                  args.initialize_from)
            segmenter.load_weights(args.initialize_from)

        segmenter.save(model_file)
        df_prev = None

    elif os.path.exists(os.path.join(model_history, args.model_name + ".csv")):
        segmenter = load_model(model_file, custom_objects=custom_objects_dict)

        df_prev = pd.read_csv(
            os.path.join(model_history, args.model_name + ".csv"))
    else:
        raise ValueError(
            "Model history not available to retrain. Rebuild required.")

    ###### LOAD TRAINING AND TESTING DATA #############
    train_path = os.path.join(args.data_path, args.train_fname)
    X = data_io.DataFile(train_path, tiff=False, data_tag='X', VERBOSITY=0)
    Y = data_io.DataFile(train_path, tiff=False, data_tag='Y', VERBOSITY=0)

    # used to evaluate accuracy metrics during training
    test_path = os.path.join(args.data_path, args.test_fname)
    Xtest = data_io.DataFile(test_path, tiff=False, data_tag='X', VERBOSITY=0)
    Ytest = data_io.DataFile(test_path, tiff=False, data_tag='Y', VERBOSITY=0)

    print("\nTotal test data shape: " + str(Ytest.d_shape))
    print("Total training data shape: " + str(Y.d_shape))

    #     save_datasnaps(data_generator(X, Y, args.n_save), model_history)

    ##### DO SOME LOGGING #####################

    rw = "w+" if args.rebuild else "a+"
    logfile = os.path.join(model_history, "README_" + args.model_name + ".txt")
    with open(logfile, rw) as f:
        if args.rebuild:
            f.write("\n")
            with redirect_stdout(f):
                segmenter.summary()
        f.write("\nNew Entry\n")
        f.write("Training Started: " + time.ctime() + "\n")
        if not args.rebuild: args_summary = args_summary.split("Defaults")[0]
        f.write(args_summary)
        f.write("\nTotal train data size: %s" % (str(Y.d_shape)))
        f.write("\nTotal test data size: %s" % (str(Ytest.d_shape)))
        f.write("\n")

    ######### START TRAINING ##################
    model_paths = {'name' : args.model_name,\
                   'history' : model_history,\
                   'file' : model_file}

    logger = Logger(Xtest,
                    Ytest,
                    model_paths,
                    args.autosave_freq,
                    df_prev=df_prev,
                    n_test=args.n_test,
                    check_norm=args.check_norm,
                    augmenter_todo=None)
    n_train = min(args.nmax_train, X.d_shape[0])
    t0 = time.time()

    try:

        if args.fit_generator:
            steps_per_epoch = int(
                (1 - args.validation_split) * n_train // args.batch_size)
            validation_steps = int(args.validation_split * n_train //
                                   args.batch_size)

            # data generator will work in 'inplace' mode (faster) and generate data of size = batch_size
            dg = data_generator(X, Y, args.batch_size, \
                                check_norm = args.check_norm, \
                                augmenter_todo = args.augmenter_todo, \
                                min_SNR = args.min_SNR, inplace = True, nprocs = 1)

            hist = segmenter.fit_generator(dg,\
                                           steps_per_epoch = steps_per_epoch,\
                                           validation_data = data_generator(X, Y, args.batch_size),\
                                           validation_steps = validation_steps,\
                                           epochs = args.n_epochs,\
                                           verbose = 2, \
                                           callbacks = [logger])
        else:
            dg = data_generator(X, Y, n_train, \
                                check_norm = args.check_norm, \
                                augmenter_todo = args.augmenter_todo, \
                                min_SNR = args.min_SNR, inplace = False)
            x_train, y_train = next(dg)
            hist = segmenter.fit(x = x_train, y = y_train,\
                                 verbose = 2, initial_epoch = 0, validation_split = args.validation_split,\
                                 epochs = args.n_epochs, batch_size = args.batch_size, callbacks = [logger])

        print("\nModel training complete")
        with open(logfile, "a+") as f:
            f.write("\nTraining Completed: " + time.ctime() + "\n")
            hours = (time.time() - t0) / (60 * 60.0)
            f.write("\nTotal time: %.4f hours\n" % (hours))
            f.write("\nEnd of Entry")
            print("\n Total time: %.4f hours" % (hours))
        segmenter.save(model_file)

        ########### EVALUATED MODEL AND SAVE SOME TEST IMAGES ############
        print("\nSaving some test images...")
        model_results = os.path.join(args.model_path, 'history',
                                     args.model_name, 'testing')
        dg = data_generator(Xtest, Ytest, args.n_save, \
                            augmenter_todo = args.augmenter_todo, \
                            min_SNR = args.min_SNR, \
                            inplace = False, nprocs = 1)

        save_results(dg, model_results, segmenter)

    # Log keyboard interrupt exception
    except KeyboardInterrupt:
        print("\nModel training interrupted")
        with open(logfile, "a+") as f:
            f.write("\nTraining Interrupted after %i epochs: " % logger.i +
                    time.ctime() + "\n")
            f.write("\nEnd of Entry")

    return
Example #3
0
def main(args):

    data_io.show_header()
    print("\nLet's make some training data.\n")

    # Understand input data format
    ct_istiff = data_io._istiff(args.ct_fpath, args.ct_data_tag)
    seg_istiff = data_io._istiff(args.seg_path, args.seg_data_tag)

    dfile_recon = data_io.DataFile(args.ct_fpath,
                                   data_tag=args.ct_data_tag,
                                   tiff=ct_istiff,
                                   VERBOSITY=0)
    dfile_seg = data_io.DataFile(args.seg_path,
                                 data_tag=args.seg_data_tag,
                                 tiff=seg_istiff,
                                 VERBOSITY=0)

    if args.output_fpath == "":
        args.output_fpath = os.path.split(args.seg_path)[0]
    args.output_fname = args.output_fname.split('.')[0] + ".hdf5"
    output_fname = os.path.join(args.output_fpath, args.output_fname)

    print("Grayscale CT data:\n%s\n" % dfile_recon.fname)
    dfile_recon.show_stats()
    print("Manually segmented mask:\n%s\n" % dfile_seg.fname)
    dfile_seg.show_stats()

    if np.prod(dfile_seg.d_shape) != np.prod(dfile_recon.d_shape):
        raise ValueError(
            "CT data and segmentation mask must be exactly same shape / size")

    # Decide domain extents
    crop_shape = data_io.get_cropped_shape(args.crops, dfile_seg.d_shape)

    # Decide patching / slicing strategy
    df_params = pd.DataFrame({
        "slice axis": args.slice_axis,
        "max patches": args.n_patches
    })
    skip_fac = args.skip_fac
    # Estimate the shapes of hdf5 data files to be written
    n_images = 0
    for idx, row in df_params.iterrows():
        _len = np.ceil(crop_shape[row["slice axis"]] / skip_fac)
        _len = _len * np.prod(row["max patches"])
        print("Calculated length of set %i: %i" % (idx + 1, _len))
        n_images = _len + n_images
    write_dshape = (int(n_images), args.model_size, args.model_size)

    # Create datafile objects for writing stuff, and get file paths for it also
    w_dfile_seg = data_io.DataFile(output_fname,
                                   data_tag='Y',
                                   tiff=False,
                                   chunked_slice_size=None,
                                   d_shape=write_dshape,
                                   d_type=np.uint8)
    w_dfile_recon = data_io.DataFile(output_fname,
                                     data_tag='X',
                                     tiff=False,
                                     chunked_slice_size=None,
                                     d_shape=write_dshape,
                                     d_type=np.float32)
    w_dfile_seg.create_new(overwrite=False)
    w_dfile_recon.create_new(overwrite=False)

    # Work on seg data
    print("Working on the segmentation map...")
    d_seg = dfile_seg.read_full()
    d_seg = d_seg[slice(*args.crops[0]),
                  slice(*args.crops[1]),
                  slice(*args.crops[2])]
    gc.collect()
    slice_start = 0
    pbar = tqdm(total=n_images)
    for idx, row in df_params.iterrows():
        p = process_data(d_seg,
                         skip_fac=skip_fac,
                         nprocs=args.nprocs,
                         patch_size=args.model_size,
                         n_patches=row['max patches'],
                         axis=row['slice axis'])
        slice_end = p.shape[0] + slice_start
        s = slice(slice_start, slice_end)
        w_dfile_seg.write_chunk(p, axis=0, s=s)
        slice_start = slice_end
        del p
        pbar.update(s.stop - s.start)
        gc.collect()
    pbar.close()

    del d_seg
    gc.collect()

    # Work on recon data
    print("Working on the grayscale CT volume...")
    slice_start = 0
    d_recon = dfile_recon.read_full()
    d_recon = d_recon[slice(*args.crops[0]),
                      slice(*args.crops[1]),
                      slice(*args.crops[2])]
    pbar = tqdm(total=n_images)
    for idx, row in df_params.iterrows():
        p = process_data(d_recon,
                         skip_fac=skip_fac,
                         nprocs=args.nprocs,
                         patch_size=args.model_size,
                         n_patches=row['max patches'],
                         axis=row['slice axis'])
        slice_end = p.shape[0] + slice_start
        s = slice(slice_start, slice_end)
        w_dfile_recon.write_chunk(p, axis=0, s=s)
        slice_start = slice_end
        del p
        pbar.update(s.stop - s.start)
        gc.collect()
    pbar.close()
Example #4
0
def main(args):
    data_io.show_header()
    # Understand input data format
    if os.path.isdir(args.input_fname):
        tiff_input = True
        if args.dataset_name == "": args.dataset_name = "data"
    elif args.input_fname.split('.')[-1] in ("hdf5", "h5"):
        tiff_input = False
        if args.dataset_name == "": raise ArgumentTypeError("dataset-name required for hdf5")
    else:
        raise ArgumentTypeError("input file type not recognized. must be tiff folder or hdf5 file")
    input_fname = args.input_fname
    
    # set up output file name / path / chunks parameter
    if args.output_fpath == "":
        args.output_fpath = os.path.split(args.input_fname)[0]
    args.output_fname = args.output_fname.split('.')[0] + ".hdf5"
    output_fname = os.path.join(args.output_fpath, args.output_fname)
    if type(args.chunk_param) in (int, float):
        chunk_size = args.chunk_param/1e3 # convert to GB
        chunk_shape = None
    elif type(args.chunk_param) == tuple:
        chunk_shape = args.chunk_param
        chunk_size = None
    chunked_slice_size = args.chunked_slice_size
    
#     print("Type chunk_param" + str(type(args.chunk_param)))
#     print("Type chunked_slice_size" + str(type(args.chunked_slice_size)))

#     print("Overwrite OK is %s"%args.overwrite_OK)
#     print("Stats only is %s"%args.stats_only)
#     print("Delete is %s"%args.delete)
#     sys.exit()
        
    # Define DataFile instances - quit here if stats_only requested
    r_dfile = data_io.DataFile(input_fname, tiff = tiff_input, \
                               data_tag = args.dataset_name, \
                               VERBOSITY = args.verbosity)

    print("Input data stats:")
    r_dfile.show_stats()
    if args.stats_only:
        sys.exit()
    
    w_shape = r_dfile.d_shape # future implementation must allow resampling dataset
    w_dtype = r_dfile.d_type # future implementation must allow changing dtype (with renormalization)
    
    w_dfile = data_io.DataFile(output_fname, tiff = False, \
                               data_tag = args.dataset_name, \
                               VERBOSITY = args.verbosity, \
                               d_shape = w_shape, d_type = w_dtype, \
                               chunk_shape = chunk_shape, \
                               chunk_size = chunk_size, \
                               chunked_slice_size = chunked_slice_size)
    
    print("\nChunking scheme estimated as: %s"%str(w_dfile.chunk_shape))
    input("\nHDF5 file will be saved to the following location.\n%s\nPress any key to continue."%output_fname)
    w_dfile.create_new(overwrite = args.overwrite_OK)
    
    t0 = time.time()
    slice_start = 0
    print("\n")
    pbar = tqdm(total = r_dfile.d_shape[0])
    while slice_start < r_dfile.d_shape[0]:
        dd, s = r_dfile.read_chunk(axis = 0, slice_start = slice_start, \
                                   max_GB = mem_thres, \
                                   chunk_shape = w_dfile.chunk_shape)
        w_dfile.write_chunk(dd, axis = 0, s = s)
        slice_start = s.stop
        pbar.update(s.stop - s.start)
    pbar.close()
    total_time = (time.time() - t0)/60.0 # minutes
    print("\nTotal time: %.2f minutes"%(total_time))
    
    if args.delete:
        input("Delete old file? Press any key")
        if tiff_input:
            rmtree(input_fname)
        else:
            os.remove(input_fname)
Example #5
0
def main(args):
    df_params = pd.DataFrame({"mask_name" : args.mask_name,\
                              "slice_axis" : args.slice_axis,\
                              "n_patches" : args.n_patches,\
                              "overlap"  : args.overlap})
    mpl.use(args.mpl_agg)
#     print(df_params)

    data_io.show_header()
    if not os.path.exists(args.seg_path): os.makedirs(args.seg_path)
    if args.run_seg:
        # Load model from model repo
        ct_dfile = data_io.DataFile(args.ct_fpath, \
                                    tiff = False,\
                                    data_tag = args.ct_data_tag, \
                                    VERBOSITY = args.rw_verbosity)
        ct_dfile.show_stats()
        if args.stats_only:
            print("\nSet stats_only = False and start over to run program.")
            sys.exit()
        
        chunk_shape = ct_dfile.chunk_shape
        model_filename = os.path.join(args.model_path, args.model_name + '.hdf5')
        print("\nStarting segmentation mode ...")
        segmenter = Segmenter(model_filename = model_filename)
        
        for idx, row in df_params.iterrows(): # iteratve over masks

            # assign arguments from df_params for this mask
            slice_axis = row["slice_axis"]
            max_patches = row["n_patches"]
            segfile_tag = row["mask_name"]
            overlap = row["overlap"]
            
            # define DataFile object for mask
            seg_fname = os.path.join(args.seg_path, segfile_tag + ".hdf5")
            seg_dfile = data_io.DataFile(seg_fname, \
                                         data_tag = "SEG",\
                                         tiff = False, \
                                         d_shape = ct_dfile.d_shape, \
                                         d_type = np.uint8, \
                                         chunk_shape = chunk_shape,\
                                         VERBOSITY = args.rw_verbosity)            
            seg_dfile.create_new(overwrite = args.overwrite_OK)

            t0 = time.time()
            slice_start = 0
            print("\nWorking on %s\n"%segfile_tag)
            pbar = tqdm(total = seg_dfile.d_shape[slice_axis])
            while slice_start < seg_dfile.d_shape[slice_axis]:
                ch, s = ct_dfile.read_chunk(axis = slice_axis, \
                                            slice_start = slice_start, \
                                            max_GB = args.mem_thres)
                ch = segmenter.seg_chunk(ch, \
                                         max_patches = max_patches, \
                                         overlap = overlap,\
                                         nprocs = args.nprocs, \
                                         arr_split = args.arr_split)
                seg_dfile.write_chunk(ch, axis = slice_axis, s = s)
                del ch
                slice_start = s.stop
                pbar.update(s.stop - s.start)
            pbar.close()
            t1 = time.time()
            total_time = (t1 - t0) / 60.0
            print("\nDONE on %s\nTotal time for generating %s mask: %.2f minutes"%(time.ctime(), segfile_tag, total_time))
            del slice_axis
            del max_patches
            del segfile_tag
        
    if args.run_ensemble:
        print("\nStarting ensemble mode ...\n")

        t0 = time.time()
        # get the d_shape of one of the masks
        temp_fname = os.path.join(args.seg_path, df_params.loc[0,"mask_name"] + ".hdf5")
        temp_ds = data_io.DataFile(temp_fname, data_tag = "SEG", tiff = False, VERBOSITY = 0)
        mask_shape = temp_ds.d_shape
        chunk_shape = temp_ds.chunk_shape
        if not args.run_seg: temp_ds.show_stats()
        del temp_ds
        del temp_fname

        if args.stats_only:
            print("\nSet stats_only = False and start over to run program.")
            sys.exit()
        
        vote_fname = os.path.join(args.seg_path, args.vote_maskname)
        if not args.tiff_output: vote_fname = vote_fname + ".hdf5"
        vote_dfile = data_io.DataFile(vote_fname, \
                                      tiff = args.tiff_output,\
                                      data_tag = "SEG", \
                                      d_shape = mask_shape, \
                                      d_type = np.uint8, \
                                      chunk_shape = chunk_shape,\
                                      VERBOSITY = args.rw_verbosity)            
        vote_dfile.create_new(overwrite = args.overwrite_OK)
        
        slice_start = 0
        n_masks = len(df_params)
        pbar = tqdm(total = mask_shape[0])
        while slice_start < mask_shape[0]:
            ch = [0]*len(df_params)
            for idx, row in df_params.iterrows():
                seg_fname = os.path.join(args.seg_path, row["mask_name"]) + ".hdf5"
                
                seg_dfile = data_io.DataFile(seg_fname, \
                                             tiff = False, \
                                             data_tag = "SEG", \
                                             VERBOSITY = args.rw_verbosity)        
                if mask_shape != seg_dfile.d_shape:
                    raise ValueError("Shape of all masks must be same")
                    
                ch[idx], s = seg_dfile.read_chunk(axis = 0, \
                                                  slice_start = slice_start, \
                                                  max_GB = args.mem_thres/(n_masks))
            ch = np.asarray(ch)
            ch = np.median(ch, axis = 0).astype(np.uint8)
            vote_dfile.write_chunk(ch, axis = 0, s = s)
            del ch
            slice_start = s.stop            
            pbar.update(s.stop - s.start)
        pbar.close()

        t1 = time.time()
        total_time = (t1 - t0) / 60.0
        print("\nDONE on %s\nTotal time for ensemble mask %s : %.2f minutes"%(time.ctime(), args.vote_maskname, total_time))


    if args.remove_masks:
        print("Intermediate masks will be removed.")
        for idx, row in df_params.iterrows(): # iterate over masks
            seg_fname = os.path.join(args.seg_path, row["mask_name"]) + ".hdf5"
            os.remove(seg_fname)
        
    return