Пример #1
0
def test_predict_mitoses_num_locations():
    #TODO: change the model file path
    model_file = 'model/0.74172_f1_1.7319_loss_8_epoch_model.hdf5'
    model_name = 'vgg'
    #model_file = 'model/0.72938_f1_0.067459_loss_7_epoch_model.hdf5'
    #model_name = 'resnet'
    marginalization = True
    threshold = 0
    tile_overlap = 0
    batch_size = 128
    ROI = np.asarray(Image.open("data/test/1_11_01_1292_413_0_0_0.png"),
                     dtype=np.int)

    # the expected probability
    base_model = tf.keras.models.load_model(model_file, compile=False)
    probs = tf.keras.layers.Activation('sigmoid',
                                       name="sigmoid")(base_model.output)
    model = tf.keras.models.Model(inputs=base_model.input, outputs=probs)
    norm_ROI = normalize((ROI / 255).astype(dtype=np.float32), model_name)
    prob = model.predict(np.expand_dims(norm_ROI, axis=0))
    print(f"The expected probability: {prob}")

    # the predicted result
    pred_result = predict_mitoses_num_locations(
        model,
        model_name,
        threshold,
        ROI,
        tile_size=64,
        tile_overlap=tile_overlap,
        tile_channel=3,
        batch_size=batch_size,
        marginalization=marginalization)
    print(f"The predicted probability: {pred_result}")
Пример #2
0
 def predict(self, x):
     x = preprocess_image(x)
     norm_patch_batch = normalize((np.array(x) / 255).astype(np.float32),
                                  "resnet_custom")
     out_batch = self.output_tensor.eval(
         feed_dict={self.input_tensor: norm_patch_batch}, session=self.sess)
     return post_process_result(out_batch)
def run_inference(model,
                  sess,
                  batch_size,
                  input_dir_path,
                  output_dir_path,
                  num_parallel_calls=1,
                  prob_thres=0.5,
                  eps=64,
                  min_samples=1,
                  isWeightedAvg=False):

    input_file_paths = [str(f) for f in Path(input_dir_path).glob('*.png')]
    input_files = np.asarray(input_file_paths, dtype=np.str)

    input_file_dataset = tf.data.Dataset.from_tensor_slices(input_files)
    img_dataset = input_file_dataset.map(lambda file: get_image_tf(file),
                                         num_parallel_calls=1)
    img_dataset = img_dataset\
        .map(lambda img: normalize(img, "resnet_custom"))\
        .batch(batch_size=batch_size)
    img_iterator = img_dataset.make_one_shot_iterator()
    next_batch = img_iterator.get_next()

    prob_result = np.empty((0, 1))

    while True:
        try:
            img_batch = sess.run(next_batch)
            pred_np = model.predict(img_batch, batch_size)
            prob_result = np.concatenate((prob_result, pred_np), axis=0)
        except tf.errors.OutOfRangeError:
            print("prediction result size: {}".format(prob_result.shape))
            break

    assert prob_result.shape[0] == input_files.shape[0]
    mitosis_probs = prob_result[prob_result > prob_thres]
    input_files = input_files.reshape(-1, 1)
    mitosis_patch_files = input_files[prob_result > prob_thres]
    inference_result = []
    for i in range(mitosis_patch_files.size):
        row, col = get_location_from_file_name(mitosis_patch_files[i])
        prob = mitosis_probs[i]
        inference_result.append((row, col, prob))

    if len(inference_result) > 0:
        clustered_pred_locations = dbscan_clustering(
            inference_result,
            eps=eps,
            min_samples=min_samples,
            isWeightedAvg=isWeightedAvg)
        tuple_2_csv(inference_result,
                    os.path.join(output_dir_path, 'mitosis_locations.csv'))
        tuple_2_csv(
            clustered_pred_locations,
            os.path.join(output_dir_path, 'clustered_mitosis_locations.csv'))
    else:
        print("Do not have mitosis in {}".format(input_dir_path))
