def _test_roll(test_case, device): torch_x = torch.rand( (2, 3, 5, 10, 20), device=device, dtype=torch.float32, requires_grad=True ) torch_grad = torch.rand_like(torch_x, device=device) shifts = ( np.random.randint(-100, 100), np.random.randint(-100, 100), np.random.randint(-100, 100), np.random.randint(-100, 100), ) dims = (0, 2, 3, 4) torch_y = torch.roll(torch_x, shifts, dims) torch_y.backward(torch_grad) of_x = flow.tensor( torch_x.detach().cpu().numpy(), device=device, dtype=flow.float32, requires_grad=True, ) of_y = flow.roll(of_x, shifts, dims) of_grad = flow.tensor(torch_grad.cpu().numpy(), device=device, dtype=flow.float32) of_y.backward(of_grad) test_case.assertTrue(np.array_equal(of_y.numpy(), torch_y.detach().cpu().numpy())) test_case.assertTrue(np.array_equal(of_x.grad.numpy(), torch_x.grad.cpu().numpy()))
def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: shifted_x = flow.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition( shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn( x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift if self.shift_size > 0: x = flow.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) # FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
def _roll(self, shifts, dims=None): return flow.roll(self, shifts=shifts, dims=dims)
def test_roll_runtime_error(test_case): with test_case.assertRaises(Exception) as context: x = flow.ones((2, 2), dtype=flow.float32, requires_grad=True) y = flow.roll(x, [0, 1], [0]) test_case.assertTrue( "shifts and dimensions must align" in str(context.exception))