コード例 #1
0
ファイル: nd.py プロジェクト: VemburajYadav/PhiFlow
def blur(field, radius, cutoff=None, kernel="1/1+x"):
    """
Warning: This function can cause NaN in the gradients, reason unknown.

Runs a blur kernel over the given tensor.
    :param field: tensor
    :param radius: weight function curve scale
    :param cutoff: kernel size
    :param kernel: Type of blur kernel (str). Must be in ('1/1+x', 'gauss')
    :return:
    """
    if cutoff is None:
        cutoff = min(int(round(radius * 3)), *field.shape[1:-1])

    xyz = np.meshgrid(
        *[range(-int(cutoff), (cutoff) + 1) for _ in field.shape[1:-1]])
    d = math.to_float(np.sqrt(np.sum([x**2 for x in xyz], axis=0)))
    if kernel == "1/1+x":
        weights = math.to_float(1) / (d / radius + 1)
    elif kernel.lower() == "gauss":
        weights = math.exp(-d / radius / 2)
    else:
        raise ValueError("Unknown kernel: %s" % kernel)
    weights /= math.sum(weights)
    weights = math.reshape(weights, list(weights.shape) + [1, 1])
    return math.conv(field, weights)
コード例 #2
0
ファイル: nd.py プロジェクト: VemburajYadav/PhiFlow
def frequency_loss(tensor, frequency_falloff=100, reduce_batches=True):
    """
    Instead of minimizing each entry of the tensor, minimize the frequencies of the tensor, emphasizing lower frequencies over higher ones.

    :param reduce_batches: whether to reduce the batch dimension of the loss by adding the losses along the first dimension
    :param tensor: typically actual - target
    :param frequency_falloff: large values put more emphasis on lower frequencies, 1.0 weights all frequencies equally.
    :return: scalar loss value
    """
    if struct.isstruct(tensor):
        all_tensors = struct.flatten(tensor)
        return sum(
            frequency_loss(tensor, frequency_falloff, reduce_batches)
            for tensor in all_tensors)
    diff_fft = abs_square(math.fft(tensor))
    k = fftfreq(tensor.shape[1:-1], mode='absolute')
    weights = math.exp(-0.5 * k**2 * frequency_falloff**2)
    return l1_loss(diff_fft * weights, reduce_batches=reduce_batches)