def __init__(self, func=None, num_samples=1000, sigma=0.01, noise=perturbations._NORMAL, batched=True, maximize=True, reduction=tf.keras.losses.Reduction.SUM): """Initializes the Fenchel-Young loss. Args: func: the function whose argmax is to be differentiated by perturbation. num_samples: (int) the number of perturbed inputs. sigma: (float) the amount of noise to be considered noise: (str) the noise distribution to be used to sample perturbations. batched: whether inputs to the func will have a leading batch dimension (True) or consist of a single example (False). Defaults to True. maximize: (bool) whether to maximize or to minimize the input function. reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to loss. Default value is `SUM`. When used in custom training loops under the scope of `tf.distribute.Strategy`, must be set to `NONE` or `SUM`. """ super().__init__(reduction=reduction, name='fenchel_young') self._batched = batched self._maximize = maximize self.func = func self.perturbed = perturbations.perturbed(func=func, num_samples=num_samples, sigma=sigma, noise=noise, batched=batched)
def test_perturbed_argmax_gradients_without_minibatch(self): input_tensor = tf.constant([-0.6, -0.5, 0.5]) dim = len(input_tensor) eps = 1e-2 n = 10000000 argmax = lambda t: tf.one_hot(tf.argmax(t, 1), dim) soft_argmax = perturbations.perturbed(argmax, sigma=0.5, num_samples=n, batched=False) norm_argmax = lambda t: tf.reduce_sum(tf.square(soft_argmax(t))) w = tf.random.normal(input_tensor.shape) w /= tf.linalg.norm(w) var = tf.Variable(input_tensor) with tf.GradientTape() as tape: value = norm_argmax(var) grad = tape.gradient(value, var) grad = tf.reshape(grad, input_tensor.shape) value_minus = norm_argmax(input_tensor - eps * w) value_plus = norm_argmax(input_tensor + eps * w) lhs = tf.reduce_sum(w * grad) rhs = (value_plus - value_minus) * 1. / (2 * eps) self.assertAllLess(tf.abs(lhs - rhs), 0.05)
def test_unbatched_rank_one_raise(self): with self.assertRaises(ValueError): input_tensor = tf.constant([-0.6, -0.5, 0.5]) dim = len(input_tensor) n = 10000000 argmax = lambda t: tf.one_hot(tf.argmax(t, 1), dim) soft_argmax = perturbations.perturbed(argmax, sigma=0.5, num_samples=n) _ = soft_argmax(input_tensor)
def test_perturbed_reduce_sign_any_gradients(self): # We choose a point where the gradient should be above noise, that is # to say the distance to 0 along one direction is about sigma. sigma = 0.1 input_tensor = tf.constant([[-0.6, -1.2, 0.5 * sigma], [-2 * sigma, -2.4, -1.0]]) soft_reduce_any = perturbations.perturbed(reduce_sign_any, sigma=sigma) with tf.GradientTape() as tape: tape.watch(input_tensor) output_tensor = soft_reduce_any(input_tensor) gradient = tape.gradient(output_tensor, input_tensor) # The two values that could change the soft logical or should be the one # with real positive impact on the final values. self.assertAllGreater(gradient[0, 2], 0.0) self.assertAllGreater(gradient[1, 0], 0.0) # The value that is more on the fence should bring more gradient than any # other one. self.assertAllLessEqual(gradient, gradient[0, 2].numpy())
def test_perturbed_reduce_sign_any(self, sigma): input_tensor = tf.constant([[-0.3, -1.2, 1.6], [-0.4, -2.4, -1.0]]) soft_reduce_any = perturbations.perturbed(reduce_sign_any, sigma=sigma) output_tensor = soft_reduce_any(input_tensor, axis=-1) self.assertAllClose(output_tensor, [1.0, -1.0])