def test_splatting_type_values(self): frame = torch.tensor([1, 2], dtype=torch.float32).reshape([1, 1, 1, 2]) flow = torch.zeros([1, 2, 1, 2], dtype=torch.float32) flow[0, 0, 0, 0] = 1 importance_metric = torch.tensor([1, 2], dtype=torch.float32).reshape( [1, 1, 1, 2]) # summation splatting output = splatting.splatting_function("summation", frame, flow) assert output[0, 0, 0, 1] == pytest.approx(3) # average splatting output = splatting.splatting_function("average", frame, flow) assert output[0, 0, 0, 1] == pytest.approx(1.5) # linear splatting output = splatting.splatting_function("linear", frame, flow, importance_metric) assert output[0, 0, 0, 1] == pytest.approx(5.0 / 3.0) # softmax splatting output = splatting.splatting_function("softmax", frame, flow, importance_metric) assert output[0, 0, 0, 1] == pytest.approx( (exp(1) + 2 * exp(2)) / (exp(1) + exp(2)))
def test_splatting_type_names(self): frame = torch.zeros(1, 1, 3, 3) flow = torch.zeros(1, 2, 3, 3) importance_metric = torch.ones_like(frame) splatting.splatting_function("summation", frame, flow) splatting.splatting_function("average", frame, flow) splatting.splatting_function("linear", frame, flow, importance_metric) splatting.splatting_function("softmax", frame, flow, importance_metric) with pytest.raises(NotImplementedError): splatting.splatting_function("something_else", frame, flow, importance_metric)
def test_flow_one(self): frame = torch.zeros(1, 1, 3, 3) frame[0, :, 0, 0] = 1 flow = torch.zeros(1, 2, 3, 3) flow[0, :, 0, 0] = 1 target = torch.zeros(1, 1, 3, 3) target[0, :, 1, 1] = 1 output = splatting.splatting_function("summation", frame, flow) assert torch.equal(output, target)
def render_forward(src_ims, src_dms, Rcam, tcam, K_src, K_dst): Rcam = Rcam.to(device=src_ims.device)[None] tcam = tcam.to(device=src_ims.device)[None] R = Rcam t = tcam[..., None] K_src_inv = K_src.inverse() assert len(src_ims.shape) == 4 assert len(src_dms.shape) == 3 assert src_ims.shape[1:3] == src_dms.shape[1:3], (src_ims.shape, src_dms.shape) x = np.arange(src_ims[0].shape[1]) y = np.arange(src_ims[0].shape[0]) coord = np.stack(np.meshgrid(x, y), -1) coord = np.concatenate((coord, np.ones_like(coord)[:, :, [0]]), -1) # z=1 coord = coord.astype(np.float32) coord = torch.as_tensor(coord, dtype=K_src.dtype, device=K_src.device) coord = coord[None] # bs, h, w, 3 D = src_dms[:, :, :, None, None] points = K_dst[None, None, None, ...] @ (R[:, None, None, ...] @ ( D * K_src_inv[None, None, None, ...] @ coord[:, :, :, :, None]) + t[:, None, None, :, :]) points = points.squeeze(-1) new_z = points[:, :, :, [2]].clone().permute(0, 3, 1, 2) # b,1,h,w points = points / torch.clamp(points[:, :, :, [2]], 1e-8, None) src_ims = src_ims.permute(0, 3, 1, 2) flow = points - coord flow = flow.permute(0, 3, 1, 2)[:, :2, ...] alpha = 0.5 importance = alpha / new_z importance_min = importance.amin((1, 2, 3), keepdim=True) importance_max = importance.amax((1, 2, 3), keepdim=True) importance = (importance - importance_min) / ( importance_max - importance_min + 1e-6) * 10 - 10 importance = importance.exp() input_data = torch.cat([importance * src_ims, importance], 1) output_data = splatting_function("summation", input_data, flow) num = torch.sum(output_data[:, :-1, :, :], dim=0, keepdim=True) nom = torch.sum(output_data[:, -1:, :, :], dim=0, keepdim=True) rendered = num / (nom + 1e-7) rendered = rendered.permute(0, 2, 3, 1)[0, ...] return rendered
def test_importance_metric_type_and_shape(self): frame = torch.ones([1, 1, 3, 3]) flow = torch.zeros([1, 2, 3, 3]) importance_metric = frame.new_ones([1, 1, 3, 3]) wrong_metric_0 = frame.new_ones([2, 1, 3, 3]) wrong_metric_1 = frame.new_ones([1, 2, 3, 3]) wrong_metric_2 = frame.new_ones([1, 1, 2, 3]) wrong_metric_3 = frame.new_ones([1, 1, 3, 2]) # summation splatting splatting.splatting_function("summation", frame, flow) with pytest.raises(AssertionError): splatting.splatting_function("summation", frame, flow, importance_metric) # average splatting splatting.splatting_function("average", frame, flow) with pytest.raises(AssertionError): splatting.splatting_function("average", frame, flow, importance_metric) # linear splatting splatting.splatting_function("linear", frame, flow, importance_metric) with pytest.raises(AssertionError): splatting.splatting_function("linear", frame, flow) with pytest.raises(AssertionError): splatting.splatting_function("linear", frame, flow, wrong_metric_0) with pytest.raises(AssertionError): splatting.splatting_function("linear", frame, flow, wrong_metric_1) with pytest.raises(AssertionError): splatting.splatting_function("linear", frame, flow, wrong_metric_2) with pytest.raises(AssertionError): splatting.splatting_function("linear", frame, flow, wrong_metric_3) # softmax splatting splatting.splatting_function("softmax", frame, flow, importance_metric) with pytest.raises(AssertionError): splatting.splatting_function("softmax", frame, flow) with pytest.raises(AssertionError): splatting.splatting_function("softmax", frame, flow, wrong_metric_0) with pytest.raises(AssertionError): splatting.splatting_function("softmax", frame, flow, wrong_metric_1) with pytest.raises(AssertionError): splatting.splatting_function("softmax", frame, flow, wrong_metric_2) with pytest.raises(AssertionError): splatting.splatting_function("softmax", frame, flow, wrong_metric_3)
def run_test_backward(method, batch_size, spatial_size, flow_init, repetitions=10): frame = torch.ones(batch_size, 3, spatial_size, spatial_size) if flow_init == "zeros": flow = torch.zeros(batch_size, 2, spatial_size, spatial_size) elif flow_init == "ones": flow = torch.ones(batch_size, 2, spatial_size, spatial_size) else: raise NotImplementedError if method == "splatting_cpu": import splatting.cpu grad_output = torch.zeros_like(frame) grad_frame = torch.zeros_like(frame) grad_flow = torch.zeros_like(flow) def test_fn(): splatting.cpu.splatting_backward_cpu( frame, flow, grad_output, grad_frame, grad_flow ) elif method == "splatting_cuda": import splatting.cuda frame = frame.cuda() flow = flow.cuda() grad_output = torch.zeros_like(frame) grad_frame = torch.zeros_like(frame) grad_flow = torch.zeros_like(flow) def test_fn(): splatting.cuda.splatting_backward_cuda( frame, flow, grad_output, grad_frame, grad_flow ) torch.cuda.synchronize() elif method == "splatting_function": import splatting frame.requires_grad_(True) flow.requires_grad_(True) output = splatting.SummationSplattingFunction.apply(frame, flow).sum() def test_fn(): output.backward(retain_graph=True) elif method == "splatting_function_summation": import splatting frame.requires_grad_(True) flow.requires_grad_(True) output = splatting.splatting_function("summation", frame, flow).sum() def test_fn(): output.backward(retain_graph=True) elif method == "splatting_module_summation": import splatting frame.requires_grad_(True) flow.requires_grad_(True) splatting_module = splatting.Splatting("summation") output = splatting_module(frame, flow).sum() def test_fn(): output.backward(retain_graph=True) elif method == "splatting_module_softmax": import splatting frame.requires_grad_(True) flow.requires_grad_(True) importance_metric = frame.new_ones( [frame.shape[0], 1, frame.shape[2], frame.shape[3]] ) splatting_module = splatting.Splatting("softmax") output = splatting_module(frame, flow, importance_metric).sum() def test_fn(): output.backward(retain_graph=True) else: raise NotImplementedError(f"method {method}") ex_time = ( timeit.timeit( test_fn, number=repetitions, ) / repetitions ) print( f"backward \tbatch_size={batch_size}\tspatial_size={spatial_size}\t" + f"flow_init={flow_init}\tex_time={ex_time}" )
def test_fn(): splatting.splatting_function("summation", frame, flow)