def logconv1x1_2d(input, filter): """ Convolution in logspace with 1x1 filters. Computes :math:`\log(\text{conv}(\mathtt{input},\mathtt{filter}))` from :math:`\log(\mathtt{input})` and :math:`\log(\mathtt{filter})` Args: input: Input in logspace filter: Filter in logspace Returns: Convolution of input and filter in logspace """ filter_max = replace_infs_with_zeros( tf.stop_gradient(tf.reduce_max(filter, axis=-2, keepdims=True))) input_max = replace_infs_with_zeros( tf.stop_gradient(tf.reduce_max(input, axis=-1, keepdims=True))) filter -= filter_max input -= input_max out = tf.math.log( tf.nn.convolution(input=tf.exp(input), filters=tf.exp(filter), padding="SAME")) out += filter_max + input_max return out
def logconv1x1_2d(input: tf.Tensor, filter: tf.Tensor) -> tf.Tensor: r""" Convolution in logspace with 1x1 filters. Computes :math:`\log(\text{conv}(\mathtt{input},\mathtt{filter}))` from :math:`\log(\mathtt{input})` and :math:`\log(\mathtt{filter})` Args: input: Input in logspace filter: Filter in logspace Returns: Convolution of input and filter in logspace """ with tf.name_scope("LogConv1x1"): filter_max = replace_infs_with_zeros( tf.stop_gradient(tf.reduce_max(filter, axis=-2, keepdims=True)) ) input_max = replace_infs_with_zeros( tf.stop_gradient(tf.reduce_max(input, axis=-1, keepdims=True)) ) out = tf.math.log( tf.nn.convolution( input=tf.exp(input - input_max), filters=tf.exp(filter - filter_max), padding="SAME", ) ) out += filter_max + input_max return out
def logmatmul(log_a, log_b): """ Matrix multiplication in log-space Args: log_a: log(a) of shape [..., batch, num_in] log_b: log(b) of shape [..., num_in, num_out] Returns: A matrix log(c) where log(c) = log(a @ b) """ # Compute max for each tensor for numerical stability max_a = replace_infs_with_zeros( tf.stop_gradient(tf.reduce_max(log_a, axis=-1, keepdims=True))) max_b = replace_infs_with_zeros( tf.stop_gradient(tf.reduce_max(log_b, axis=-2, keepdims=True))) # Compute logsumexp using matrix multiplication return tf.math.log(tf.matmul(tf.exp(log_a - max_a), tf.exp(log_b - max_b))) + max_a + max_b