Exemplo n.º 1
0
 def test_min_step_input_shape(self):
     data = {"min": [1, 2, 3], "step": [0, 1, 3]}
     expected = raw_nodes.ParametrizedInputShape(**data)
     actual = fields.Union([
         fields.ExplicitShape(),
         fields.Nested(schema.ParametrizedInputShape())
     ],
                           required=True).deserialize(data)
     assert actual == expected
Exemplo n.º 2
0
class InputTensor(_TensorBase):
    shape = fields.Union(
        [
            fields.ExplicitShape(
                bioimageio_description=
                "Exact shape with same length as `axes`, e.g. `shape: [1, 512, 512, 1]`"
            ),
            fields.Nested(
                ParametrizedInputShape(),
                bioimageio_description=
                "A sequence of valid shapes given by `shape = min + k * step for k in {0, 1, ...}`.",
            ),
        ],
        required=True,
        bioimageio_description="Specification of input tensor shape.",
    )
    preprocessing = fields.List(
        fields.Nested(Preprocessing()),
        bioimageio_description=
        "Description of how this input should be preprocessed.")
    processing_name = "preprocessing"

    @validates_schema
    def zero_batch_step_and_one_batch_size(self, data, **kwargs):
        axes = data.get("axes")
        shape = data.get("shape")

        if axes is None or shape is None:
            raise ValidationError(
                "Failed to validate batch_step=0 and batch_size=1 due to other validation errors"
            )

        axes = data["axes"]
        shape = data["shape"]

        bidx = axes.find("b")
        if bidx == -1:
            return

        if isinstance(shape, raw_nodes.ParametrizedInputShape):
            step = shape.step
            shape = shape.min

        elif isinstance(shape, list):
            step = [0] * len(shape)
        else:
            raise ValidationError(f"Unknown shape type {type(shape)}")

        if step[bidx] != 0:
            raise ValidationError(
                "Input shape step has to be zero in the batch dimension (the batch dimension can always be "
                "increased, but `step` should specify how to increase the minimal shape to find the largest "
                "single batch shape)")

        if shape[bidx] != 1:
            raise ValidationError(
                "Input shape has to be 1 in the batch dimension b.")
Exemplo n.º 3
0
 def test_output_shape(self):
     data = {
         "reference_tensor": "in1",
         "scale": [1, 2, 3],
         "offset": [0, 1, 3]
     }
     expected = raw_nodes.ImplicitOutputShape(**data)
     actual = fields.Union([
         fields.ExplicitShape(),
         fields.Nested(schema.ImplicitOutputShape())
     ],
                           required=True).deserialize(data)
     assert actual == expected
Exemplo n.º 4
0
class OutputTensor(_TensorBase):
    shape = fields.Union(
        [
            fields.ExplicitShape(),
            fields.Nested(
                ImplicitOutputShape(),
                bioimageio_description=
                "In reference to the shape of an input tensor, the shape of the output "
                "tensor is `shape = shape(input_tensor) * scale + 2 * offset`.",
            ),
        ],
        required=True,
        bioimageio_description="Specification of output tensor shape.",
    )
    halo = fields.List(
        fields.Integer(),
        bioimageio_description=
        "The halo to crop from the output tensor (for example to crop away boundary effects or "
        "for tiling). The halo should be cropped from both sides, i.e. `shape_after_crop = shape - 2 * halo`. The "
        "`halo` is not cropped by the bioimage.io model, but is left to be cropped by the consumer software. Use "
        "`shape:offset` if the model output itself is cropped and input and output shapes not fixed.",
    )
    postprocessing = fields.List(
        fields.Nested(Postprocessing()),
        bioimageio_description=
        "Description of how this output should be postprocessed.",
    )
    processing_name = "postprocessing"

    @validates_schema
    def matching_halo_length(self, data, **kwargs):
        shape = data.get("shape")
        halo = data.get("halo")
        if halo is None:
            return
        elif isinstance(shape, list) or isinstance(
                shape, raw_nodes.ImplicitOutputShape):
            if shape is None or len(halo) != len(shape):
                raise ValidationError(
                    f"halo {halo} has to have same length as shape {shape}!")
        else:
            raise NotImplementedError(type(shape))
Exemplo n.º 5
0
class OutputTensor(_TensorBase):
    shape = fields.Union(
        [
            fields.ExplicitShape(),
            fields.Nested(
                ImplicitOutputShape(),
                bioimageio_description="In reference to the shape of an input tensor, the shape of the output "
                "tensor is `shape = shape(input_tensor) * scale + 2 * offset`.",
            ),
        ],
        required=True,
        bioimageio_description="Specification of output tensor shape.",
    )
    halo = fields.List(
        fields.Integer(),
        bioimageio_description=lambda: "Hint to describe the potentially corrupted edge region of the output tensor, due to "
        "boundary effects. "
        "The `halo` is not cropped by the bioimage.io model, but is left to be cropped by the consumer software. "
        f"An example implementation of prediction with tiling, accounting for the halo can be found [here]("
        f"{get_ref_url('function', '_predict_with_tiling_impl', 'https://github.com/bioimage-io/core-bioimage-io-python/blob/main/bioimageio/core/prediction.py')}). "
        "Use `shape:offset` if the model output itself is cropped and input and output shapes not fixed. ",
    )
    postprocessing = fields.List(
        fields.Nested(Postprocessing()),
        bioimageio_description="Description of how this output should be postprocessed.",
    )
    processing_name = "postprocessing"

    @validates_schema
    def matching_halo_length(self, data, **kwargs):
        shape = data["shape"]
        halo = data.get("halo")
        if halo is None:
            return
        elif isinstance(shape, list) or isinstance(shape, raw_nodes.ImplicitOutputShape):
            if len(halo) != len(shape):
                raise ValidationError(f"halo {halo} has to have same length as shape {shape}!")
        else:
            raise NotImplementedError(type(shape))
Exemplo n.º 6
0
 def test_explicit_input_shape(self):
     data = [1, 2, 3]
     expected = data
     actual = fields.ExplicitShape().deserialize(data)
     assert expected == actual