Exemple #1
0
def main(unused_argv):
    request = inference_flags.request_from_flags()

    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    runner = inference.Runner()
    runner.start(request, with_membrane=FLAGS.with_membrane)
    print('>>>>>>>>>>>>>>>>> FAKE RUN')
    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x),
               with_membrane=FLAGS.with_membrane,
               fake=True)
    print('>>>>>>>>>>>>>>>>> REAL RUN')
    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x),
               with_membrane=FLAGS.with_membrane)

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.Exists(counter_path):
        runner.counters.dump(counter_path)
Exemple #2
0
def main(unused_argv):
    print('log_level', FLAGS.verbose)
    if FLAGS.verbose:
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
    else:
        logger = logging.getLogger()
        logger.setLevel(logging.WARNING)

    start_time = time.time()
    request = inference_flags.request_from_flags()
    if mpi_rank == 0:
        if not gfile.exists(request.segmentation_output_dir):
            gfile.makedirs(request.segmentation_output_dir)
        if FLAGS.output_dir is None:
            root_output_dir = request.segmentation_output_dir
        else:
            root_output_dir = FLAGS.output_dir

        bbox = bounding_box_pb2.BoundingBox()
        text_format.Parse(FLAGS.bounding_box, bbox)

        subvolume_size = np.array([int(i) for i in FLAGS.subvolume_size])
        overlap = np.array([int(i) for i in FLAGS.overlap])
        sub_bboxes = divide_bounding_box(bbox, subvolume_size, overlap)
        if FLAGS.resume:
            sub_bboxes = find_unfinished(sub_bboxes, root_output_dir)

        sub_bboxes = np.array_split(np.array(sub_bboxes), mpi_size)
    else:
        sub_bboxes = None
        root_output_dir = None

    sub_bboxes = mpi_comm.scatter(sub_bboxes, 0)
    root_output_dir = mpi_comm.bcast(root_output_dir, 0)

    for sub_bbox in sub_bboxes:
        out_name = 'seg-%d_%d_%d_%d_%d_%d' % (
            sub_bbox.start[0], sub_bbox.start[1], sub_bbox.start[2],
            sub_bbox.size[0], sub_bbox.size[1], sub_bbox.size[2])
        segmentation_output_dir = os.path.join(root_output_dir, out_name)
        request.segmentation_output_dir = segmentation_output_dir
        if FLAGS.num_gpu > 0:
            use_gpu = str(mpi_rank % FLAGS.num_gpu)
        else:
            use_gpu = ''
        runner = inference.Runner(use_cpu=FLAGS.use_cpu, use_gpu=use_gpu)
        cube_start_time = (time.time() - start_time) / 60
        runner.start(request)
        runner.run(sub_bbox.start[::-1], sub_bbox.size[::-1])
        cube_finish_time = (time.time() - start_time) / 60
        print('%s finished in %s min' %
              (out_name, cube_finish_time - cube_start_time))
        runner.stop_executor()

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters_%d.txt' % mpi_rank)
    if not gfile.exists(counter_path):
        runner.counters.dump(counter_path)
Exemple #3
0
def main(unused_argv):
    request = inference_flags.request_from_flags()

    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    # start_pos = tuple([int(i) for i in FLAGS.start_pos])
    runner = inference.Runner()

    corner = (bbox.start.z, bbox.start.y, bbox.start.x)
    subvol_size = (bbox.size.z, bbox.size.y, bbox.size.x)
    start_pos = tuple([int(i) for i in FLAGS.start_pos])

    seg_path = storage.segmentation_path(request.segmentation_output_dir,
                                         corner)
    prob_path = storage.object_prob_path(request.segmentation_output_dir,
                                         corner)

    runner.start(request)
    canvas, alignment = runner.make_canvas(corner, subvol_size)
    num_iter = canvas.segment_at(start_pos)

    print('>>', num_iter)

    sel = [
        slice(max(s, 0), e + 1)
        for s, e in zip(canvas._min_pos -
                        canvas._pred_size // 2, canvas._max_pos +
                        canvas._pred_size // 2)
    ]
    mask = canvas.seed[sel] >= canvas.options.segment_threshold
    raw_segmented_voxels = np.sum(mask)

    mask &= canvas.segmentation[sel] <= 0
    actual_segmented_voxels = np.sum(mask)
    canvas._max_id += 1
    canvas.segmentation[sel][mask] = canvas._max_id
    canvas.seg_prob[sel][mask] = storage.quantize_probability(
        expit(canvas.seed[sel][mask]))

    runner.save_segmentation(canvas, alignment, seg_path, prob_path)

    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x))

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.Exists(counter_path):
        runner.counters.dump(counter_path)