def run_mitosis_classification(model,
                               sess,
                               batch_size,
                               input_dir_path,
                               output_dir_path,
                               augmentation_number,
                               mitosis_tile_size=64,
                               num_parallel_calls=1,
                               prefetch=32,
                               prob_thres=0.5,
                               eps=64, min_samples=1,
                               isWeightedAvg=False):

    input_file_paths = [str(f) for f in Path(input_dir_path).glob('*.png')]
    input_files = np.asarray(input_file_paths, dtype=np.str)

    input_file_dataset = tf.data.Dataset.from_tensor_slices(input_files)
    img_dataset = input_file_dataset.map(lambda file: get_image_tf(file),
                                         num_parallel_calls=1)

    if augmentation_number == 1:
      img_dataset = img_dataset\
        .map(lambda img: normalize(img, "resnet_custom"),
             num_parallel_calls=num_parallel_calls)\
        .batch(batch_size)\
        .prefetch(prefetch)
      # Make sure all the files in the dataset are feeded into inference
      float_steps = len(input_file_paths) / batch_size
      int_steps = len(input_file_paths) // batch_size
      steps = math.ceil(float_steps) if float_steps > int_steps else int_steps
    else:
      img_dataset = img_dataset \
        .map(lambda img: create_augmented_batch(img, augmentation_number,
                         mitosis_tile_size),
             num_parallel_calls=num_parallel_calls) \
        .map(lambda img: normalize(img, "resnet_custom"),
             num_parallel_calls=num_parallel_calls) \
        .prefetch(prefetch)
      steps = len(input_file_paths)

    img_iterator = img_dataset.make_one_shot_iterator()
    next_batch = img_iterator.get_next()

    while True:
        try:
            pred_np = model.predict(next_batch, steps=steps)
            print("Prediction result shape: ", pred_np.shape)
        except tf.errors.OutOfRangeError:
            print("Please check the steps parameter. steps = {}, "
                  "batch_size = {}, input_tile_size = {}, "
                  "augmentation_number = {}"
                  .format(steps, batch_size, input_files.shape,
                          augmentation_number))
            break

    prob_result = \
      np.average(pred_np.reshape(-1, augmentation_number), axis=1)

    print("Finish the inference on {} with {} input tiles"
          .format(input_dir_path, prob_result.shape))

    assert prob_result.shape[0] == input_files.shape[0]
    mitosis_probs = prob_result[prob_result > prob_thres]
    input_files = input_files.reshape(-1, 1)
    mitosis_patch_files = input_files[prob_result > prob_thres]
    inference_result = []
    for i in range(mitosis_patch_files.size):
        row, col = get_location_from_file_name(mitosis_patch_files[i])
        prob = mitosis_probs[i]
        inference_result.append((row, col, prob))

    if len(inference_result) > 0:
        clustered_pred_locations = dbscan_clustering(
            inference_result, eps=eps, min_samples=min_samples,
            isWeightedAvg=isWeightedAvg)
        tuple_2_csv(
            inference_result,
            os.path.join(output_dir_path, 'mitosis_locations.csv'))
        tuple_2_csv(
            clustered_pred_locations,
            os.path.join(output_dir_path, 'clustered_mitosis_locations.csv'))
    else:
        print("Do not have mitosis in {}".format(input_dir_path))
Пример #5
0
 def _predict(self, x):
     norm_patch_batch = normalize((np.array(x) / 255).astype(np.float32),
                                  "resnet_custom")
     out_batch = self.output_tensor.eval(
         feed_dict={self.input_tensor: norm_patch_batch}, session=self.sess)
     return out_batch
