コード例 #1
0
def main(unused_argv):
    request = inference_flags.request_from_flags()

    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    runner = inference.Runner()
    runner.start(request, with_membrane=FLAGS.with_membrane)
    print('>>>>>>>>>>>>>>>>> FAKE RUN')
    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x),
               with_membrane=FLAGS.with_membrane,
               fake=True)
    print('>>>>>>>>>>>>>>>>> REAL RUN')
    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x),
               with_membrane=FLAGS.with_membrane)

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.Exists(counter_path):
        runner.counters.dump(counter_path)
コード例 #2
0
def main(unused_argv):
    print('log_level', FLAGS.verbose)
    if FLAGS.verbose:
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
    else:
        logger = logging.getLogger()
        logger.setLevel(logging.WARNING)

    start_time = time.time()
    request = inference_flags.request_from_flags()
    if mpi_rank == 0:
        if not gfile.exists(request.segmentation_output_dir):
            gfile.makedirs(request.segmentation_output_dir)
        if FLAGS.output_dir is None:
            root_output_dir = request.segmentation_output_dir
        else:
            root_output_dir = FLAGS.output_dir

        bbox = bounding_box_pb2.BoundingBox()
        text_format.Parse(FLAGS.bounding_box, bbox)

        subvolume_size = np.array([int(i) for i in FLAGS.subvolume_size])
        overlap = np.array([int(i) for i in FLAGS.overlap])
        sub_bboxes = divide_bounding_box(bbox, subvolume_size, overlap)
        if FLAGS.resume:
            sub_bboxes = find_unfinished(sub_bboxes, root_output_dir)

        sub_bboxes = np.array_split(np.array(sub_bboxes), mpi_size)
    else:
        sub_bboxes = None
        root_output_dir = None

    sub_bboxes = mpi_comm.scatter(sub_bboxes, 0)
    root_output_dir = mpi_comm.bcast(root_output_dir, 0)

    for sub_bbox in sub_bboxes:
        out_name = 'seg-%d_%d_%d_%d_%d_%d' % (
            sub_bbox.start[0], sub_bbox.start[1], sub_bbox.start[2],
            sub_bbox.size[0], sub_bbox.size[1], sub_bbox.size[2])
        segmentation_output_dir = os.path.join(root_output_dir, out_name)
        request.segmentation_output_dir = segmentation_output_dir
        if FLAGS.num_gpu > 0:
            use_gpu = str(mpi_rank % FLAGS.num_gpu)
        else:
            use_gpu = ''
        runner = inference.Runner(use_cpu=FLAGS.use_cpu, use_gpu=use_gpu)
        cube_start_time = (time.time() - start_time) / 60
        runner.start(request)
        runner.run(sub_bbox.start[::-1], sub_bbox.size[::-1])
        cube_finish_time = (time.time() - start_time) / 60
        print('%s finished in %s min' %
              (out_name, cube_finish_time - cube_start_time))
        runner.stop_executor()

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters_%d.txt' % mpi_rank)
    if not gfile.exists(counter_path):
        runner.counters.dump(counter_path)
