Esempio n. 1
0
def main(unused_argv):
    request = inference_flags.request_from_flags()

    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    runner = inference.Runner()
    runner.start(request, with_membrane=FLAGS.with_membrane)
    print('>>>>>>>>>>>>>>>>> FAKE RUN')
    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x),
               with_membrane=FLAGS.with_membrane,
               fake=True)
    print('>>>>>>>>>>>>>>>>> REAL RUN')
    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x),
               with_membrane=FLAGS.with_membrane)

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.Exists(counter_path):
        runner.counters.dump(counter_path)
Esempio n. 2
0
def main(unused_argv):
    print('log_level', FLAGS.verbose)
    if FLAGS.verbose:
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
    else:
        logger = logging.getLogger()
        logger.setLevel(logging.WARNING)

    start_time = time.time()
    request = inference_flags.request_from_flags()
    if mpi_rank == 0:
        if not gfile.exists(request.segmentation_output_dir):
            gfile.makedirs(request.segmentation_output_dir)
        if FLAGS.output_dir is None:
            root_output_dir = request.segmentation_output_dir
        else:
            root_output_dir = FLAGS.output_dir

        bbox = bounding_box_pb2.BoundingBox()
        text_format.Parse(FLAGS.bounding_box, bbox)

        subvolume_size = np.array([int(i) for i in FLAGS.subvolume_size])
        overlap = np.array([int(i) for i in FLAGS.overlap])
        sub_bboxes = divide_bounding_box(bbox, subvolume_size, overlap)
        if FLAGS.resume:
            sub_bboxes = find_unfinished(sub_bboxes, root_output_dir)

        sub_bboxes = np.array_split(np.array(sub_bboxes), mpi_size)
    else:
        sub_bboxes = None
        root_output_dir = None

    sub_bboxes = mpi_comm.scatter(sub_bboxes, 0)
    root_output_dir = mpi_comm.bcast(root_output_dir, 0)

    for sub_bbox in sub_bboxes:
        out_name = 'seg-%d_%d_%d_%d_%d_%d' % (
            sub_bbox.start[0], sub_bbox.start[1], sub_bbox.start[2],
            sub_bbox.size[0], sub_bbox.size[1], sub_bbox.size[2])
        segmentation_output_dir = os.path.join(root_output_dir, out_name)
        request.segmentation_output_dir = segmentation_output_dir
        if FLAGS.num_gpu > 0:
            use_gpu = str(mpi_rank % FLAGS.num_gpu)
        else:
            use_gpu = ''
        runner = inference.Runner(use_cpu=FLAGS.use_cpu, use_gpu=use_gpu)
        cube_start_time = (time.time() - start_time) / 60
        runner.start(request)
        runner.run(sub_bbox.start[::-1], sub_bbox.size[::-1])
        cube_finish_time = (time.time() - start_time) / 60
        print('%s finished in %s min' %
              (out_name, cube_finish_time - cube_start_time))
        runner.stop_executor()

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters_%d.txt' % mpi_rank)
    if not gfile.exists(counter_path):
        runner.counters.dump(counter_path)
Esempio n. 3
0
def main(unused_argv):
    request = inference_flags.request_from_flags()

    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    # start_pos = tuple([int(i) for i in FLAGS.start_pos])
    runner = inference.Runner()

    corner = (bbox.start.z, bbox.start.y, bbox.start.x)
    subvol_size = (bbox.size.z, bbox.size.y, bbox.size.x)
    start_pos = tuple([int(i) for i in FLAGS.start_pos])

    seg_path = storage.segmentation_path(request.segmentation_output_dir,
                                         corner)
    prob_path = storage.object_prob_path(request.segmentation_output_dir,
                                         corner)

    runner.start(request)
    canvas, alignment = runner.make_canvas(corner, subvol_size)
    num_iter = canvas.segment_at(start_pos)

    print('>>', num_iter)

    sel = [
        slice(max(s, 0), e + 1)
        for s, e in zip(canvas._min_pos -
                        canvas._pred_size // 2, canvas._max_pos +
                        canvas._pred_size // 2)
    ]
    mask = canvas.seed[sel] >= canvas.options.segment_threshold
    raw_segmented_voxels = np.sum(mask)

    mask &= canvas.segmentation[sel] <= 0
    actual_segmented_voxels = np.sum(mask)
    canvas._max_id += 1
    canvas.segmentation[sel][mask] = canvas._max_id
    canvas.seg_prob[sel][mask] = storage.quantize_probability(
        expit(canvas.seed[sel][mask]))

    runner.save_segmentation(canvas, alignment, seg_path, prob_path)

    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x))

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.Exists(counter_path):
        runner.counters.dump(counter_path)
