def backward(self, z: Tensor) -> Tuple[Tensor, Tensor]: log_det = tr.zeros(z.shape[0]).to(dtype=z.dtype, device=z.device) lower, upper = z[:, :self.dim // 2], z[:, self.dim // 2:] out = self.f2.__call__(upper).reshape(-1, self.dim // 2, 3 * self.K - 1) W, H, D = out.split(self.K, dim=2) W, H = tr.softmax(W, dim=2), tr.softmax(H, dim=2) W, H = 2 * self.B * W, 2 * self.B * H D = func.softplus(D) lower, ld = unconstrained_rqs(lower, W, H, D, inverse=True, tail_bound=self.B) log_det += tr.sum(ld, dim=1) out = self.f1.__call__(lower).reshape(-1, self.dim // 2, 3 * self.K - 1) W, H, D = out.split(self.K, dim=2) W, H = tr.softmax(W, dim=2), tr.softmax(H, dim=2) W, H = 2 * self.B * W, 2 * self.B * H D = func.softplus(D) upper, ld = unconstrained_rqs(upper, W, H, D, inverse=True, tail_bound=self.B) log_det += tr.sum(ld, dim=1) return tr.cat([lower, upper], dim=1), log_det
def forward_(self, z: Tensor) -> Tuple[Tensor, Tensor]: z_out = tr.zeros_like(z).to(dtype=z.dtype, device=z.device) log_det = tr.zeros(z_out.shape[0]).to(dtype=z.dtype, device=z.device) for i in range(self.dim): if i == 0: init_param = self.init_param.expand(z.shape[0], 3 * self.K - 1) W, H, D = init_param.split(self.K, dim=1) else: out = self.layers[i - 1].__call__(z[:, :i]) W, H, D = out.split(self.K, dim=1) W, H = tr.softmax(W, dim=1), tr.softmax(H, dim=1) W, H = 2 * self.B * W, 2 * self.B * H D = func.softplus(D) z_out[:, i], ld = unconstrained_rqs(z[:, i], W, H, D, inverse=False, tail_bound=self.B) log_det += ld return z_out, log_det
def forward_(self, z: Tensor) -> Tuple[Tensor, Tensor]: u, w, b = self.u, self.w, self.b # Create uhat such that it is parallel to w uw = tr.dot(u, w) u_hat = u + (func.softplus(uw) - 1 - uw) * tr.transpose( w, 0, -1) / tr.sum(w**2) # m(uw) == softplus(uw) - 1 == log(1 + exp(uw)) - 1 # Equation 21 - Transform z zw__b = tr.mv(z, vec=w) + b # z @ w + b fz = z + (u_hat.view(1, -1) * tr.tanh(zw__b).view(-1, 1)) # Compute the Jacobian using the fact that # tanh(x) dx = 1 - tanh(x)**2 ψ = (-tr.tanh(zw__b)**2 + 1).view(-1, 1) * w.view(1, -1) ψu = tr.mv(ψ, vec=u_hat) # ψ @ u_hat # Return the transformed output along with log determninant of J logabs_detjacobian = tr.log(tr.abs(ψu + 1) + δ) return fz, logabs_detjacobian
def μ_log_σ(self, x: Tensor) -> Tuple[Tensor, Tensor]: μ, log_σ = super(GaussianSampleTrim, self).μ_log_σ(x) μ = (tr.sigmoid(μ) - 0.5) * self.max_abs_5μ_div_2 log_σ = func.softplus(log_σ - self.min_log_σ) return μ, log_σ
def rqs(inputs: Tensor, unnormalized_widths: Tensor, unnormalized_heights: Tensor, unnormalized_derivatives: Tensor, inverse: bool = False, left: float = 0., right: float = 1., bottom: float = 0., top: float = 1., min_bin_width: float = DEFAULT_MIN_BIN_WIDTH, min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT, min_derivative: float = DEFAULT_MIN_DERIVATIVE ) -> Tuple[Tensor, Tensor]: if tr.min(inputs) < left or tr.max(inputs) > right: raise ValueError("Input outside domain") num_bins = unnormalized_widths.shape[-1] if min_bin_width * num_bins > 1.0: raise ValueError('Minimal bin width too large for the number of bins') if min_bin_height * num_bins > 1.0: raise ValueError('Minimal bin height too large for the number of bins') widths = func.softmax(unnormalized_widths, dim=-1) widths = min_bin_width + (1 - min_bin_width * num_bins) * widths cumwidths = tr.cumsum(widths, dim=-1) cumwidths = func.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) cumwidths = (right - left) * cumwidths + left cumwidths[..., 0] = left cumwidths[..., -1] = right widths = cumwidths[..., 1:] - cumwidths[..., :-1] derivatives = min_derivative + func.softplus(unnormalized_derivatives) heights = func.softmax(unnormalized_heights, dim=-1) heights = min_bin_height + (1 - min_bin_height * num_bins) * heights cumheights = tr.cumsum(heights, dim=-1) cumheights = func.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) cumheights = (top - bottom) * cumheights + bottom cumheights[..., 0] = bottom cumheights[..., -1] = top heights = cumheights[..., 1:] - cumheights[..., :-1] if inverse: bin_idx = searchsorted(cumheights, inputs)[..., None] else: bin_idx = searchsorted(cumwidths, inputs)[..., None] input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] input_bin_widths = widths.gather(-1, bin_idx)[..., 0] input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] Δ = heights / widths input_Δ = Δ.gather(-1, bin_idx)[..., 0] input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx) input_derivatives_plus_one = input_derivatives_plus_one[..., 0] input_heights = heights.gather(-1, bin_idx)[..., 0] if inverse: a = ((inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_Δ) + input_heights * (input_Δ - input_derivatives)) b = (input_heights * input_derivatives - (inputs - input_cumheights) * (input_derivatives + input_derivatives_plus_one - 2 * input_Δ)) c = -input_Δ * (inputs - input_cumheights) discriminant = b**2 - 4 * a * c if not (discriminant >= 0).all(): raise AssertionError root = (2 * c) / (-b - tr.sqrt(discriminant)) outputs = root * input_bin_widths + input_cumwidths θ_1minθ = root * (-root + 1) denominator = input_Δ + ( (input_derivatives + input_derivatives_plus_one - 2 * input_Δ) * θ_1minθ) derivative_numerator = (input_Δ**2) * ( input_derivatives_plus_one * root**2 + 2 * input_Δ * θ_1minθ + input_derivatives * (-root + 1)**2) logabsdet = tr.log(derivative_numerator) - 2 * tr.log(denominator) return outputs, -logabsdet θ = (inputs - input_cumwidths) / input_bin_widths θ_1minθ = θ * (-θ + 1) numerator = input_heights * (input_Δ * θ**2 + input_derivatives * θ_1minθ) denominator = input_Δ + ( (input_derivatives + input_derivatives_plus_one - 2 * input_Δ) * θ_1minθ) outputs = input_cumheights + numerator / denominator derivative_numerator = (input_Δ**2) * (input_derivatives_plus_one * θ**2 + 2 * input_Δ * θ_1minθ + input_derivatives * (-θ + 1)**2) logabsdet = tr.log(derivative_numerator) - 2 * tr.log(denominator) return outputs, logabsdet