Example #1
0
def crop(
    x: torch.Tensor,
    origin: Tuple[int, int],
    size: Union[Tuple[int, int], int],
    vert_anchor: str = "top",
    horz_anchor: str = "left",
) -> torch.Tensor:
    verify_str_arg(vert_anchor, "vert_anchor", ("top", "bottom"))
    verify_str_arg(horz_anchor, "horz_anchor", ("left", "right"))

    vert_origin, horz_origin = origin
    height, width = _parse_size(size)

    def create_vert_slice() -> slice:
        if vert_anchor == "top":
            return slice(vert_origin, vert_origin + height)
        else:  # vert_anchor == "bottom"
            return slice(vert_origin - height, vert_origin)

    def create_horz_slice() -> slice:
        if horz_anchor == "left":
            return slice(horz_origin, horz_origin + width)
        else:  # horz_anchor == "right"
            return slice(horz_origin - width, horz_origin)

    vert_slice = create_vert_slice()
    horz_slice = create_horz_slice()
    return x[:, :, vert_slice, horz_slice]
Example #2
0
def _reduce(x: torch.Tensor, reduction: str) -> torch.Tensor:
    verify_str_arg(reduction, "reduction", ("mean", "sum", "none"))
    if reduction == "mean":
        return torch.mean(x)
    elif reduction == "sum":
        return torch.sum(x)
    else:  # reduction == "none":
        return x
Example #3
0
    def _parse_loss_weights(loss_weights: Union[str, Sequence[float]],
                            num_losses: int) -> Sequence[float]:
        if isinstance(loss_weights, str):
            verify_str_arg(loss_weights, "loss_weights", ("sum", "mean"))
            if loss_weights == "sum":
                return [1.0] * num_losses
            else:  # loss_weights == "mean":
                return [1.0 / num_losses] * num_losses
        else:
            if len(loss_weights) == num_losses:
                return loss_weights

            raise ValueError(
                f"The length of the loss weights and the number of losses do not "
                f"match: {len(loss_weights)} != {num_losses}")
Example #4
0
def test_verify_str_arg():
    arg = None
    with pytest.raises(ValueError):
        misc.verify_str_arg(arg)

    arg = "foo"
    valid_args = ("bar", "baz")
    with pytest.raises(ValueError):
        misc.verify_str_arg(arg, valid_args=valid_args)

    arg = "foo"
    valid_args = ("foo", "bar")

    actual = misc.verify_str_arg(arg, valid_args=valid_args)
    desired = arg
    assert actual == desired
Example #5
0
    def _parse_op_weights(
        op_weights: Union[str, Sequence[float]], num_ops: int
    ) -> Sequence[float]:
        if isinstance(op_weights, str):
            verify_str_arg(op_weights, "op_weights", ("sum", "mean"))
            if op_weights == "sum":
                return [1.0] * num_ops
            else:  # op_weights == "mean":
                return [1.0 / num_ops] * num_ops
        else:
            if len(op_weights) == num_ops:
                return op_weights

            msg = (
                f"The length of the operator weights and the number of operators do "
                f"not match: {len(op_weights)} != {num_ops}"
            )
            raise ValueError(msg)
Example #6
0
def propagate_guide(
    module: nn.Module,
    guide: torch.Tensor,
    method: str = "simple",
    allow_empty: bool = False,
) -> torch.Tensor:
    verify_str_arg(method, "method", ("simple", "inside", "all"))
    if is_conv_module(module):
        guide = _conv_guide(cast(ConvModule, module), guide, method)
    elif is_pool_module(module):
        guide = _pool_guide(cast(PoolModule, module), guide)

    if allow_empty or torch.any(guide.bool()):
        return guide

    msg = (
        f"Guide has no longer any entries after propagation through "
        f"{module.__class__.__name__}({module.extra_repr()}). If this is valid, "
        f"set allow_empty=True.")
    raise RuntimeError(msg)
