def _elementwise(self, inputs, autoregressive_params, inverse=False): batch_size = inputs.shape[0] transform_params = autoregressive_params.view(batch_size, self.features, self.num_bins * 2 + 2) unnormalized_widths = transform_params[..., :self.num_bins] unnormalized_heights = transform_params[..., self.num_bins:2 * self.num_bins] derivatives = transform_params[..., 2 * self.num_bins:] unnorm_derivatives_left = derivatives[..., 0][..., None] unnorm_derivatives_right = derivatives[..., 1][..., None] if hasattr(self.autoregressive_net, 'hidden_features'): unnormalized_widths /= np.sqrt( self.autoregressive_net.hidden_features) unnormalized_heights /= np.sqrt( self.autoregressive_net.hidden_features) outputs, logabsdet = splines.cubic_spline( inputs=inputs, unnormalized_widths=unnormalized_widths, unnormalized_heights=unnormalized_heights, unnorm_derivatives_left=unnorm_derivatives_left, unnorm_derivatives_right=unnorm_derivatives_right, inverse=inverse) return outputs, utils.sum_except_batch(logabsdet)
def _elementwise_inverse(self, inputs, autoregressive_params): unconstrained_scale, shift = self._unconstrained_scale_and_shift( autoregressive_params) scale = torch.sigmoid(unconstrained_scale + 2.) + 1e-3 log_scale = torch.log(scale) outputs = (inputs - shift) / scale logabsdet = -utils.sum_except_batch(log_scale, num_batch_dims=1) return outputs, logabsdet
def _elementwise(self, inputs, autoregressive_params, inverse=False): batch_size = inputs.shape[0] unnormalized_pdf = autoregressive_params.view( batch_size, self.features, self._output_dim_multiplier()) outputs, logabsdet = splines.linear_spline( inputs=inputs, unnormalized_pdf=unnormalized_pdf, inverse=inverse) return outputs, utils.sum_except_batch(logabsdet)
def _elementwise(self, inputs, autoregressive_params, inverse=False): batch_size, features = inputs.shape[0], inputs.shape[1] transform_params = autoregressive_params.view( batch_size, features, self._output_dim_multiplier()) unnormalized_widths = transform_params[..., :self.num_bins] unnormalized_heights = transform_params[..., self.num_bins:2 * self.num_bins] unnormalized_derivatives = transform_params[..., 2 * self.num_bins:] if hasattr(self.autoregressive_net, 'hidden_features'): unnormalized_widths /= np.sqrt( self.autoregressive_net.hidden_features) unnormalized_heights /= np.sqrt( self.autoregressive_net.hidden_features) if self.tails is None: spline_fn = splines.rational_quadratic_spline spline_kwargs = {} elif self.tails == 'linear': spline_fn = splines.unconstrained_rational_quadratic_spline spline_kwargs = { 'tails': self.tails, 'tail_bound': self.tail_bound } else: raise ValueError outputs, logabsdet = spline_fn( inputs=inputs, unnormalized_widths=unnormalized_widths, unnormalized_heights=unnormalized_heights, unnormalized_derivatives=unnormalized_derivatives, inverse=inverse, min_bin_width=self.min_bin_width, min_bin_height=self.min_bin_height, min_derivative=self.min_derivative, **spline_kwargs) return outputs, utils.sum_except_batch(logabsdet)