コード例 #3
0
def h5_sequential_chunk_generator_v2(data_volumes,
                                     chunk_shape=(64, 64, 32),
                                     overlap=(0, 0, 0),
                                     bounding_box=None,
                                     var_threshold=10,
                                     data_axes='zyx'):
    '''Sequentially generate chunks(with overlap) from volumes'''
    image_volume_map = {}
    for vol in data_volumes.split(','):
        volname, path, dataset = vol.split(':')
        image_volume_map[volname] = h5py.File(path, 'r')[dataset]

    chunk_shape = np.array(chunk_shape)
    chunk_offset = chunk_shape // 2
    overlap = np.array(overlap)
    if data_axes == 'zyx':
        chunk_shape = chunk_shape[::-1]
        overlap = overlap[::-1]

    if bounding_box:
        sample_bbox = bounding_box_pb2.BoundingBox()
        text_format.Parse(bounding_box, sample_bbox)
        sample_start = geom_utils.ToNumpy3Vector(sample_bbox.start)
        sample_size = geom_utils.ToNumpy3Vector(sample_bbox.size)
        print(sample_start, sample_size)
    else:
        sample_start = None
        sample_size = None

    def gen():
        for key, val in image_volume_map.items():
            data_shape = np.array(val.shape)
            step_shape = chunk_shape - overlap
            step_counts = (data_shape - 1) // step_shape + 1
            pad_start = overlap // 2
            # pad zeros at end to ensure modular zero
            pad_end = step_counts * step_shape + pad_start - data_shape
            grid_zyx = [
                np.arange(j) * i + d // 2
                for i, j, d in zip(step_shape, step_counts, chunk_shape)
            ]
            grid = np.array(np.meshgrid(*grid_zyx)).T.reshape(-1, 3)

            for i in range(grid.shape[0]):
                center = grid[i]
                image = _load_from_numpylike_with_pad(center, val, pad_start,
                                                      pad_end, chunk_shape,
                                                      sample_start,
                                                      sample_size)

                if image is not None:
                    if np.var(image[...]) > var_threshold:
                        yield (center, image)
                    else:
                        logging.info('skipped chunk %s', str(center))

    return gen
コード例 #4
0
ファイル: run_inference_at.py プロジェクト: ravescovi/ffn
def main(unused_argv):
    request = inference_flags.request_from_flags()

    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    # start_pos = tuple([int(i) for i in FLAGS.start_pos])
    runner = inference.Runner()

    corner = (bbox.start.z, bbox.start.y, bbox.start.x)
    subvol_size = (bbox.size.z, bbox.size.y, bbox.size.x)
    start_pos = tuple([int(i) for i in FLAGS.start_pos])

    seg_path = storage.segmentation_path(request.segmentation_output_dir,
                                         corner)
    prob_path = storage.object_prob_path(request.segmentation_output_dir,
                                         corner)

    runner.start(request)
    canvas, alignment = runner.make_canvas(corner, subvol_size)
    num_iter = canvas.segment_at(start_pos)

    print('>>', num_iter)

    sel = [
        slice(max(s, 0), e + 1)
        for s, e in zip(canvas._min_pos -
                        canvas._pred_size // 2, canvas._max_pos +
                        canvas._pred_size // 2)
    ]
    mask = canvas.seed[sel] >= canvas.options.segment_threshold
    raw_segmented_voxels = np.sum(mask)

    mask &= canvas.segmentation[sel] <= 0
    actual_segmented_voxels = np.sum(mask)
    canvas._max_id += 1
    canvas.segmentation[sel][mask] = canvas._max_id
    canvas.seg_prob[sel][mask] = storage.quantize_probability(
        expit(canvas.seed[sel][mask]))

    runner.save_segmentation(canvas, alignment, seg_path, prob_path)

    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x))

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.Exists(counter_path):
        runner.counters.dump(counter_path)
コード例 #5
0
def main(unused_argv):
  logger = logging.getLogger()
  if FLAGS.verbose:
    logger.setLevel(logging.DEBUG)
  else:
    logger.setLevel(logging.WARNING)
  start_time = time.time()
  # mpi version
  request = inference_flags.request_from_flags()
  if mpi_rank == 0:
    if not gfile.Exists(request.segmentation_output_dir):
      gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    subvolume_size = np.array([int(i) for i in FLAGS.subvolume_size])
    overlap = np.array([int(i) for i in FLAGS.overlap])
    sub_bboxes = divide_bounding_box(bbox, subvolume_size, overlap)
    sub_bboxes = np.array_split(np.array(sub_bboxes), mpi_size)
    root_output_dir = request.segmentation_output_dir
  else:
    sub_bboxes = None
    root_output_dir = None
  
  sub_bboxes = mpi_comm.scatter(sub_bboxes, 0)
  root_output_dir = mpi_comm.bcast(root_output_dir, 0)
  print('rank %d, bbox: %s' % (mpi_rank, len(sub_bboxes)))
  print(sub_bboxes)
  
  for sub_bbox in sub_bboxes:
    out_name = 'seg-%d_%d_%d_%d_%d_%d' % (
      sub_bbox.start[0], sub_bbox.start[1], sub_bbox.start[2], 
      sub_bbox.size[0], sub_bbox.size[1], sub_bbox.size[2])
    segmentation_output_dir = os.path.join(root_output_dir, out_name)
    request.segmentation_output_dir = segmentation_output_dir
    if FLAGS.num_gpu > 0:
      use_gpu = str(mpi_rank % FLAGS.num_gpu)
    else:
      use_gpu = ''
    runner = inference.Runner(use_cpu=FLAGS.use_cpu, use_gpu=use_gpu)
    runner.start(request)
    runner.run(sub_bbox.start[::-1], sub_bbox.size[::-1])
    runner.stop_executor()
  mpi_comm.barrier()
  sys.exit()
