Ejemplo n.º 1
0
def main(unused_argv):

    request = inference_pb2.InferenceRequest()
    with open(FLAGS.parameter_file, mode='r') as f:
        text_list = f.readlines()
    text = ' '.join(text_list)
    text_format.Parse(text, request)

    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)
    runner = inference.Runner()
    runner.start(request)
    #  runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
    #             (bbox.size.z, bbox.size.y, bbox.size.x))
    runner.run((0, 0, 0), (int(FLAGS.image_size_z), int(
        FLAGS.image_size_y), int(FLAGS.image_size_x)))

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.Exists(counter_path):
        runner.counters.dump(counter_path)
Ejemplo n.º 2
0
def main(idx,
         move_threshold=0.7,
         segment_threshold=0.6,
         validate=False,
         seed='15,15,17',
         rotate=False):
    """Apply the FFN routines using fGRUs."""
    SEED = np.array([int(x) for x in seed.split(',')])
    rdirs(SEED, MEM_STR)
    model_shape = (SHAPE * PATH_EXTENT)
    mpath = MEM_STR % (pad_zeros(SEED[0], 4), pad_zeros(
        SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
            SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4))
    if idx == 0:
        # 1. select a volume
        if not validate:
            if np.all(PATH_EXTENT == 1):
                path = PATH_STR % (pad_zeros(SEED[0], 4), pad_zeros(
                    SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
                        SEED[0], 4), pad_zeros(SEED[1],
                                               4), pad_zeros(SEED[2], 4))
                vol = np.fromfile(path, dtype='uint8').reshape(SHAPE)
            else:
                vol = np.zeros((np.array(SHAPE) * PATH_EXTENT))
                for z in range(PATH_EXTENT[0]):
                    for y in range(PATH_EXTENT[1]):
                        for x in range(PATH_EXTENT[2]):
                            path = PATH_STR % (pad_zeros(
                                SEED[0] + x, 4), pad_zeros(
                                    SEED[1] + y, 4), pad_zeros(SEED[2] + z, 4),
                                               pad_zeros(SEED[0] + x, 4),
                                               pad_zeros(SEED[1] + y, 4),
                                               pad_zeros(SEED[2] + z, 4))
                            v = np.fromfile(path, dtype='uint8').reshape(SHAPE)
                            vol[z * SHAPE[0]:z * SHAPE[0] + SHAPE[0],
                                y * SHAPE[1]:y * SHAPE[1] + SHAPE[1],
                                x * SHAPE[2]:x * SHAPE[2] + SHAPE[2]] = v
        else:
            data = np.load(
                '/media/data_cifs/connectomics/datasets/berson_0.npz')
            vol = data['volume'][:model_shape[0]]
            SEED = [99, 99, 99]
            mpath = MEM_STR % (pad_zeros(SEED[0], 4), pad_zeros(
                SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
                    SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4))
            rdirs(SEED, MEM_STR)
        print('seed: %s' % SEED)
        print('mpath: %s' % mpath)
        vol = vol.astype(np.float32) / 255.

        # 2. Predict its membranes
        membranes = fgru.main(test=vol,
                              evaluate=True,
                              adabn=True,
                              test_input_shape=np.concatenate(
                                  (model_shape, [1])).tolist(),
                              test_label_shape=np.concatenate(
                                  (model_shape, [12])).tolist(),
                              checkpoint=MEMBRANE_CKPT)

        # 3. Concat the volume w/ membranes and pass to FFN
        if MEMBRANE_TYPE == 'probability':
            print 'Membrane: %s' % MEMBRANE_TYPE
            proc_membrane = (
                membranes[0, :, :, :, :3].mean(-1)).transpose(FFN_TRANSPOSE)
        elif MEMBRANE_TYPE == 'threshold':
            print 'Membrane: %s' % MEMBRANE_TYPE
            proc_membrane = (membranes[0, :, :, :, :3].mean(-1) >
                             0.5).astype(int).transpose(FFN_TRANSPOSE)
        else:
            raise NotImplementedError
        vol = vol.transpose(FFN_TRANSPOSE)  # ).astype(np.uint8)
        membranes = np.round(np.stack((vol, proc_membrane), axis=-1) *
                             255).astype(np.float32)  # np.uint8)
        if rotate:
            membranes = np.rot90(membranes, k=1, axes=(1, 2))
        np.save(mpath, membranes)
        print 'Saved membrane volume to %s' % mpath
    mpath = '%s.npy' % mpath

    # 4. Start FFN
    ckpt_path = '/media/data_cifs/connectomics/ffn_ckpts/wide_fov/htd_cnn_3l_in_berson3x_w_inf_memb_r0/model.ckpt-1212476'  # model.ckpt-933785
    model = 'htd_cnn_3l_in'
    # ckpt_path = '/media/data_cifs/connectomics/ffn_ckpts/wide_fov/htd_cnn_3l_berson3x_w_inf_memb_r0/model.ckpt-924330'
    # model = 'htd_cnn_3l'
    # ckpt_path = '/media/data_cifs/connectomics/ffn_ckpts/wide_fov/htd_cnn_3l_trainablestat_berson3x_w_inf_memb_r0/model.ckpt-1261571'
    # model = 'htd_cnn_3l_trainablestat'

    if validate:
        SEED = [99, 99, 99]
    deltas = '[14, 14, 3]'  # '[27, 27, 6]'
    seg_dir = 'ding_segmentations/x%s/y%s/z%s/v%s/' % (pad_zeros(
        SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4), idx)
    print 'Saving segmentations to: %s' % seg_dir
    if idx == 0:
        seed_policy = 'PolicyMembrane'  # 'PolicyPeaks'
    else:
        seed_policy = 'ShufflePolicyPeaks'
    seed_policy = 'ShufflePolicyPeaks'  # 'PolicyMembrane'
    config = '''image {hdf5: "%s:raw" }
        image_mean: 128
        image_stddev: 33
        seed_policy: "%s"
        model_checkpoint_path: "%s"
        model_name: "%s.ConvStack3DFFNModel"
        model_args: "{\\"depth\\": 12, \\"fov_size\\": [57, 57, 13], \\"deltas\\": %s}"
        segmentation_output_dir: "%s"
        inference_options {
            init_activation: 0.95
            pad_value: 0.05
            move_threshold: %s
            min_boundary_dist { x: 1 y: 1 z: 1}
            segment_threshold: %s
            min_segment_size: 1000
        }''' % (
        '/media/data_cifs/connectomics/datasets/third_party/wide_fov/berson_w_inf_memb/train/grayscale_maps.h5',
        seed_policy, ckpt_path, model, deltas, seg_dir, move_threshold,
        segment_threshold)

    req = inference_pb2.InferenceRequest()
    _ = text_format.Parse(config, req)
    runner = inference.Runner()
    runner.start(req, tag='_inference')
    runner.run((0, 0, 0), (model_shape[0], model_shape[1], model_shape[2]))