def main(unused_argv):
  logger = logging.getLogger()
  if FLAGS.verbose:
    logger.setLevel(logging.DEBUG)
  else:
    logger.setLevel(logging.WARNING)
  start_time = time.time()
  # mpi version
  request = inference_flags.request_from_flags()
  if mpi_rank == 0:
    if not gfile.Exists(request.segmentation_output_dir):
      gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    subvolume_size = np.array([int(i) for i in FLAGS.subvolume_size])
    overlap = np.array([int(i) for i in FLAGS.overlap])
    sub_bboxes = divide_bounding_box(bbox, subvolume_size, overlap)
    sub_bboxes = np.array_split(np.array(sub_bboxes), mpi_size)
    root_output_dir = request.segmentation_output_dir
  else:
    sub_bboxes = None
    root_output_dir = None
  
  sub_bboxes = mpi_comm.scatter(sub_bboxes, 0)
  root_output_dir = mpi_comm.bcast(root_output_dir, 0)
  print('rank %d, bbox: %s' % (mpi_rank, len(sub_bboxes)))
  print(sub_bboxes)
  
  for sub_bbox in sub_bboxes:
    out_name = 'seg-%d_%d_%d_%d_%d_%d' % (
      sub_bbox.start[0], sub_bbox.start[1], sub_bbox.start[2], 
      sub_bbox.size[0], sub_bbox.size[1], sub_bbox.size[2])
    segmentation_output_dir = os.path.join(root_output_dir, out_name)
    request.segmentation_output_dir = segmentation_output_dir
    if FLAGS.num_gpu > 0:
      use_gpu = str(mpi_rank % FLAGS.num_gpu)
    else:
      use_gpu = ''
    runner = inference.Runner(use_cpu=FLAGS.use_cpu, use_gpu=use_gpu)
    runner.start(request)
    runner.run(sub_bbox.start[::-1], sub_bbox.size[::-1])
    runner.stop_executor()
  mpi_comm.barrier()
  sys.exit()
Exemple #5
0
def main(unused_argv):
    request = inference_flags.request_from_flags()

    if not gfile.exists(request.segmentation_output_dir):
        gfile.makedirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    runner = inference.Runner()
    runner.start(request)
    runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
               (bbox.size.z, bbox.size.y, bbox.size.x))

    counter_path = os.path.join(request.segmentation_output_dir,
                                'counters.txt')
    if not gfile.exists(counter_path):
        runner.counters.dump(counter_path)
Exemple #6
0
def main(unused_argv):
    request = inference_flags.request_from_flags()
    if not gfile.Exists(request.segmentation_output_dir):
        gfile.MakeDirs(request.segmentation_output_dir)

    bbox = bounding_box_pb2.BoundingBox()
    text_format.Parse(FLAGS.bounding_box, bbox)

    # Training
    import os
    batch_size = 16
    max_steps = 3000  #10*250/batch_size #250
    hdf_dir = os.path.split(request.image.hdf5)[0]
    load_ckpt_path = request.model_checkpoint_path
    save_ckpt_path = os.path.split(
        load_ckpt_path)[0] + '_topup_' + os.path.split(
            os.path.split(hdf_dir)[0])[1]
    # import ipdb;ipdb.set_trace()
    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            # SET UP TRAIN MODEL
            print('>>>>>>>>>>>>>>>>>>>>>>SET UP TRAIN MODEL')

            TA = train_functional.TrainArgs(
                train_coords=os.path.join(hdf_dir, 'tf_record_file'),
                data_volumes='jk:' +
                os.path.join(hdf_dir, 'grayscale_maps.h5') + ':raw',
                label_volumes='jk:' + os.path.join(hdf_dir, 'groundtruth.h5') +
                ':stack',
                train_dir=save_ckpt_path,
                model_name=request.model_name,
                model_args=request.model_args,
                image_mean=request.image_mean,
                image_stddev=request.image_stddev,
                max_steps=max_steps,
                optimizer='adam',
                load_from_ckpt=load_ckpt_path,
                batch_size=batch_size)
            global TA
            model_class = import_symbol(TA.model_name)
            seed = int(time.time() + TA.task * 3600 * 24)
            logging.info('Random seed: %r', seed)
            random.seed(seed)
            eval_tracker, model, secs, load_data_ops, summary_writer, merge_summaries_op = \
                        build_train_graph(model_class, TA,
                                          save_ckpt=False, with_membrane=TA.with_membrane, **json.loads(TA.model_args))

            # SET UP INFERENCE MODEL
            print('>>>>>>>>>>>>>>>>>>>>>>SET UP INFERENCE MODEL')
            print('>>>>>>>>>>>>>>>>>>>>>>COUNTED %s VARIABLES PRE-INFERENCE' %
                  len(tf.trainable_variables()))
            runner = inference.Runner()
            runner.start(request,
                         batch_size=1,
                         topup={'train_dir': FLAGS.train_dir},
                         reuse=tf.AUTO_REUSE,
                         tag='_inference')  #TAKES SESSION
            print('>>>>>>>>>>>>>>>>>>>>>>COUNTED %s VARIABLES POST-INFERENCE' %
                  len(tf.trainable_variables()))

            # START TRAINING
            print('>>>>>>>>>>>>>>>>>>>>>>START TOPUP TRAINING')
            sess = train_functional.train_ffn(TA, eval_tracker, model,
                                              runner.session, load_data_ops,
                                              summary_writer,
                                              merge_summaries_op)

            # saver.save(sess, "/tmp/model.ckpt")

            # START INFERENCE
            print('>>>>>>>>>>>>>>>>>>>>>>START INFERENCE')
            # saver.restore(sess, "/tmp/model.ckpt")
            runner.run((bbox.start.z, bbox.start.y, bbox.start.x),
                       (bbox.size.z, bbox.size.y, bbox.size.x))

            counter_path = os.path.join(request.segmentation_output_dir,
                                        'counters.txt')
            if not gfile.Exists(counter_path):
                runner.counters.dump(counter_path)

            sess.close()