def _elementwise_inverse(self, z, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() unnormalized_widths = elementwise_params[..., :self.num_bins] unnormalized_heights = elementwise_params[..., self.num_bins:2*self.num_bins] unnorm_derivatives_left = elementwise_params[..., 2*self.num_bins:2*self.num_bins+1] unnorm_derivatives_right = elementwise_params[..., 2*self.num_bins+1:] x, _ = splines.cubic_spline(z, unnormalized_widths=unnormalized_widths, unnormalized_heights=unnormalized_heights, unnorm_derivatives_left=unnorm_derivatives_left, unnorm_derivatives_right=unnorm_derivatives_right, inverse=True) return x
def _elementwise_forward(self, x, elementwise_params): assert elementwise_params.shape[-1] == self._output_dim_multiplier() unnormalized_widths = elementwise_params[..., :self.num_bins] unnormalized_heights = elementwise_params[..., self.num_bins:2*self.num_bins] unnorm_derivatives_left = elementwise_params[..., 2*self.num_bins:2*self.num_bins+1] unnorm_derivatives_right = elementwise_params[..., 2*self.num_bins+1:] z, ldj_elementwise = splines.cubic_spline(x, unnormalized_widths=unnormalized_widths, unnormalized_heights=unnormalized_heights, unnorm_derivatives_left=unnorm_derivatives_left, unnorm_derivatives_right=unnorm_derivatives_right, inverse=False) ldj = sum_except_batch(ldj_elementwise) return z, ldj