Esempio n. 4
0
def main(unused_argv):
  logger = logging.getLogger()
  if FLAGS.verbose:
    logger.setLevel(logging.DEBUG)
  else:
    logger.setLevel(logging.WARNING)
  start_time = time.time()
  # mpi version
  request = inference_flags.request_from_flags()
  if mpi_rank == 0:
    if not gfile.Exists(request.segmentation_output_dir):
      gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    subvolume_size = np.array([int(i) for i in FLAGS.subvolume_size])
    overlap = np.array([int(i) for i in FLAGS.overlap])
    sub_bboxes = divide_bounding_box(bbox, subvolume_size, overlap)
    sub_bboxes = np.array_split(np.array(sub_bboxes), mpi_size)
    root_output_dir = request.segmentation_output_dir
  else:
    sub_bboxes = None
    root_output_dir = None
  
  sub_bboxes = mpi_comm.scatter(sub_bboxes, 0)
  root_output_dir = mpi_comm.bcast(root_output_dir, 0)
  print('rank %d, bbox: %s' % (mpi_rank, len(sub_bboxes)))
  print(sub_bboxes)
  
  for sub_bbox in sub_bboxes:
    out_name = 'seg-%d_%d_%d_%d_%d_%d' % (
      sub_bbox.start[0], sub_bbox.start[1], sub_bbox.start[2], 
      sub_bbox.size[0], sub_bbox.size[1], sub_bbox.size[2])
    segmentation_output_dir = os.path.join(root_output_dir, out_name)
    request.segmentation_output_dir = segmentation_output_dir
    if FLAGS.num_gpu > 0:
      use_gpu = str(mpi_rank % FLAGS.num_gpu)
    else:
      use_gpu = ''
    runner = inference.Runner(use_cpu=FLAGS.use_cpu, use_gpu=use_gpu)
    runner.start(request)
    runner.run(sub_bbox.start[::-1], sub_bbox.size[::-1])
    runner.stop_executor()
  mpi_comm.barrier()
  sys.exit()
Esempio n. 5
0
def main(unused_argv):
    request = inference_flags.request_from_flags()

    if not gfile.exists(request.segmentation_output_dir):
        gfile.makedirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    runner = inference.Runner()
    runner.start(request)
    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x))

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.exists(counter_path):
        runner.counters.dump(counter_path)
Esempio n. 6
0
def main(unused_argv):

    request = inference_pb2.InferenceRequest()
    with open(FLAGS.parameter_file, mode='r') as f:
        text_list = f.readlines()
    text = ' '.join(text_list)
    text_format.Parse(text, request)

    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)
    runner = inference.Runner()
    runner.start(request)
    #  runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
    #             (bbox.size.z, bbox.size.y, bbox.size.x))
    runner.run((0, 0, 0), (int(FLAGS.image_size_z), int(
        FLAGS.image_size_y), int(FLAGS.image_size_x)))

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.Exists(counter_path):
        runner.counters.dump(counter_path)