Пример #6
0
def predict_mitoses_num_locations(model,
                                  model_name,
                                  threshold,
                                  ROI,
                                  tile_size=64,
                                  tile_overlap=0,
                                  tile_channel=3,
                                  batch_size=128,
                                  marginalization=False):
    """ Predict the number of mitoses with the detected mitosis locations
    for each input ROI.

  Args:
    model: model loaded from the model file.
    model_name: name of the input model, e.g. vgg, resnet.
    threshold: threshold for the output of last sigmoid layer.
    ROI: ROI in numpy array.
    ROI_size: size of ROI.
    ROI_overlap: overlap between ROIs.
    ROI_row: row number of the ROI in the input slide image. If setting
      it 0, the original coordination will be the left-upper corner of
      the input ROI.
    ROI_col: col number of the ROI in the input slide image. If setting
      it 0, the original coordination will be the left-upper corner of
      the input ROI.
    tile_size: tile size.
    tile_overlap: overlap between tiles.
    tile_channel: channel of tiles.
    batch_size: the batch_size for prediction.
    marginalization: Boolean for whether or not to use noise
      marginalization when making predictions.  If True, then
      each image will be expanded to a batch of size `batch_size` of
      augmented versions of that image, and predicted probabilities for
      each batch will be averaged to yield a single noise-marginalized
      prediction for each image.  Note: if this is True, then
      `batch_size` must be divisible by 4, or equal to 1 for a special
      debugging case of no augmentation.

  Return:
     the prediction result for the input ROI, (mitosis_num,
     mitosis_location_scores).
  """
    ROI_height, ROI_width, ROI_channel = ROI.shape

    # gen_dense_coords function will handle the cases that the tile center point is outside of the ROI
    tile_indices = list(
        gen_dense_coords(ROI_height, ROI_width, tile_size - tile_overlap))

    mitosis_location_scores = []
    predictions = np.empty((0, 1))

    if marginalization:
        # create tiles larger than the intended size so that we can perform random rotations and
        # random translations via cropping
        d = 72  # TODO: keep this in sync with the training augmentation code
        tiles = (element[0] for element in gen_patches(ROI,
                                                       tile_indices,
                                                       tile_size + d,
                                                       rotations=0,
                                                       translations=0,
                                                       max_shift=0,
                                                       p=1))

        # create marginalization graph
        # NOTE: averaging over sigmoid outputs vs. logits may yield slightly different results, due
        # to numerical precision
        prep_tile = tf.placeholder(
            tf.float32, shape=[tile_size + d, tile_size + d, tile_channel])
        aug_tiles = create_augmented_batch(prep_tile, batch_size,
                                           tile_size)  # create aug batch
        norm_tiles = normalize(aug_tiles,
                               model_name)  # normalize augmented tiles
        aug_preds = model(
            norm_tiles)  # make predictions on normalized and augmented batch
        pred = marginalize(aug_preds)  # average predictions

        # make predictions
        sess = tf.keras.backend.get_session()
        for tile in tiles:
            prep_tile_np = (tile / 255).astype(
                np.float32)  # convert to values in [0,1]
            pred_np, aug_preds_np = sess.run(
                (pred, aug_preds),
                feed_dict={
                    prep_tile: prep_tile_np,
                    tf.keras.backend.learning_phase(): 0
                })
            predictions = np.concatenate((predictions, pred_np), axis=0)

            print (f"The {predictions.shape[0]}th prediction: max: {np.max(aug_preds_np)}, min: "\
                   f"{np.min(aug_preds_np)}, avg: {pred_np}")

    else:
        tiles = (element[0] for element in gen_patches(ROI,
                                                       tile_indices,
                                                       tile_size,
                                                       rotations=0,
                                                       translations=0,
                                                       max_shift=0,
                                                       p=1))
        tile_batches = gen_batches(tiles, batch_size, include_partial=True)
        for tile_batch in tile_batches:
            tile_stack = np.stack(tile_batch, axis=0)
            tile_stack = normalize((tile_stack / 255).astype(dtype=np.float32),
                                   model_name)
            pred_np = model.predict(tile_stack, batch_size)
            predictions = np.concatenate((predictions, pred_np), axis=0)

    isMitoses = predictions > threshold
    for i in range(isMitoses.shape[0]):
        if isMitoses[i]:
            tile_row_index, tile_col_index = tile_indices[i]
            mitosis_location_scores.append(
                (tile_row_index, tile_col_index, np.asscalar(predictions[i])))

    mitosis_num = len(mitosis_location_scores)
    return (mitosis_num, mitosis_location_scores)