def linear_sampler(data, offset): """Differentiable Temporal-wise Frame Sampling, which is essentially a linear interpolation process. It gets the feature map which has been split into several groups and shift them by different offsets according to their groups. Then compute the weighted sum along with the temporal dimension. Args: data (torch.Tensor): Split data for certain group in shape [N, num_segments, C, H, W]. offset (torch.Tensor): Data offsets for this group data in shape [N, num_segments]. """ # [N, num_segments, C, H, W] n, t, c, h, w = data.shape # offset0, offset1: [N, num_segments] offset0 = torch.floor(offset).int() offset1 = offset0 + 1 # data, data0, data1: [N, num_segments, C, H * W] data = data.view(n, t, c, h * w).contiguous() try: from mmcv.ops import tin_shift except (ImportError, ModuleNotFoundError): raise ImportError('Failed to import `tin_shift` from `mmcv.ops`. You ' 'will be unable to use TIN. ') data0 = tin_shift(data, offset0) data1 = tin_shift(data, offset1) # weight0, weight1: [N, num_segments] weight0 = 1 - (offset - offset0.float()) weight1 = 1 - weight0 # weight0, weight1: # [N, num_segments] -> [N, num_segments, C // num_segments] -> [N, C] group_size = offset.shape[1] weight0 = weight0[:, :, None].repeat(1, 1, c // group_size) weight0 = weight0.view(weight0.size(0), -1) weight1 = weight1[:, :, None].repeat(1, 1, c // group_size) weight1 = weight1.view(weight1.size(0), -1) # weight0, weight1: [N, C] -> [N, 1, C, 1] weight0 = weight0[:, None, :, None] weight1 = weight1[:, None, :, None] # output: [N, num_segments, C, H * W] -> [N, num_segments, C, H, W] output = weight0 * data0 + weight1 * data1 output = output.view(n, t, c, h, w) return output
def _test_tinshift_assert(dtype): try: from mmcv.ops import tin_shift except ModuleNotFoundError: pytest.skip('TINShift op is not successfully compiled') inputs = [torch.rand(2, 3, 4, 2), torch.rand(2, 3, 4, 2)] shifts = [torch.rand(2, 3), torch.rand(2, 5)] for x, shift in zip(inputs, shifts): x = x.cuda() shift = shift.cuda() # A ValueError should be raised if ops get inputs with wrong shapes. with pytest.raises(ValueError): tin_shift(x, shift)
def _test_tinshift_allclose(dtype): try: from mmcv.ops import tin_shift except ModuleNotFoundError: pytest.skip('TINShift op is not successfully compiled') for shift, output, grad in zip(shifts, outputs, grads): np_input = np.array(inputs) np_shift = np.array(shift) np_output = np.array(output) np_grad = np.array(grad) x = torch.tensor(np_input, dtype=dtype, device='cuda', requires_grad=True) shift = torch.tensor(np_shift, device='cuda').int() output = tin_shift(x, shift) output.backward(torch.ones_like(output)) assert np.allclose( output.data.type(torch.float).cpu().numpy(), np_output, 1e-3) assert np.allclose( x.grad.data.type(torch.float).cpu().numpy(), np_grad, 1e-3)