Esempio n. 7
0
def main(idx,
         move_threshold=0.7,
         segment_threshold=0.6,
         validate=False,
         seed='15,15,17',
         rotate=False):
    """Apply the FFN routines using fGRUs."""
    SEED = np.array([int(x) for x in seed.split(',')])
    rdirs(SEED, MEM_STR)
    model_shape = (SHAPE * PATH_EXTENT)
    mpath = MEM_STR % (pad_zeros(SEED[0], 4), pad_zeros(
        SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
            SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4))
    if idx == 0:
        # 1. select a volume
        if not validate:
            if np.all(PATH_EXTENT == 1):
                path = PATH_STR % (pad_zeros(SEED[0], 4), pad_zeros(
                    SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
                        SEED[0], 4), pad_zeros(SEED[1],
                                               4), pad_zeros(SEED[2], 4))
                vol = np.fromfile(path, dtype='uint8').reshape(SHAPE)
            else:
                vol = np.zeros((np.array(SHAPE) * PATH_EXTENT))
                for z in range(PATH_EXTENT[0]):
                    for y in range(PATH_EXTENT[1]):
                        for x in range(PATH_EXTENT[2]):
                            path = PATH_STR % (pad_zeros(
                                SEED[0] + x, 4), pad_zeros(
                                    SEED[1] + y, 4), pad_zeros(SEED[2] + z, 4),
                                               pad_zeros(SEED[0] + x, 4),
                                               pad_zeros(SEED[1] + y, 4),
                                               pad_zeros(SEED[2] + z, 4))
                            v = np.fromfile(path, dtype='uint8').reshape(SHAPE)
                            vol[z * SHAPE[0]:z * SHAPE[0] + SHAPE[0],
                                y * SHAPE[1]:y * SHAPE[1] + SHAPE[1],
                                x * SHAPE[2]:x * SHAPE[2] + SHAPE[2]] = v
        else:
            data = np.load(
                '/media/data_cifs/connectomics/datasets/berson_0.npz')
            vol = data['volume'][:model_shape[0]]
            SEED = [99, 99, 99]
            mpath = MEM_STR % (pad_zeros(SEED[0], 4), pad_zeros(
                SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
                    SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4))
            rdirs(SEED, MEM_STR)
        print('seed: %s' % SEED)
        print('mpath: %s' % mpath)
        vol = vol.astype(np.float32) / 255.

        # 2. Predict its membranes
        membranes = fgru.main(test=vol,
                              evaluate=True,
                              adabn=True,
                              test_input_shape=np.concatenate(
                                  (model_shape, [1])).tolist(),
                              test_label_shape=np.concatenate(
                                  (model_shape, [12])).tolist(),
                              checkpoint=MEMBRANE_CKPT)

        # 3. Concat the volume w/ membranes and pass to FFN
        if MEMBRANE_TYPE == 'probability':
            print 'Membrane: %s' % MEMBRANE_TYPE
            proc_membrane = (
                membranes[0, :, :, :, :3].mean(-1)).transpose(FFN_TRANSPOSE)
        elif MEMBRANE_TYPE == 'threshold':
            print 'Membrane: %s' % MEMBRANE_TYPE
            proc_membrane = (membranes[0, :, :, :, :3].mean(-1) >
                             0.5).astype(int).transpose(FFN_TRANSPOSE)
        else:
            raise NotImplementedError
        vol = vol.transpose(FFN_TRANSPOSE)  # ).astype(np.uint8)
        membranes = np.round(np.stack((vol, proc_membrane), axis=-1) *
                             255).astype(np.float32)  # np.uint8)
        if rotate:
            membranes = np.rot90(membranes, k=1, axes=(1, 2))
        np.save(mpath, membranes)
        print 'Saved membrane volume to %s' % mpath
    mpath = '%s.npy' % mpath

    # 4. Start FFN
    ckpt_path = '/media/data_cifs/connectomics/ffn_ckpts/wide_fov/htd_cnn_3l_in_berson3x_w_inf_memb_r0/model.ckpt-1212476'  # model.ckpt-933785
    model = 'htd_cnn_3l_in'
    # ckpt_path = '/media/data_cifs/connectomics/ffn_ckpts/wide_fov/htd_cnn_3l_berson3x_w_inf_memb_r0/model.ckpt-924330'
    # model = 'htd_cnn_3l'
    # ckpt_path = '/media/data_cifs/connectomics/ffn_ckpts/wide_fov/htd_cnn_3l_trainablestat_berson3x_w_inf_memb_r0/model.ckpt-1261571'
    # model = 'htd_cnn_3l_trainablestat'

    if validate:
        SEED = [99, 99, 99]
    deltas = '[14, 14, 3]'  # '[27, 27, 6]'
    seg_dir = 'ding_segmentations/x%s/y%s/z%s/v%s/' % (pad_zeros(
        SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4), idx)
    print 'Saving segmentations to: %s' % seg_dir
    if idx == 0:
        seed_policy = 'PolicyMembrane'  # 'PolicyPeaks'
    else:
        seed_policy = 'ShufflePolicyPeaks'
    seed_policy = 'ShufflePolicyPeaks'  # 'PolicyMembrane'
    config = '''image {hdf5: "%s:raw" }
        image_mean: 128
        image_stddev: 33
        seed_policy: "%s"
        model_checkpoint_path: "%s"
        model_name: "%s.ConvStack3DFFNModel"
        model_args: "{\\"depth\\": 12, \\"fov_size\\": [57, 57, 13], \\"deltas\\": %s}"
        segmentation_output_dir: "%s"
        inference_options {
            init_activation: 0.95
            pad_value: 0.05
            move_threshold: %s
            min_boundary_dist { x: 1 y: 1 z: 1}
            segment_threshold: %s
            min_segment_size: 1000
        }''' % (
        '/media/data_cifs/connectomics/datasets/third_party/wide_fov/berson_w_inf_memb/train/grayscale_maps.h5',
        seed_policy, ckpt_path, model, deltas, seg_dir, move_threshold,
        segment_threshold)

    req = inference_pb2.InferenceRequest()
    _ = text_format.Parse(config, req)
    runner = inference.Runner()
    runner.start(req, tag='_inference')
    runner.run((0, 0, 0), (model_shape[0], model_shape[1], model_shape[2]))
