def safe_divide(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ Divide a and b safely, set the output to zero where the divisor b is zero. Parameters ---------- a : torch.Tensor b : torch.Tensor Returns ------- torch.Tensor: the division. """ assert_named(a) assert_named(b) b = b.align_as(a) data = torch.where( b.rename(None) == 0, torch.tensor([0.0], dtype=a.dtype).to(a.device), (a / b).rename(None), ).refine_names(*a.names) return data