Ejemplo n.º 3
0
def main(idx, validate=False, seed='15,15,17'):
    """Apply the FFN routines using fGRUs."""
    SEED = np.array([int(x) for x in seed.split(',')])
    rdirs(SEED, MEM_STR)
    model_shape = (SHAPE * PATH_EXTENT)
    mpath = MEM_STR % (pad_zeros(SEED[0], 4), pad_zeros(
        SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
            SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4))
    if idx == 0:
        # 1. select a volume
        if not validate:
            if np.all(PATH_EXTENT == 1):
                path = PATH_STR % (pad_zeros(SEED[0], 4), pad_zeros(
                    SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
                        SEED[0], 4), pad_zeros(SEED[1],
                                               4), pad_zeros(SEED[2], 4))
                vol = np.fromfile(path, dtype='uint8').reshape(SHAPE)
            else:
                vol = np.zeros((np.array(SHAPE) * PATH_EXTENT))
                for z in range(PATH_EXTENT[0]):
                    for y in range(PATH_EXTENT[1]):
                        for x in range(PATH_EXTENT[2]):
                            path = PATH_STR % (pad_zeros(
                                SEED[0] + x, 4), pad_zeros(
                                    SEED[1] + y, 4), pad_zeros(SEED[2] + z, 4),
                                               pad_zeros(SEED[0] + x, 4),
                                               pad_zeros(SEED[1] + y, 4),
                                               pad_zeros(SEED[2] + z, 4))
                            v = np.fromfile(path, dtype='uint8').reshape(SHAPE)
                            vol[z * SHAPE[0]:z * SHAPE[0] + SHAPE[0],
                                y * SHAPE[1]:y * SHAPE[1] + SHAPE[1],
                                x * SHAPE[2]:x * SHAPE[2] + SHAPE[2]] = v
        else:
            data = np.load(
                '/gpfs/data/tserre/data/connectomics/datasets/berson_0.npz')
            vol = data['volume'][:model_shape[0]]
            SEED = [99, 99, 99]
            mpath = MEM_STR % (pad_zeros(SEED[0], 4), pad_zeros(
                SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
                    SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4))
            rdirs(SEED, MEM_STR)
        print('seed: %s' % SEED)
        print('mpath: %s' % mpath)
        vol = vol.astype(np.float32) / 255.

        # 2. Predict its membranes
        membranes = fgru.main(test=vol,
                              evaluate=True,
                              adabn=True,
                              test_input_shape=np.concatenate(
                                  (model_shape, [1])).tolist(),
                              test_label_shape=np.concatenate(
                                  (model_shape, [12])).tolist(),
                              checkpoint=MEMBRANE_CKPT)

        # 3. Concat the volume w/ membranes and pass to FFN
        if MEMBRANE_TYPE == 'probability':
            print 'Membrane: %s' % MEMBRANE_TYPE
            proc_membrane = (
                membranes[0, :, :, :, :3].mean(-1)).transpose(FFN_TRANSPOSE)
        elif MEMBRANE_TYPE == 'threshold':
            print 'Membrane: %s' % MEMBRANE_TYPE
            proc_membrane = (membranes[0, :, :, :, :3].mean(-1) >
                             0.5).astype(int).transpose(FFN_TRANSPOSE)
        else:
            raise NotImplementedError
        vol = vol.transpose(FFN_TRANSPOSE) * 255.
        membranes = np.stack((vol, proc_membrane), axis=-1)
        np.save(mpath, membranes)
        print 'Saved membrane volume to %s' % mpath

    # 4. Start FFN
    if validate:
        SEED = [99, 99, 99]
    seg_dir = 'ding_segmentations/x%s/y%s/z%s/v%s/' % (pad_zeros(
        SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4), idx)
    print 'Saving segmentations to: %s' % seg_dir
    if idx == 0:
        seed_policy = 'PolicyMembrane'  # 'PolicyPeaks'
    else:
        seed_policy = 'PolicyMembraneShuffle'  # 'PolicyPeaks'  # 'ShufflePolicyPeaks'
    config = '''image {hdf5: "%s"}
        image_mean: 128
        image_stddev: 33
        seed_policy: "%s"
        model_checkpoint_path: "/gpfs/data/tserre/data/connectomics/checkpoints/feedback_hgru_v5_3l_notemp_f_berson2x_w_memb_r0/model.ckpt-44450"
        model_name: "feedback_hgru_v5_3l_notemp_f.ConvStack3DFFNModel"
        model_args: "{\\"depth\\": 12, \\"fov_size\\": [57, 57, 13], \\"deltas\\": [8, 8, 3]}"
        segmentation_output_dir: "%s"
        inference_options {
            init_activation: 0.95
            pad_value: 0.05
            move_threshold: 0.5
            min_boundary_dist { x: 1 y: 1 z: 1}
            segment_threshold: 0.6
            min_segment_size: 4096
        }''' % (mpath, seed_policy, seg_dir)

    req = inference_pb2.InferenceRequest()
    _ = text_format.Parse(config, req)
    runner = inference.Runner()
    runner.start(req, tag='_inference')
    runner.run((0, 0, 0), (model_shape[0], model_shape[1], model_shape[2]))