def main(idx, validate=False, seed='15,15,17'):
    """Apply the FFN routines using fGRUs."""
    SEED = np.array([int(x) for x in seed.split(',')])
    rdirs(SEED, MEM_STR)
    model_shape = (SHAPE * PATH_EXTENT)
    mpath = MEM_STR % (pad_zeros(SEED[0], 4), pad_zeros(
        SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
            SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4))
    if idx == 0:
        # 1. select a volume
        if not validate:
            if np.all(PATH_EXTENT == 1):
                path = PATH_STR % (pad_zeros(SEED[0], 4), pad_zeros(
                    SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
                        SEED[0], 4), pad_zeros(SEED[1],
                                               4), pad_zeros(SEED[2], 4))
                vol = np.fromfile(path, dtype='uint8').reshape(SHAPE)
            else:
                vol = np.zeros((np.array(SHAPE) * PATH_EXTENT))
                for z in range(PATH_EXTENT[0]):
                    for y in range(PATH_EXTENT[1]):
                        for x in range(PATH_EXTENT[2]):
                            path = PATH_STR % (pad_zeros(
                                SEED[0] + x, 4), pad_zeros(
                                    SEED[1] + y, 4), pad_zeros(SEED[2] + z, 4),
                                               pad_zeros(SEED[0] + x, 4),
                                               pad_zeros(SEED[1] + y, 4),
                                               pad_zeros(SEED[2] + z, 4))
                            v = np.fromfile(path, dtype='uint8').reshape(SHAPE)
                            vol[z * SHAPE[0]:z * SHAPE[0] + SHAPE[0],
                                y * SHAPE[1]:y * SHAPE[1] + SHAPE[1],
                                x * SHAPE[2]:x * SHAPE[2] + SHAPE[2]] = v
        else:
            data = np.load(
                '/gpfs/data/tserre/data/connectomics/datasets/berson_0.npz')
            vol = data['volume'][:model_shape[0]]
            SEED = [99, 99, 99]
            mpath = MEM_STR % (pad_zeros(SEED[0], 4), pad_zeros(
                SEED[1], 4), pad_zeros(SEED[2], 4), pad_zeros(
                    SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4))
            rdirs(SEED, MEM_STR)
        print('seed: %s' % SEED)
        print('mpath: %s' % mpath)
        vol = vol.astype(np.float32) / 255.

        # 2. Predict its membranes
        membranes = fgru.main(test=vol,
                              evaluate=True,
                              adabn=True,
                              test_input_shape=np.concatenate(
                                  (model_shape, [1])).tolist(),
                              test_label_shape=np.concatenate(
                                  (model_shape, [12])).tolist(),
                              checkpoint=MEMBRANE_CKPT)

        # 3. Concat the volume w/ membranes and pass to FFN
        if MEMBRANE_TYPE == 'probability':
            print 'Membrane: %s' % MEMBRANE_TYPE
            proc_membrane = (
                membranes[0, :, :, :, :3].mean(-1)).transpose(FFN_TRANSPOSE)
        elif MEMBRANE_TYPE == 'threshold':
            print 'Membrane: %s' % MEMBRANE_TYPE
            proc_membrane = (membranes[0, :, :, :, :3].mean(-1) >
                             0.5).astype(int).transpose(FFN_TRANSPOSE)
        else:
            raise NotImplementedError
        vol = vol.transpose(FFN_TRANSPOSE) * 255.
        membranes = np.stack((vol, proc_membrane), axis=-1)
        np.save(mpath, membranes)
        print 'Saved membrane volume to %s' % mpath

    # 4. Start FFN
    if validate:
        SEED = [99, 99, 99]
    seg_dir = 'ding_segmentations/x%s/y%s/z%s/v%s/' % (pad_zeros(
        SEED[0], 4), pad_zeros(SEED[1], 4), pad_zeros(SEED[2], 4), idx)
    print 'Saving segmentations to: %s' % seg_dir
    if idx == 0:
        seed_policy = 'PolicyMembrane'  # 'PolicyPeaks'
    else:
        seed_policy = 'PolicyMembraneShuffle'  # 'PolicyPeaks'  # 'ShufflePolicyPeaks'
    config = '''image {hdf5: "%s"}
        image_mean: 128
        image_stddev: 33
        seed_policy: "%s"
        model_checkpoint_path: "/gpfs/data/tserre/data/connectomics/checkpoints/feedback_hgru_v5_3l_notemp_f_berson2x_w_memb_r0/model.ckpt-44450"
        model_name: "feedback_hgru_v5_3l_notemp_f.ConvStack3DFFNModel"
        model_args: "{\\"depth\\": 12, \\"fov_size\\": [57, 57, 13], \\"deltas\\": [8, 8, 3]}"
        segmentation_output_dir: "%s"
        inference_options {
            init_activation: 0.95
            pad_value: 0.05
            move_threshold: 0.5
            min_boundary_dist { x: 1 y: 1 z: 1}
            segment_threshold: 0.6
            min_segment_size: 4096
        }''' % (mpath, seed_policy, seg_dir)

    req = inference_pb2.InferenceRequest()
    _ = text_format.Parse(config, req)
    runner = inference.Runner()
    runner.start(req, tag='_inference')
    runner.run((0, 0, 0), (model_shape[0], model_shape[1], model_shape[2]))
