示例#1
0
def test_zip_equal():
    foo = (1, 2)
    bar = ("a", "b")

    actual = tuple(misc.zip_equal(foo, bar))
    desired = tuple(zip(foo, bar))
    assert actual == desired

    foo = (1, 2)
    bar = ("a", "b", "c")

    with pytest.raises(RuntimeError):
        misc.zip_equal(foo, bar)
示例#2
0
 def forward(self, *inputs: torch.Tensor) -> torch.Tensor:
     return join_channelwise(
         *[
             norm(input)
             for norm, input in misc.zip_equal(self.norm_modules, inputs)
         ],
         channel_dim=self.channel_dim,
     )
示例#3
0
def _extract_patchesnd(
    input: torch.Tensor, patch_sizes: Sequence[int], strides: Sequence[int]
) -> torch.Tensor:
    batch_size, num_channels = input.size()[:2]
    dims = range(2, input.dim())
    for dim, patch_size, stride in zip_equal(dims, patch_sizes, strides):
        input = input.unfold(dim, patch_size, stride)
    input = input.permute(0, *dims, 1, *[dim + len(dims) for dim in dims]).contiguous()
    return input.view(batch_size, -1, num_channels, *patch_sizes)
示例#4
0
def _extract_patchesnd(
    x: torch.Tensor, patch_sizes: Sequence[int], strides: Sequence[int]
) -> torch.Tensor:
    num_channels = x.size()[1]
    dims = range(2, x.dim())
    for dim, patch_size, stride in zip_equal(dims, patch_sizes, strides):
        x = x.unfold(dim, patch_size, stride)
    x = x.permute(0, *dims, 1, *[dim + len(dims) for dim in dims]).contiguous()
    num_patches = prod(x.size()[: len(dims) + 1])
    return x.view(num_patches, num_channels, *patch_sizes)
示例#5
0
    def build_levels(
        edge_sizes: Sequence[int],
        num_steps: Union[Sequence[int], int],
        edge: Union[Sequence[str], str],
    ) -> Tuple[PyramidLevel, ...]:
        num_levels = len(edge_sizes)
        if isinstance(num_steps, int):
            num_steps = [num_steps] * num_levels
        if isinstance(edge, str):
            edge = [edge] * num_levels

        return tuple(
            PyramidLevel(edge_size, num_steps_, edge_)
            for edge_size, num_steps_, edge_ in zip_equal(
                edge_sizes, num_steps, edge))