Ejemplo n.º 4
0
def main(unused_argv):
    move_threshold = FLAGS.move_threshold
    fov_size = dict(zip(["z", "y", "x"], [int(i) for i in FLAGS.fov_size]))
    deltas = dict(zip(["z", "y", "x"], [int(i) for i in FLAGS.deltas]))
    min_boundary_dist = dict(zip(["z", "y", "x"],
                                 [int(i) for i in FLAGS.min_boundary_dist]))
    model_uid = "lr{learning_rate}depth{depth}fov{fov}" \
        .format(learning_rate=FLAGS.lr,
                depth=FLAGS.depth,
                fov=max(fov_size.values()),
                )
    segmentation_output_dir = os.path.join(
        os.getcwd(),
        FLAGS.out_dir + model_uid + "mt" + str(move_threshold) + "policy" +
        FLAGS.seed_policy
    )
    model_checkpoint_path = "{train_dir}/{model_uid}/model.ckpt-{ckpt_id}"\
        .format(train_dir=FLAGS.train_dir,
                model_uid=model_uid,
                ckpt_id=FLAGS.ckpt_id)
    if not gfile.Exists(segmentation_output_dir):
        gfile.MakeDirs(segmentation_output_dir)
    else:
        logging.warning(
            "segmentation_output_dir {} already exists; this may cause inference to "
            "terminate without running.".format(segmentation_output_dir))

    with tempfile.TemporaryDirectory(dir=segmentation_output_dir) as tmpdir:

        # Create a temporary local copy of the HDF5 image, because simulataneous access
        # to HDF5 files is not allowed (only recommended for small files).

        temp_image = copy_file_to_tempdir(FLAGS.image, tmpdir)

        inference_config = InferenceConfig(
            image=temp_image,
            fov_size=fov_size,
            deltas=deltas,
            depth=FLAGS.depth,
            image_mean=FLAGS.image_mean,
            image_stddev=FLAGS.image_stddev,
            model_checkpoint_path=model_checkpoint_path,
            model_name=FLAGS.model_name,
            segmentation_output_dir=segmentation_output_dir,
            move_threshold=move_threshold,
            min_segment_size=FLAGS.min_segment_size,
            segment_threshold=FLAGS.segment_threshold,
            min_boundary_dist=min_boundary_dist,
            seed_policy=FLAGS.seed_policy
        )
        config = inference_config.to_string()
        logging.info(config)
        req = inference_pb2.InferenceRequest()
        _ = text_format.Parse(config, req)


        bbox = bounding_box_pb2.BoundingBox()
        text_format.Parse(FLAGS.bounding_box, bbox)
        runner = inference.Runner()
        runner.start(req)

        start_zyx = (bbox.start.z, bbox.start.y, bbox.start.x)
        size_zyx = (bbox.size.z, bbox.size.y, bbox.size.x)
        logging.info("Running; start at {} size {}.".format(start_zyx, size_zyx))

        # Segmentation is attempted from all valid starting points provided by the seed
        # policy by calling runner.canvas.segment_all().
        runner.run(start_zyx,
                   size_zyx,
                   allow_overlapping_segmentation=True,
                   # reset_seed_per_segment=False,
                   # keep_history=True  # this only keeps seed history; not that useful
                   )
        logging.info("Finished running.")

        counter_path = os.path.join(inference_config.segmentation_output_dir, 'counters.txt')
        if not gfile.Exists(counter_path):
            runner.counters.dump(counter_path)

        runner.stop_executor()
        del runner