def test_predict_call_fails_without_specifying_model_session_id(
            self, grpc_stub):
        with pytest.raises(grpc.RpcError) as e:
            grpc_stub.Predict(inference_pb2.PredictRequest())

        assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
        assert "model-session-id has not been provided" in e.value.details()
예제 #2
0
    def predict(self, feature_image, roi, axistags=None):
        """
        :param numpy.ndarray feature_image: classifier input
        :param numpy.ndarray roi: ROI within feature_image
        :param vigra.AxisTags axistags: axistags of feature_image
        :return: probabilities
        """
        assert isinstance(roi, numpy.ndarray)
        logger.debug("predict tile shape: %s (axistags: %r)", feature_image.shape, axistags)

        # translate roi axes todo: remove with tczyx standard
        # output_axis_order = self._model_conf.output_axis_order
        output_axis_order = self.output_axes
        if "c" not in output_axis_order:
            output_axis_order = "c" + output_axis_order
            c_was_not_in_output_axis_order = True
        else:
            c_was_not_in_output_axis_order = False
        roi = roi[:, [axistags.index(a) for a in output_axis_order]]

        reordered_feature_image = reorder_axes(feature_image, from_axes_tags=axistags, to_axes_tags=self.input_axes)

        try:
            current_rq = Request._current_request()
            resp = self.tiktorchClient.Predict.future(
                inference_pb2.PredictRequest(
                    tensor=converters.numpy_to_pb_tensor(reordered_feature_image, axistags=self.input_axes),
                    modelSessionId=self.__session.id,
                )
            )
            resp.add_done_callback(lambda o: current_rq._wake_up())
            current_rq._suspend()
            resp = resp.result()
            result = converters.pb_tensor_to_numpy(resp.tensor)
        except Exception:
            logger.exception("Predict call failed")
            return 0

        logger.debug(f"Obtained a predicted block of shape {result.shape}")
        if c_was_not_in_output_axis_order:
            result = result[None, ...]

        # make two channels out of single channel predictions
        channel_axis = output_axis_order.find("c")
        if result.shape[channel_axis] == 1:
            result = numpy.concatenate((result, 1 - result), axis=channel_axis)
            logger.debug(f"Changed shape of predicted block to {result.shape} by adding '1-p' channel")

        shape_wo_halo = result.shape
        result = result[roiToSlice(*roi)]
        logger.debug(
            f"Selected roi (start: {roi[0]}, stop: {roi[1]}) from result without halo {shape_wo_halo}. Now"
            f" result has shape: ({result.shape})."
        )

        return reorder_axes(result, from_axes_tags=output_axis_order, to_axes_tags=axistags)
예제 #3
0
    def test_call_predict(self, grpc_stub, pybio_dummy_model_bytes):
        model = grpc_stub.CreateModelSession(valid_model_request(pybio_dummy_model_bytes))

        arr = np.arange(32 * 32).reshape(1, 1, 32, 32)
        expected = arr + 1
        input_tensor = converters.numpy_to_pb_tensor(arr)
        res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensor=input_tensor))

        grpc_stub.CloseModelSession(model)

        assert_array_equal(expected, converters.pb_tensor_to_numpy(res.tensor))
예제 #4
0
 def test_call_fails_with_unknown_model_session_id(self, grpc_stub):
     with pytest.raises(grpc.RpcError) as e:
         res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId="myid1"))
     assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
     assert "model-session with id myid1 doesn't exist" in e.value.details()