コード例 #6
0
ファイル: run_inference.py プロジェクト: JennyZhen95/ffn
def main(unused_argv):
    request = inference_flags.request_from_flags()

    if not gfile.exists(request.segmentation_output_dir):
        gfile.makedirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    runner = inference.Runner()
    runner.start(request)
    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x))

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.exists(counter_path):
        runner.counters.dump(counter_path)
コード例 #7
0
ファイル: run_inference.py プロジェクト: malei-pku/ffn-tracer
def main(unused_argv):
    move_threshold = FLAGS.move_threshold
    fov_size = dict(zip(["z", "y", "x"], [int(i) for i in FLAGS.fov_size]))
    deltas = dict(zip(["z", "y", "x"], [int(i) for i in FLAGS.deltas]))
    min_boundary_dist = dict(zip(["z", "y", "x"],
                                 [int(i) for i in FLAGS.min_boundary_dist]))
    model_uid = "lr{learning_rate}depth{depth}fov{fov}" \
        .format(learning_rate=FLAGS.lr,
                depth=FLAGS.depth,
                fov=max(fov_size.values()),
                )
    segmentation_output_dir = os.path.join(
        os.getcwd(),
        FLAGS.out_dir + model_uid + "mt" + str(move_threshold) + "policy" +
        FLAGS.seed_policy
    )
    model_checkpoint_path = "{train_dir}/{model_uid}/model.ckpt-{ckpt_id}"\
        .format(train_dir=FLAGS.train_dir,
                model_uid=model_uid,
                ckpt_id=FLAGS.ckpt_id)
    if not gfile.Exists(segmentation_output_dir):
        gfile.MakeDirs(segmentation_output_dir)
    else:
        logging.warning(
            "segmentation_output_dir {} already exists; this may cause inference to "
            "terminate without running.".format(segmentation_output_dir))

    with tempfile.TemporaryDirectory(dir=segmentation_output_dir) as tmpdir:

        # Create a temporary local copy of the HDF5 image, because simulataneous access
        # to HDF5 files is not allowed (only recommended for small files).

        temp_image = copy_file_to_tempdir(FLAGS.image, tmpdir)

        inference_config = InferenceConfig(
            image=temp_image,
            fov_size=fov_size,
            deltas=deltas,
            depth=FLAGS.depth,
            image_mean=FLAGS.image_mean,
            image_stddev=FLAGS.image_stddev,
            model_checkpoint_path=model_checkpoint_path,
            model_name=FLAGS.model_name,
            segmentation_output_dir=segmentation_output_dir,
            move_threshold=move_threshold,
            min_segment_size=FLAGS.min_segment_size,
            segment_threshold=FLAGS.segment_threshold,
            min_boundary_dist=min_boundary_dist,
            seed_policy=FLAGS.seed_policy
        )
        config = inference_config.to_string()
        logging.info(config)
        req = inference_pb2.InferenceRequest()
        _ = text_format.Parse(config, req)


        bbox = bounding_box_pb2.BoundingBox()
        text_format.Parse(FLAGS.bounding_box, bbox)
        runner = inference.Runner()
        runner.start(req)

        start_zyx = (bbox.start.z, bbox.start.y, bbox.start.x)
        size_zyx = (bbox.size.z, bbox.size.y, bbox.size.x)
        logging.info("Running; start at {} size {}.".format(start_zyx, size_zyx))

        # Segmentation is attempted from all valid starting points provided by the seed
        # policy by calling runner.canvas.segment_all().
        runner.run(start_zyx,
                   size_zyx,
                   allow_overlapping_segmentation=True,
                   # reset_seed_per_segment=False,
                   # keep_history=True  # this only keeps seed history; not that useful
                   )
        logging.info("Finished running.")

        counter_path = os.path.join(inference_config.segmentation_output_dir, 'counters.txt')
        if not gfile.Exists(counter_path):
            runner.counters.dump(counter_path)

        runner.stop_executor()
        del runner
