コード例 #1
0
ファイル: normal.py プロジェクト: skeeet/manifold-flow
    def _log_prob(self, inputs, context):
        if inputs.shape[1:] != self._shape:
            raise ValueError("Expected input of shape {}, got {}".format(
                self._shape, inputs.shape[1:]))

        # Compute parameters.
        means, log_stds = self._compute_params(context)
        assert means.shape == inputs.shape and log_stds.shape == inputs.shape

        # Compute log prob.
        norm_inputs = (inputs - means) * torch.exp(-log_stds)
        log_prob = -0.5 * various.sum_except_batch(norm_inputs**2,
                                                   num_batch_dims=1)
        log_prob -= various.sum_except_batch(log_stds, num_batch_dims=1)
        log_prob -= self._log_z
        return log_prob
コード例 #2
0
    def _spline(self, inputs, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_widths = _share_across_batch(self.unnormalized_widths,
                                                  batch_size)
        unnormalized_heights = _share_across_batch(self.unnormalized_heights,
                                                   batch_size)
        unnormalized_derivatives = _share_across_batch(
            self.unnormalized_derivatives, batch_size)

        if self.tails is None:
            spline_fn = splines.rational_quadratic_spline
            spline_kwargs = {}
        else:
            spline_fn = splines.unconstrained_rational_quadratic_spline
            spline_kwargs = {
                "tails": self.tails,
                "tail_bound": self.tail_bound
            }

        outputs, logabsdet = spline_fn(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnormalized_derivatives=unnormalized_derivatives,
            inverse=inverse,
            min_bin_width=self.min_bin_width,
            min_bin_height=self.min_bin_height,
            min_derivative=self.min_derivative,
            **spline_kwargs)

        return outputs, various.sum_except_batch(logabsdet)
コード例 #3
0
    def _coupling_transform(self,
                            inputs,
                            transform_params,
                            inverse=False,
                            full_jacobian=False):
        if inputs.dim() == 4:
            b, c, h, w = inputs.shape
            # For images, reshape transform_params from Bx(C*?)xHxW to BxCxHxWx?
            transform_params = transform_params.reshape(b, c, -1, h,
                                                        w).permute(
                                                            0, 1, 3, 4, 2)
        elif inputs.dim() == 2:
            b, d = inputs.shape
            # For 2D data, reshape transform_params from Bx(D*?) to BxDx?
            transform_params = transform_params.reshape(b, d, -1)

        if full_jacobian:
            outputs, jacobian = self._piecewise_cdf(inputs,
                                                    transform_params,
                                                    inverse,
                                                    full_jacobian=True)
            return outputs, jacobian
        else:
            outputs, logabsdet = self._piecewise_cdf(inputs, transform_params,
                                                     inverse)
            return outputs, various.sum_except_batch(logabsdet)
コード例 #4
0
    def inverse(self, inputs, context=None, full_jacobian=False):
        if full_jacobian:
            raise NotImplementedError

        mask_right = inputs > self.inv_cut_point
        mask_left = inputs < -self.inv_cut_point
        mask_middle = ~(mask_right | mask_left)

        outputs = torch.zeros_like(inputs)
        outputs[mask_middle] = 0.5 * torch.log(
            (1 + inputs[mask_middle]) / (1 - inputs[mask_middle]))
        outputs[mask_right] = torch.exp(
            inputs[mask_right] / self.alpha) / self.beta
        outputs[mask_left] = -torch.exp(
            -inputs[mask_left] / self.alpha) / self.beta

        logabsdet = torch.zeros_like(inputs)
        logabsdet[mask_middle] = -torch.log(1 - inputs[mask_middle]**2)
        logabsdet[mask_right] = -np.log(
            self.alpha * self.beta) + inputs[mask_right] / self.alpha
        logabsdet[mask_left] = -np.log(
            self.alpha * self.beta) - inputs[mask_left] / self.alpha
        logabsdet = various.sum_except_batch(logabsdet, num_batch_dims=1)

        return outputs, logabsdet
コード例 #5
0
 def inverse(self, inputs, context=None):
     if torch.min(inputs) <= -1 or torch.max(inputs) >= 1:
         raise transforms.InputOutsideDomain()
     outputs = 0.5 * torch.log((1 + inputs) / (1 - inputs))
     logabsdet = -torch.log(1 - inputs ** 2)
     logabsdet = various.sum_except_batch(logabsdet, num_batch_dims=1)
     return outputs, logabsdet
コード例 #6
0
    def inverse(self, inputs, context=None):
        if torch.min(inputs) < 0 or torch.max(inputs) > 1:
            raise transforms.InputOutsideDomain()

        outputs = torch.tan(np.pi * (inputs - 0.5))
        logabsdet = -various.sum_except_batch(-np.log(np.pi) - torch.log(1 + outputs ** 2))
        return outputs, logabsdet
コード例 #7
0
    def forward(self, inputs, context=None, full_jacobian=False):
        if full_jacobian:
            raise NotImplementedError

        outputs = (1 / np.pi) * torch.atan(inputs) + 0.5
        logabsdet = various.sum_except_batch(-np.log(np.pi) -
                                             torch.log(1 + inputs**2))
        return outputs, logabsdet
コード例 #8
0
    def forward(self, inputs, context=None, full_jacobian=False):
        if full_jacobian:
            raise NotImplementedError

        outputs = torch.tanh(inputs)
        logabsdet = torch.log(1 - outputs**2)
        logabsdet = various.sum_except_batch(logabsdet, num_batch_dims=1)
        return outputs, logabsdet
コード例 #9
0
ファイル: normal.py プロジェクト: skeeet/manifold-flow
 def _log_prob(self, inputs, context):
     # Note: the context is ignored.
     if inputs.shape[1:] != self._shape:
         raise ValueError("Expected input of shape {}, got {}".format(
             self._shape, inputs.shape[1:]))
     neg_energy = -0.5 * various.sum_except_batch(inputs**2,
                                                  num_batch_dims=1)
     return neg_energy - self._log_z
コード例 #10
0
    def inverse(self, inputs, context=None):
        if torch.min(inputs) < 0 or torch.max(inputs) > 1:
            raise transforms.InputOutsideDomain()

        inputs = torch.clamp(inputs, self.eps, 1 - self.eps)

        outputs = (1 / self.temperature) * (torch.log(inputs) - torch.log1p(-inputs))
        logabsdet = -various.sum_except_batch(torch.log(self.temperature) - F.softplus(-self.temperature * outputs) - F.softplus(self.temperature * outputs))
        return outputs, logabsdet
コード例 #11
0
    def forward(self, inputs, context=None, full_jacobian=False):
        if full_jacobian:
            raise NotImplementedError

        outputs = F.leaky_relu(inputs, negative_slope=self.negative_slope)
        mask = (inputs < 0).type(torch.Tensor)
        logabsdet = self.log_negative_slope * mask
        logabsdet = various.sum_except_batch(logabsdet, num_batch_dims=1)
        return outputs, logabsdet
コード例 #12
0
    def forward(self, inputs, context=None, full_jacobian=False):
        if full_jacobian:
            raise NotImplementedError

        inputs = self.temperature * inputs
        outputs = torch.sigmoid(inputs)
        logabsdet = various.sum_except_batch(
            torch.log(self.temperature) - F.softplus(-inputs) -
            F.softplus(inputs))
        return outputs, logabsdet
コード例 #13
0
ファイル: normal.py プロジェクト: skeeet/manifold-flow
 def _log_prob(self, inputs, context):
     # Note: the context is ignored.
     if inputs.shape[1:] != self._shape:
         raise ValueError("Expected input of shape {}, got {}".format(
             self._shape, inputs.shape[1:]))
     if self._clip is not None:
         inputs = torch.clamp(inputs, -self._clip, self._clip)
     neg_energy = -0.5 * various.sum_except_batch(
         inputs**2, num_batch_dims=1) / self.std**2
     return neg_energy - self._log_z
コード例 #14
0
ファイル: normal.py プロジェクト: skeeet/manifold-flow
    def _log_prob(self, inputs, context):
        # Note: the context is ignored.
        if inputs.shape[1:] != self._shape:
            raise ValueError("Expected input of shape {}, got {}".format(
                self._shape, inputs.shape[1:]))
        inputs = torch.clamp(inputs, -self._clip, self._clip)

        stds = torch.exp(self.log_stds).unsqueeze(0)
        neg_energy = -0.5 * various.sum_except_batch(inputs**2 / stds**2,
                                                     num_batch_dims=1)
        log_z = self._log_z_constant + torch.sum(self.log_stds)
        return neg_energy - log_z
コード例 #15
0
 def _coupling_transform_inverse(self,
                                 inputs,
                                 transform_params,
                                 full_jacobian=False):
     scale, shift = self._scale_and_shift(transform_params)
     log_scale = torch.log(scale)
     outputs = (inputs - shift) / scale
     if full_jacobian:
         jacobian = -various.batch_diagonal(scale)
         return outputs, jacobian
     else:
         logabsdet = -various.sum_except_batch(log_scale, num_batch_dims=1)
         return outputs, logabsdet
コード例 #16
0
ファイル: conv.py プロジェクト: ramyapriya/manifold-flow
    def _lu_forward_inverse(self, inputs, inverse=False):
        b, c, h, w = inputs.shape
        inputs = inputs.permute(0, 2, 3, 1).reshape(b * h * w, c)

        if inverse:
            outputs, logabsdet = super().inverse(inputs)
        else:
            outputs, logabsdet = super().forward(inputs)

        outputs = outputs.reshape(b, h, w, c).permute(0, 3, 1, 2)
        logabsdet = logabsdet.reshape(b, h, w)

        return outputs, various.sum_except_batch(logabsdet)
コード例 #17
0
    def _spline(self, inputs, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_pdf = _share_across_batch(self.unnormalized_pdf, batch_size)

        if self.tails is None:
            outputs, logabsdet = splines.linear_spline(inputs=inputs, unnormalized_pdf=unnormalized_pdf, inverse=inverse)
        else:
            outputs, logabsdet = splines.unconstrained_linear_spline(
                inputs=inputs, unnormalized_pdf=unnormalized_pdf, inverse=inverse, tails=self.tails, tail_bound=self.tail_bound
            )

        return outputs, various.sum_except_batch(logabsdet)
コード例 #18
0
ファイル: discrete.py プロジェクト: skeeet/manifold-flow
    def _log_prob(self, inputs, context):
        if inputs.shape[1:] != self._shape:
            raise ValueError("Expected input of shape {}, got {}".format(
                self._shape, inputs.shape[1:]))

        # Compute parameters.
        logits = self._compute_params(context)
        assert logits.shape == inputs.shape

        # Compute log prob.
        log_prob = -inputs * F.softplus(-logits) - (
            1.0 - inputs) * F.softplus(logits)
        log_prob = various.sum_except_batch(log_prob, num_batch_dims=1)
        return log_prob
コード例 #19
0
 def _elementwise_forward(self,
                          inputs,
                          autoregressive_params,
                          full_jacobian=False):
     unconstrained_scale, shift = self._unconstrained_scale_and_shift(
         autoregressive_params)
     scale = torch.sigmoid(unconstrained_scale + 2.0) + 1e-3
     log_scale = torch.log(scale)
     outputs = scale * inputs + shift
     if full_jacobian:
         raise NotImplementedError
     else:
         logabsdet = various.sum_except_batch(log_scale, num_batch_dims=1)
         return outputs, logabsdet
コード例 #20
0
    def _elementwise(self,
                     inputs,
                     autoregressive_params,
                     inverse=False,
                     full_jacobian=False):

        if full_jacobian:
            raise NotImplementedError

        batch_size, features = inputs.shape[0], inputs.shape[1]

        transform_params = autoregressive_params.view(
            batch_size, features, self._output_dim_multiplier())

        unnormalized_widths = transform_params[..., :self.num_bins]
        unnormalized_heights = transform_params[..., self.num_bins:2 *
                                                self.num_bins]
        unnormalized_derivatives = transform_params[..., 2 * self.num_bins:]

        if hasattr(self.autoregressive_net, "hidden_features"):
            unnormalized_widths /= np.sqrt(
                self.autoregressive_net.hidden_features)
            unnormalized_heights /= np.sqrt(
                self.autoregressive_net.hidden_features)

        if self.tails is None:
            spline_fn = splines.rational_quadratic_spline
            spline_kwargs = {}
        elif self.tails == "linear":
            spline_fn = splines.unconstrained_rational_quadratic_spline
            spline_kwargs = {
                "tails": self.tails,
                "tail_bound": self.tail_bound
            }
        else:
            raise ValueError

        outputs, logabsdet = spline_fn(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnormalized_derivatives=unnormalized_derivatives,
            inverse=inverse,
            min_bin_width=self.min_bin_width,
            min_bin_height=self.min_bin_height,
            min_derivative=self.min_derivative,
            **spline_kwargs)

        return outputs, various.sum_except_batch(logabsdet)
コード例 #21
0
    def _lu_forward_inverse(self, inputs, inverse=False, full_jacobian=False):
        if full_jacobian:
            raise NotImplementedError

        b, c, h, w = inputs.shape
        inputs = inputs.permute(0, 2, 3, 1).reshape(b * h * w, c)

        if inverse:
            outputs, logabsdet = super().inverse(inputs)
        else:
            outputs, logabsdet = super().forward(inputs)

        outputs = outputs.reshape(b, h, w, c).permute(0, 3, 1, 2)
        logabsdet = logabsdet.reshape(b, h, w)

        return outputs, various.sum_except_batch(logabsdet)
コード例 #22
0
    def _elementwise(self,
                     inputs,
                     autoregressive_params,
                     inverse=False,
                     full_jacobian=False):
        batch_size = inputs.shape[0]

        unnormalized_pdf = autoregressive_params.view(
            batch_size, self.features, self._output_dim_multiplier())

        if full_jacobian:
            raise NotImplementedError

        outputs, logabsdet = splines.linear_spline(
            inputs=inputs, unnormalized_pdf=unnormalized_pdf, inverse=inverse)

        return outputs, various.sum_except_batch(logabsdet)
コード例 #23
0
    def forward(self, inputs, context=None):
        mask_right = inputs > self.cut_point
        mask_left = inputs < -self.cut_point
        mask_middle = ~(mask_right | mask_left)

        outputs = torch.zeros_like(inputs)
        outputs[mask_middle] = torch.tanh(inputs[mask_middle])
        outputs[mask_right] = self.alpha * torch.log(self.beta * inputs[mask_right])
        outputs[mask_left] = self.alpha * -torch.log(-self.beta * inputs[mask_left])

        logabsdet = torch.zeros_like(inputs)
        logabsdet[mask_middle] = torch.log(1 - outputs[mask_middle] ** 2)
        logabsdet[mask_right] = torch.log(self.alpha / inputs[mask_right])
        logabsdet[mask_left] = torch.log(-self.alpha / inputs[mask_left])
        logabsdet = various.sum_except_batch(logabsdet, num_batch_dims=1)

        return outputs, logabsdet
コード例 #24
0
    def _elementwise(self,
                     inputs,
                     autoregressive_params,
                     inverse=False,
                     full_jacobian=False):

        if full_jacobian:
            raise NotImplementedError

        batch_size = inputs.shape[0]

        transform_params = autoregressive_params.view(batch_size,
                                                      self.features,
                                                      self.num_bins * 2 + 2)

        unnormalized_widths = transform_params[..., :self.num_bins]
        unnormalized_heights = transform_params[..., self.num_bins:2 *
                                                self.num_bins]
        derivatives = transform_params[..., 2 * self.num_bins:]
        unnorm_derivatives_left = derivatives[..., 0][..., None]
        unnorm_derivatives_right = derivatives[..., 1][..., None]

        if hasattr(self.autoregressive_net, "hidden_features"):
            unnormalized_widths /= np.sqrt(
                self.autoregressive_net.hidden_features)
            unnormalized_heights /= np.sqrt(
                self.autoregressive_net.hidden_features)

        outputs, logabsdet = splines.cubic_spline(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnorm_derivatives_left=unnorm_derivatives_left,
            unnorm_derivatives_right=unnorm_derivatives_right,
            inverse=inverse,
        )
        return outputs, various.sum_except_batch(logabsdet)
コード例 #25
0
 def forward(self, inputs, context=None):
     inputs = self.temperature * inputs
     outputs = torch.sigmoid(inputs)
     logabsdet = various.sum_except_batch(torch.log(self.temperature) - F.softplus(-inputs) - F.softplus(inputs))
     return outputs, logabsdet
コード例 #26
0
 def forward(self, inputs, context=None):
     outputs = (1 / np.pi) * torch.atan(inputs) + 0.5
     logabsdet = various.sum_except_batch(-np.log(np.pi) - torch.log(1 + inputs ** 2))
     return outputs, logabsdet
コード例 #27
0
 def forward(self, inputs, context=None):
     outputs = torch.tanh(inputs)
     logabsdet = torch.log(1 - outputs ** 2)
     logabsdet = various.sum_except_batch(logabsdet, num_batch_dims=1)
     return outputs, logabsdet
コード例 #28
0
 def forward(self, inputs, context=None):
     outputs = F.leaky_relu(inputs, negative_slope=self.negative_slope)
     mask = (inputs < 0).type(torch.Tensor)
     logabsdet = self.log_negative_slope * mask
     logabsdet = various.sum_except_batch(logabsdet, num_batch_dims=1)
     return outputs, logabsdet
コード例 #29
0
 def log_prob(self, value, context):
     return various.sum_except_batch(super().log_prob(value))