def reconstruction(self): """Reconstruct the scale input.""" quan_x = ops.soft_quantile_normalization( self.x, self.quantiles_prime, target_weights=tf.nn.softmax(self._b_prime, axis=-1), **self._kwargs) self._nmf_factorizer(quan_x, random_seed=0) self.inner_kl.append(self._nmf_factorizer.losses) return ops.soft_quantile_normalization(self._nmf_factorizer.reconstruction, self.quantiles, target_weights=self.target_weights, **self._kwargs)
def sliced_reconstruction(self, rows=None, cols=None): """Reconstructs the scale input.""" return ops.soft_quantile_normalization( self.sliced_uv(rows, cols), self.sliced_quantiles(rows), target_weights=self.sliced_target_weights(rows), **self._kwargs)
def test_soft_quantile_normalization(self): x = tf.constant([1.2, 1.3, 1.5, -4.0, 1.8, 2.4, -1.0]) target = tf.cumsum(tf.ones(x.shape[0])) xn = ops.soft_quantile_normalization(x, target) # Make sure that the order of x and xn are identical self.assertAllEqual(tf.argsort(x), tf.argsort(xn)) # Make sure that the values of xn and target are close. self.assertAllClose(tf.sort(target), tf.sort(xn), atol=1e-1)
def call(self, inputs): target_cdf = tf.math.cumsum(tf.math.softmax(self.w), axis=self._axis) return ops.soft_quantile_normalization(inputs, target_cdf, axis=self._axis, **self._kwargs)