コード例 #8
0
def main(unused_argv):
    request = inference_flags.request_from_flags()
    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    # Training
    import os
    batch_size = 16
    max_steps = 3000  #10*250/batch_size #250
    hdf_dir = os.path.split(request.image.hdf5)[0]
    load_ckpt_path = request.model_checkpoint_path
    save_ckpt_path = os.path.split(
        load_ckpt_path)[0] + '_topup_' + os.path.split(
            os.path.split(hdf_dir)[0])[1]
    # import ipdb;ipdb.set_trace()
    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            # SET UP TRAIN MODEL
            print('>>>>>>>>>>>>>>>>>>>>>>SET UP TRAIN MODEL')

            TA = train_functional.TrainArgs(
                train_coords=os.path.join(hdf_dir, 'tf_record_file'),
                data_volumes='jk:' +
                os.path.join(hdf_dir, 'grayscale_maps.h5') + ':raw',
                label_volumes='jk:' + os.path.join(hdf_dir, 'groundtruth.h5') +
                ':stack',
                train_dir=save_ckpt_path,
                model_name=request.model_name,
                model_args=request.model_args,
                image_mean=request.image_mean,
                image_stddev=request.image_stddev,
                max_steps=max_steps,
                optimizer='adam',
                load_from_ckpt=load_ckpt_path,
                batch_size=batch_size)
            global TA
            model_class = import_symbol(TA.model_name)
            seed = int(time.time() + TA.task * 3600 * 24)
            logging.info('Random seed: %r', seed)
            random.seed(seed)
            eval_tracker, model, secs, load_data_ops, summary_writer, merge_summaries_op = \
                        build_train_graph(model_class, TA,
                                          save_ckpt=False, with_membrane=TA.with_membrane, **json.loads(TA.model_args))

            # SET UP INFERENCE MODEL
            print('>>>>>>>>>>>>>>>>>>>>>>SET UP INFERENCE MODEL')
            print('>>>>>>>>>>>>>>>>>>>>>>COUNTED %s VARIABLES PRE-INFERENCE' %
                  len(tf.trainable_variables()))
            runner = inference.Runner()
            runner.start(request,
                         batch_size=1,
                         topup={'train_dir': FLAGS.train_dir},
                         reuse=tf.AUTO_REUSE,
                         tag='_inference')  #TAKES SESSION
            print('>>>>>>>>>>>>>>>>>>>>>>COUNTED %s VARIABLES POST-INFERENCE' %
                  len(tf.trainable_variables()))

            # START TRAINING
            print('>>>>>>>>>>>>>>>>>>>>>>START TOPUP TRAINING')
            sess = train_functional.train_ffn(TA, eval_tracker, model,
                                              runner.session, load_data_ops,
                                              summary_writer,
                                              merge_summaries_op)

            # saver.save(sess, "/tmp/model.ckpt")

            # START INFERENCE
            print('>>>>>>>>>>>>>>>>>>>>>>START INFERENCE')
            # saver.restore(sess, "/tmp/model.ckpt")
            runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
                       (bbox.size.z, bbox.size.y, bbox.size.x))

            counter_path = os.path.join(request.segmentation_output_dir,
                                        'counters.txt')
            if not gfile.Exists(counter_path):
                runner.counters.dump(counter_path)

            sess.close()