Esempio n. 9
0
def main(unused_argv):
    move_threshold = FLAGS.move_threshold
    fov_size = dict(zip(["z", "y", "x"], [int(i) for i in FLAGS.fov_size]))
    deltas = dict(zip(["z", "y", "x"], [int(i) for i in FLAGS.deltas]))
    min_boundary_dist = dict(zip(["z", "y", "x"],
                                 [int(i) for i in FLAGS.min_boundary_dist]))
    model_uid = "lr{learning_rate}depth{depth}fov{fov}" \
        .format(learning_rate=FLAGS.lr,
                depth=FLAGS.depth,
                fov=max(fov_size.values()),
                )
    segmentation_output_dir = os.path.join(
        os.getcwd(),
        FLAGS.out_dir + model_uid + "mt" + str(move_threshold) + "policy" +
        FLAGS.seed_policy
    )
    model_checkpoint_path = "{train_dir}/{model_uid}/model.ckpt-{ckpt_id}"\
        .format(train_dir=FLAGS.train_dir,
                model_uid=model_uid,
                ckpt_id=FLAGS.ckpt_id)
    if not gfile.Exists(segmentation_output_dir):
        gfile.MakeDirs(segmentation_output_dir)
    else:
        logging.warning(
            "segmentation_output_dir {} already exists; this may cause inference to "
            "terminate without running.".format(segmentation_output_dir))

    with tempfile.TemporaryDirectory(dir=segmentation_output_dir) as tmpdir:

        # Create a temporary local copy of the HDF5 image, because simulataneous access
        # to HDF5 files is not allowed (only recommended for small files).

        temp_image = copy_file_to_tempdir(FLAGS.image, tmpdir)

        inference_config = InferenceConfig(
            image=temp_image,
            fov_size=fov_size,
            deltas=deltas,
            depth=FLAGS.depth,
            image_mean=FLAGS.image_mean,
            image_stddev=FLAGS.image_stddev,
            model_checkpoint_path=model_checkpoint_path,
            model_name=FLAGS.model_name,
            segmentation_output_dir=segmentation_output_dir,
            move_threshold=move_threshold,
            min_segment_size=FLAGS.min_segment_size,
            segment_threshold=FLAGS.segment_threshold,
            min_boundary_dist=min_boundary_dist,
            seed_policy=FLAGS.seed_policy
        )
        config = inference_config.to_string()
        logging.info(config)
        req = inference_pb2.InferenceRequest()
        _ = text_format.Parse(config, req)


        bbox = bounding_box_pb2.BoundingBox()
        text_format.Parse(FLAGS.bounding_box, bbox)
        runner = inference.Runner()
        runner.start(req)

        start_zyx = (bbox.start.z, bbox.start.y, bbox.start.x)
        size_zyx = (bbox.size.z, bbox.size.y, bbox.size.x)
        logging.info("Running; start at {} size {}.".format(start_zyx, size_zyx))

        # Segmentation is attempted from all valid starting points provided by the seed
        # policy by calling runner.canvas.segment_all().
        runner.run(start_zyx,
                   size_zyx,
                   allow_overlapping_segmentation=True,
                   # reset_seed_per_segment=False,
                   # keep_history=True  # this only keeps seed history; not that useful
                   )
        logging.info("Finished running.")

        counter_path = os.path.join(inference_config.segmentation_output_dir, 'counters.txt')
        if not gfile.Exists(counter_path):
            runner.counters.dump(counter_path)

        runner.stop_executor()
        del runner
