def _conv_laplace_3d(tensor): """ 3D/Cube laplace stencil in 3D+2D [3,3,3,1,1] array([[[[[ 0.]], [[ 0.]], [[ 0.]]], [[[ 0.]], [[ 1.]], [[ 0.]]], [[[ 0.]], [[ 0.]], [[ 0.]]]], [[[[ 0.]], [[ 1.]], [[ 0.]]], [[[ 1.]], [[-6.]], [[ 1.]]], [[[ 0.]], [[ 1.]], [[ 0.]]]], [[[[ 0.]], [[ 0.]], [[ 0.]]], [[[ 0.]], [[ 1.]], [[ 0.]]], [[[ 0.]], [[ 0.]], [[ 0.]]]]] returns ... padding explicitly done in laplace(), hence here not needed """ kernel = math.to_float([[[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]], [[0., 1., 0.], [1., -6., 1.], [0., 1., 0.]], [[0., 0., 0.], [0., 1., 0.], [0., 0., 0.]]]) kernel = kernel.reshape((3, 3, 3, 1, 1)) if tensor.shape[-1] == 1: return math.conv(tensor, kernel, padding='VALID') else: return math.concat([ math.conv(tensor[..., i:i + 1], kernel, padding='VALID') for i in range(tensor.shape[-1]) ], -1)
def _conv_laplace_2d(tensor): kernel = np.array([[0., 1., 0.], [1., -4., 1.], [0., 1., 0.]], dtype=np.float32) kernel = kernel.reshape((3, 3, 1, 1)) if tensor.shape[-1] == 1: return math.conv(tensor, kernel, padding='VALID') else: return math.concat([math.conv(tensor[..., i:i + 1], kernel, padding='VALID') for i in range(tensor.shape[-1])], -1)
def blur(field, radius, cutoff=None, kernel="1/1+x"): """ Warning: This function can cause NaN in the gradients, reason unknown. Runs a blur kernel over the given tensor. :param field: tensor :param radius: weight function curve scale :param cutoff: kernel size :param kernel: Type of blur kernel (str). Must be in ('1/1+x', 'gauss') :return: """ if cutoff is None: cutoff = min(int(round(radius * 3)), *field.shape[1:-1]) xyz = np.meshgrid( *[range(-int(cutoff), (cutoff) + 1) for _ in field.shape[1:-1]]) d = math.to_float(np.sqrt(np.sum([x**2 for x in xyz], axis=0))) if kernel == "1/1+x": weights = math.to_float(1) / (d / radius + 1) elif kernel.lower() == "gauss": weights = math.exp(-d / radius / 2) else: raise ValueError("Unknown kernel: %s" % kernel) weights /= math.sum(weights) weights = math.reshape(weights, list(weights.shape) + [1, 1]) return math.conv(field, weights)