Example #7
0
def image_to_edge_size(image_size: Tuple[int, int],
                       edge: str = "short") -> int:
    edge = verify_str_arg(edge, "edge", ("short", "long", "vert", "horz"))
    if edge == "short":
        return min(image_size)
    elif edge == "long":
        return max(image_size)
    elif edge == "vert":
        return image_size[0]
    else:  # edge == "horz"
        return image_size[1]
Example #8
0
def edge_to_image_size(edge_size: int,
                       aspect_ratio: float,
                       edge: str = "short") -> Tuple[int, int]:
    edge = verify_str_arg(edge, "edge", ("short", "long", "vert", "horz"))
    if edge == "vert":
        return edge_size, round(edge_size * aspect_ratio)
    elif edge == "horz":
        return round(edge_size * aspect_ratio), edge_size

    if (edge == "short") ^ (aspect_ratio < 1.0):
        return edge_size, round(edge_size * aspect_ratio)
    else:
        return round(edge_size / aspect_ratio), edge_size
Example #9
0
 def __init__(self, edge_size: int, num_steps: int, edge: str) -> None:
     self.edge_size = edge_size
     self.num_steps = num_steps
     self.edge = verify_str_arg(edge, "edge", ("short", "long"))
Example #10
0
def _resize_canvas(
    transform_matrix: torch.Tensor,
    image_size: Tuple[int, int],
    method: str = "same",
) -> Tuple[torch.Tensor, Tuple[int, int]]:
    verify_str_arg(method, "method", ("same", "full", "valid"))

    if method == "same":
        return transform_matrix, image_size

    def center_motif(transform_matrix: torch.Tensor,
                     image_size: Tuple[int, int]) -> torch.Tensor:
        image_center = _calculate_image_center(image_size)
        image_center = torch.tensor((*image_center[::-1], 1.0)).unsqueeze(1)
        motif_center = torch.mm(transform_matrix, image_center)
        motif_center = cast(Tuple[float, float],
                            motif_center[:-1, 0].tolist()[::-1])

        translation_matrix = _create_motif_translation_matrix(motif_center,
                                                              inverse=True)
        return torch.mm(translation_matrix, transform_matrix)

    def calculate_motif_vertices(transform_matrix: torch.Tensor,
                                 image_size: Tuple[int, int]) -> torch.Tensor:
        height, width = image_size
        # TODO: do this without transpose
        image_vertices = torch.tensor((
            (0.0, 0.0, 1.0),
            (width, 0.0, 1.0),
            (0.0, height, 1.0),
            (width, height, 1.0),
        )).t()
        return torch.mm(transform_matrix, image_vertices)[:-1, :]

    def scale_and_off_center_motif(
        transform_matrix: torch.Tensor,
        image_size: Tuple[int, int],
        bounding_box_size: Tuple[int, int],
    ) -> torch.Tensor:
        height, width = image_size
        image_center = _calculate_image_center(image_size)

        bounding_box_height, bounding_box_width = bounding_box_size
        scaling_factors = (height / bounding_box_height,
                           width / bounding_box_width)
        scaling_matrix = _create_motif_scaling_matrix(scaling_factors)
        scaling_matrix = _transform_around_point(image_center, scaling_matrix)

        translation_matrix = _create_motif_translation_matrix(image_center)

        return cast(
            torch.Tensor,
            torch.chain_matmul(scaling_matrix, translation_matrix,
                               transform_matrix),
        )

    transform_matrix = center_motif(transform_matrix, image_size)
    motif_vertices = calculate_motif_vertices(transform_matrix, image_size)

    if method == "full":
        bounding_box_size = _calculate_full_bounding_box_size(motif_vertices)
    else:  # method == "valid"
        bounding_box_size = _calculate_valid_bounding_box_size(motif_vertices)

    transform_matrix = scale_and_off_center_motif(transform_matrix, image_size,
                                                  bounding_box_size)

    return transform_matrix, bounding_box_size