Example #1
0
def _test_roll(test_case, device):
    torch_x = torch.rand(
        (2, 3, 5, 10, 20), device=device, dtype=torch.float32, requires_grad=True
    )
    torch_grad = torch.rand_like(torch_x, device=device)

    shifts = (
        np.random.randint(-100, 100),
        np.random.randint(-100, 100),
        np.random.randint(-100, 100),
        np.random.randint(-100, 100),
    )
    dims = (0, 2, 3, 4)

    torch_y = torch.roll(torch_x, shifts, dims)
    torch_y.backward(torch_grad)

    of_x = flow.tensor(
        torch_x.detach().cpu().numpy(),
        device=device,
        dtype=flow.float32,
        requires_grad=True,
    )
    of_y = flow.roll(of_x, shifts, dims)
    of_grad = flow.tensor(torch_grad.cpu().numpy(), device=device, dtype=flow.float32)
    of_y.backward(of_grad)

    test_case.assertTrue(np.array_equal(of_y.numpy(), torch_y.detach().cpu().numpy()))
    test_case.assertTrue(np.array_equal(of_x.grad.numpy(), torch_x.grad.cpu().numpy()))
Example #2
0
    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = flow.roll(x,
                                  shifts=(-self.shift_size, -self.shift_size),
                                  dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(
            shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size,
                                   C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(
            x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size,
                                         self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H,
                                   W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = flow.roll(shifted_x,
                          shifts=(self.shift_size, self.shift_size),
                          dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x
Example #3
0
def _roll(self, shifts, dims=None):
    return flow.roll(self, shifts=shifts, dims=dims)
 def test_roll_runtime_error(test_case):
     with test_case.assertRaises(Exception) as context:
         x = flow.ones((2, 2), dtype=flow.float32, requires_grad=True)
         y = flow.roll(x, [0, 1], [0])
     test_case.assertTrue(
         "shifts and dimensions must align" in str(context.exception))