def _elementwise_forward(self, x, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() z, ldj_elementwise = splines.linear_spline(x, elementwise_params, inverse=False) ldj = sum_except_batch(ldj_elementwise) return z, ldj
def _elementwise_inverse(self, z, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() x, _ = splines.linear_spline(z, elementwise_params, inverse=True) return x