Esempio n. 10
0
def main(unused_argv):
    request = inference_flags.request_from_flags()
    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    # Training
    import os
    batch_size = 16
    max_steps = 3000  #10*250/batch_size #250
    hdf_dir = os.path.split(request.image.hdf5)[0]
    load_ckpt_path = request.model_checkpoint_path
    save_ckpt_path = os.path.split(
        load_ckpt_path)[0] + '_topup_' + os.path.split(
            os.path.split(hdf_dir)[0])[1]
    # import ipdb;ipdb.set_trace()
    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            # SET UP TRAIN MODEL
            print('>>>>>>>>>>>>>>>>>>>>>>SET UP TRAIN MODEL')

            TA = train_functional.TrainArgs(
                train_coords=os.path.join(hdf_dir, 'tf_record_file'),
                data_volumes='jk:' +
                os.path.join(hdf_dir, 'grayscale_maps.h5') + ':raw',
                label_volumes='jk:' + os.path.join(hdf_dir, 'groundtruth.h5') +
                ':stack',
                train_dir=save_ckpt_path,
                model_name=request.model_name,
                model_args=request.model_args,
                image_mean=request.image_mean,
                image_stddev=request.image_stddev,
                max_steps=max_steps,
                optimizer='adam',
                load_from_ckpt=load_ckpt_path,
                batch_size=batch_size)
            global TA
            model_class = import_symbol(TA.model_name)
            seed = int(time.time() + TA.task * 3600 * 24)
            logging.info('Random seed: %r', seed)
            random.seed(seed)
            eval_tracker, model, secs, load_data_ops, summary_writer, merge_summaries_op = \
                        build_train_graph(model_class, TA,
                                          save_ckpt=False, with_membrane=TA.with_membrane, **json.loads(TA.model_args))

            # SET UP INFERENCE MODEL
            print('>>>>>>>>>>>>>>>>>>>>>>SET UP INFERENCE MODEL')
            print('>>>>>>>>>>>>>>>>>>>>>>COUNTED %s VARIABLES PRE-INFERENCE' %
                  len(tf.trainable_variables()))
            runner = inference.Runner()
            runner.start(request,
                         batch_size=1,
                         topup={'train_dir': FLAGS.train_dir},
                         reuse=tf.AUTO_REUSE,
                         tag='_inference')  #TAKES SESSION
            print('>>>>>>>>>>>>>>>>>>>>>>COUNTED %s VARIABLES POST-INFERENCE' %
                  len(tf.trainable_variables()))

            # START TRAINING
            print('>>>>>>>>>>>>>>>>>>>>>>START TOPUP TRAINING')
            sess = train_functional.train_ffn(TA, eval_tracker, model,
                                              runner.session, load_data_ops,
                                              summary_writer,
                                              merge_summaries_op)

            # saver.save(sess, "/tmp/model.ckpt")

            # START INFERENCE
            print('>>>>>>>>>>>>>>>>>>>>>>START INFERENCE')
            # saver.restore(sess, "/tmp/model.ckpt")
            runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
                       (bbox.size.z, bbox.size.y, bbox.size.x))

            counter_path = os.path.join(request.segmentation_output_dir,
                                        'counters.txt')
            if not gfile.Exists(counter_path):
                runner.counters.dump(counter_path)

            sess.close()