Exemplo n.º 1
0
    def test_kwargs(self):
        stride = (2, )
        x = make_conv_module(nn.Conv2d, stride=stride)

        actual = meta.conv_module_meta(x, stride=stride)["stride"]
        desired = stride
        assert actual == desired
Exemplo n.º 2
0
    def test_conv_module_meta_kwargs(self):
        stride = (2, )

        x = nn.Conv1d(1, 1, 1, stride=stride)

        actual = meta.conv_module_meta(x, stride=stride)["stride"]
        desired = stride
        self.assertEqual(actual, desired)
Exemplo n.º 3
0
    def test_pool_module_meta_kwargs(self):
        kernel_size = (2, )

        x = nn.MaxPool1d(kernel_size=kernel_size)

        actual = meta.conv_module_meta(x,
                                       kernel_size=kernel_size)["kernel_size"]
        desired = kernel_size
        self.assertEqual(actual, desired)
Exemplo n.º 4
0
    def test_main(self):
        conv_module_meta = {
            "kernel_size": (2, ),
            "stride": (3, ),
            "padding": (4, ),
            "dilation": (5, ),
        }
        x = make_conv_module(nn.Conv1d, **conv_module_meta)

        actual = meta.conv_module_meta(x)
        desired = conv_module_meta
        assert actual == desired
Exemplo n.º 5
0
    def test_conv_module_meta(self):
        conv_module_meta = {
            "kernel_size": (2, ),
            "stride": (3, ),
            "padding": (4, ),
            "dilation": (5, ),
        }

        x = nn.Conv1d(1, 1, **conv_module_meta)

        actual = meta.conv_module_meta(x)
        desired = conv_module_meta
        self.assertDictEqual(actual, desired)
Exemplo n.º 6
0
def _conv_guide(module: ConvModule, guide: torch.Tensor,
                method: str) -> torch.Tensor:
    # TODO: deal with convolution that doesn't preserve the output shape
    if method == "simple":
        return guide

    meta = conv_module_meta(module)
    guide_unfolded = F.unfold(guide, **meta).byte()

    if method == "inside":
        mask = ~torch.all(guide_unfolded, 1, keepdim=True)
        val = False
    else:
        mask = torch.any(guide_unfolded, 1, keepdim=True)
        val = True

    mask, _ = torch.broadcast_tensors(mask, guide_unfolded)
    guide_unfolded[mask] = val

    guide_folded = F.fold(guide_unfolded.float(), guide.size()[2:], **meta)
    return torch.clamp(guide_folded, 0.0, 1.0)