def _elementwise_inverse(self, z, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() unnormalized_widths, unnormalized_heights = elementwise_params[ ..., :self.num_bins], elementwise_params[..., self.num_bins:] x, _ = splines.quadratic_spline( z, unnormalized_widths=unnormalized_widths, unnormalized_heights=unnormalized_heights, inverse=True) return x
def _elementwise_forward(self, x, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() unnormalized_widths, unnormalized_heights = elementwise_params[ ..., :self.num_bins], elementwise_params[..., self.num_bins:] z, ldj_elementwise = splines.quadratic_spline( x, unnormalized_widths=unnormalized_widths, unnormalized_heights=unnormalized_heights, inverse=False) ldj = sum_except_batch(ldj_elementwise) return z, ldj