Beispiel #1
0
 def inverse(self, inputs, context=None):
     if torch.min(inputs) <= -1 or torch.max(inputs) >= 1:
         raise transforms.InputOutsideDomain()
     outputs = 0.5 * torch.log((1 + inputs) / (1 - inputs))
     logabsdet = - torch.log(1 - inputs ** 2)
     logabsdet = utils.sum_except_batch(logabsdet, num_batch_dims=1)
     return outputs, logabsdet
Beispiel #2
0
    def _spline(self, inputs, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_widths=_share_across_batch(self.unnormalized_widths, batch_size)
        unnormalized_heights=_share_across_batch(self.unnormalized_heights, batch_size)
        unnormalized_derivatives=_share_across_batch(self.unnormalized_derivatives, batch_size)

        if self.tails is None:
            spline_fn = splines.rational_quadratic_spline
            spline_kwargs = {}
        else:
            spline_fn = splines.unconstrained_rational_quadratic_spline
            spline_kwargs = {
                'tails': self.tails,
                'tail_bound': self.tail_bound
            }

        outputs, logabsdet = spline_fn(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnormalized_derivatives=unnormalized_derivatives,
            inverse=inverse,
            min_bin_width=self.min_bin_width,
            min_bin_height=self.min_bin_height,
            min_derivative=self.min_derivative,
            **spline_kwargs
        )

        return outputs, utils.sum_except_batch(logabsdet)
 def _elementwise_inverse(self, inputs, autoregressive_params):
     unconstrained_scale, shift = self._unconstrained_scale_and_shift(autoregressive_params)
     scale = torch.sigmoid(unconstrained_scale + 2.) + 1e-3
     log_scale = torch.log(scale)
     outputs = (inputs - shift) / scale
     logabsdet = -utils.sum_except_batch(log_scale, num_batch_dims=1)
     return outputs, logabsdet
Beispiel #4
0
 def _log_prob(self, inputs, context):
     # Note: the context is ignored.
     if inputs.shape[1:] != self._shape:
         raise ValueError('Expected input of shape {}, got {}'.format(
             self._shape, inputs.shape[1:]))
     neg_energy = -0.5 * utils.sum_except_batch(inputs**2, num_batch_dims=1)
     return neg_energy - self._log_z
    def _elementwise(self, inputs, autoregressive_params, inverse=False):
        batch_size = inputs.shape[0]

        transform_params = autoregressive_params.view(batch_size,
                                                 self.features,
                                                 self.num_bins * 2 + 2)

        unnormalized_widths = transform_params[...,:self.num_bins]
        unnormalized_heights = transform_params[...,self.num_bins:2*self.num_bins]
        derivatives = transform_params[...,2*self.num_bins:]
        unnorm_derivatives_left = derivatives[..., 0][..., None]
        unnorm_derivatives_right = derivatives[..., 1][..., None]

        if hasattr(self.autoregressive_net, 'hidden_features'):
            unnormalized_widths /= np.sqrt(self.autoregressive_net.hidden_features)
            unnormalized_heights /= np.sqrt(self.autoregressive_net.hidden_features)

        outputs, logabsdet = splines.cubic_spline(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnorm_derivatives_left=unnorm_derivatives_left,
            unnorm_derivatives_right=unnorm_derivatives_right,
            inverse=inverse
        )
        return outputs, utils.sum_except_batch(logabsdet)
Beispiel #6
0
 def inverse(self, inputs, context=None):
     outputs = F.leaky_relu(inputs,
                            negative_slope=(1 / self.negative_slope))
     mask = (inputs < 0).type(torch.Tensor)
     logabsdet = -self.log_negative_slope * mask
     logabsdet = utils.sum_except_batch(logabsdet, num_batch_dims=1)
     return outputs, logabsdet
Beispiel #7
0
    def _log_prob(self, inputs, context):
        if inputs.shape[1:] != self._shape:
            raise ValueError('Expected input of shape {}, got {}'.format(
                self._shape, inputs.shape[1:]))

        # Compute parameters.
        means, log_stds = self._compute_params(context)
        assert means.shape == inputs.shape and log_stds.shape == inputs.shape

        # Compute log prob.
        norm_inputs = (inputs - means) * torch.exp(-log_stds)
        log_prob = -0.5 * utils.sum_except_batch(norm_inputs**2,
                                                 num_batch_dims=1)
        log_prob -= utils.sum_except_batch(log_stds, num_batch_dims=1)
        log_prob -= self._log_z
        return log_prob
Beispiel #8
0
 def forward(self, inputs, context=None):
     inputs = self.temperature * inputs
     outputs = torch.sigmoid(inputs)
     logabsdet = utils.sum_except_batch(
         torch.log(self.temperature) - F.softplus(-inputs) - F.softplus(inputs)
     )
     return outputs, logabsdet
Beispiel #9
0
    def inverse(self, inputs, context=None):
        if torch.min(inputs) < 0 or torch.max(inputs) > 1:
            raise transforms.InputOutsideDomain()

        outputs = torch.tan(np.pi * (inputs - 0.5))
        logabsdet = -utils.sum_except_batch(-np.log(np.pi) -
                                            torch.log(1 + outputs**2))
        return outputs, logabsdet
Beispiel #10
0
 def _elementwise_forward(self, inputs, autoregressive_params):
     unconstrained_scale, shift = self._unconstrained_scale_and_shift(
         autoregressive_params)
     # scale = torch.sigmoid(unconstrained_scale + 2.0) + self._epsilon
     scale = F.softplus(unconstrained_scale) + self._epsilon
     log_scale = torch.log(scale)
     outputs = scale * inputs + shift
     logabsdet = utils.sum_except_batch(log_scale, num_batch_dims=1)
     return outputs, logabsdet
Beispiel #11
0
    def _elementwise(self, inputs, autoregressive_params, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_pdf = autoregressive_params.view(
            batch_size, self.features, self._output_dim_multiplier())

        outputs, logabsdet = splines.linear_spline(
            inputs=inputs, unnormalized_pdf=unnormalized_pdf, inverse=inverse)

        return outputs, utils.sum_except_batch(logabsdet)
Beispiel #12
0
    def inverse(self, inputs, context=None):
        if torch.min(inputs) < 0 or torch.max(inputs) > 1:
            raise transforms.InputOutsideDomain()

        inputs = torch.clamp(inputs, self.eps, 1 - self.eps)

        outputs = (1 / self.temperature) * (torch.log(inputs) - torch.log1p(-inputs))
        logabsdet = - utils.sum_except_batch(
            torch.log(self.temperature) - F.softplus(
                -self.temperature * outputs) - F.softplus(self.temperature * outputs)
        )
        return outputs, logabsdet
Beispiel #13
0
    def _lu_forward_inverse(self, inputs, inverse=False):
        b, c, h, w = inputs.shape
        inputs = inputs.permute(0, 2, 3, 1).reshape(b * h * w, c)

        if inverse:
            outputs, logabsdet = super().inverse(inputs)
        else:
            outputs, logabsdet = super().forward(inputs)

        outputs = outputs.reshape(b, h, w, c).permute(0, 3, 1, 2)
        logabsdet = logabsdet.reshape(b, h, w)

        return outputs, utils.sum_except_batch(logabsdet)
Beispiel #14
0
    def _coupling_transform(self, inputs, transform_params, inverse=False):
        if inputs.dim() == 4:
            b, c, h, w = inputs.shape
            # For images, reshape transform_params from Bx(C*?)xHxW to BxCxHxWx?
            transform_params = transform_params.reshape(b, c, -1, h, w).permute(0, 1, 3, 4, 2)
        elif inputs.dim() == 2:
            b, d = inputs.shape
            # For 2D data, reshape transform_params from Bx(D*?) to BxDx?
            transform_params = transform_params.reshape(b, d, -1)

        outputs, logabsdet = self._piecewise_cdf(inputs, transform_params, inverse)

        return outputs, utils.sum_except_batch(logabsdet)
Beispiel #15
0
    def _log_prob(self, inputs, context):
        if inputs.shape[1:] != self._shape:
            raise ValueError('Expected input of shape {}, got {}'.format(
                self._shape, inputs.shape[1:]))

        # Compute parameters.
        logits = self._compute_params(context)
        assert logits.shape == inputs.shape

        # Compute log prob.
        log_prob = -inputs * F.softplus(-logits) - (
            1.0 - inputs) * F.softplus(logits)
        log_prob = utils.sum_except_batch(log_prob, num_batch_dims=1)
        return log_prob
Beispiel #16
0
    def forward(self, inputs, context=None):
        mask_right = (inputs > self.cut_point)
        mask_left = (inputs < -self.cut_point)
        mask_middle = ~(mask_right | mask_left)

        outputs = torch.zeros_like(inputs)
        outputs[mask_middle] = torch.tanh(inputs[mask_middle])
        outputs[mask_right] = self.alpha * torch.log(self.beta * inputs[mask_right])
        outputs[mask_left] = self.alpha * -torch.log(-self.beta * inputs[mask_left])

        logabsdet = torch.zeros_like(inputs)
        logabsdet[mask_middle] = torch.log(1 - outputs[mask_middle] ** 2)
        logabsdet[mask_right] = torch.log(self.alpha / inputs[mask_right])
        logabsdet[mask_left] = torch.log(-self.alpha / inputs[mask_left])
        logabsdet = utils.sum_except_batch(logabsdet, num_batch_dims=1)

        return outputs, logabsdet
Beispiel #17
0
    def inverse(self, inputs, context=None):

        mask_right = (inputs > self.inv_cut_point)
        mask_left = (inputs < -self.inv_cut_point)
        mask_middle = ~(mask_right | mask_left)

        outputs = torch.zeros_like(inputs)
        outputs[mask_middle] = 0.5 * torch.log((1 + inputs[mask_middle])
                                               / (1 - inputs[mask_middle]))
        outputs[mask_right] = torch.exp(inputs[mask_right] / self.alpha) / self.beta
        outputs[mask_left] = -torch.exp(-inputs[mask_left] / self.alpha) / self.beta

        logabsdet = torch.zeros_like(inputs)
        logabsdet[mask_middle] = -torch.log(1 - inputs[mask_middle] ** 2)
        logabsdet[mask_right] = -np.log(self.alpha * self.beta) + inputs[mask_right] / self.alpha
        logabsdet[mask_left] = -np.log(self.alpha * self.beta) - inputs[mask_left] / self.alpha
        logabsdet = utils.sum_except_batch(logabsdet, num_batch_dims=1)

        return outputs, logabsdet
    def _elementwise(self, inputs, autoregressive_params, inverse=False):
        batch_size, features = inputs.shape[0], inputs.shape[1]

        transform_params = autoregressive_params.view(
            batch_size,
            features,
            self._output_dim_multiplier()
        )

        unnormalized_widths = transform_params[...,:self.num_bins]
        unnormalized_heights = transform_params[...,self.num_bins:2*self.num_bins]
        unnormalized_derivatives = transform_params[...,2*self.num_bins:]

        if hasattr(self.autoregressive_net, 'hidden_features'):
            unnormalized_widths /= np.sqrt(self.autoregressive_net.hidden_features)
            unnormalized_heights /= np.sqrt(self.autoregressive_net.hidden_features)

        if self.tails is None:
            spline_fn = splines.rational_quadratic_spline
            spline_kwargs = {}
        elif self.tails == 'linear':
            spline_fn = splines.unconstrained_rational_quadratic_spline
            spline_kwargs = {
                'tails': self.tails,
                'tail_bound': self.tail_bound
            }
        else:
            raise ValueError

        outputs, logabsdet = spline_fn(
            inputs=inputs,
            unnormalized_widths=unnormalized_widths,
            unnormalized_heights=unnormalized_heights,
            unnormalized_derivatives=unnormalized_derivatives,
            inverse=inverse,
            min_bin_width=self.min_bin_width,
            min_bin_height=self.min_bin_height,
            min_derivative=self.min_derivative,
            **spline_kwargs
        )

        return outputs, utils.sum_except_batch(logabsdet)
Beispiel #19
0
    def _spline(self, inputs, inverse=False):
        batch_size = inputs.shape[0]

        unnormalized_pdf = _share_across_batch(self.unnormalized_pdf,
                                               batch_size)

        if self.tails is None:
            outputs, logabsdet = splines.linear_spline(
                inputs=inputs,
                unnormalized_pdf=unnormalized_pdf,
                inverse=inverse)
        else:
            outputs, logabsdet = splines.unconstrained_linear_spline(
                inputs=inputs,
                unnormalized_pdf=unnormalized_pdf,
                inverse=inverse,
                tails=self.tails,
                tail_bound=self.tail_bound)

        return outputs, utils.sum_except_batch(logabsdet)
    def forward(self, x, context=None):
        x2_size = self.input_dim - self.split_dim

        x1, x2 = x.split([self.split_dim, x2_size], dim=self.event_dim)
        nn_input = torch.cat(
            (x1, context), dim=self.event_dim) if self.context_dim != 0 else x1

        nn_out = torch.utils.checkpoint.checkpoint(
            self.nn, nn_input, preserve_rng_state=False)
        unnormalized_widths, unnormalized_heights, unnormalized_derivatives = nn_out.reshape(
            nn_input.shape[:2]+(-1, self._output_dim_multiplier())).split([self.num_bins, self.num_bins, self.num_bins+1], dim=self.event_dim)

        # Inverse not specified as default is false
        y2, ldj = torch.utils.checkpoint.checkpoint(unconstrained_rational_quadratic_spline,
                                                    x2,
                                                    unnormalized_widths,
                                                    unnormalized_heights,
                                                    unnormalized_derivatives,
                                                    preserve_rng_state=False)
      
        ldj = sum_except_batch(ldj, num_dims=2)

        y1 = x1
        return torch.cat([y1, y2], dim=self.event_dim), ldj
 def log_prob(self, x, context):
     dist = self.cond_dist(context)
     return sum_except_batch(dist.log_prob(x), num_dims=2)
Beispiel #22
0
 def _coupling_transform_inverse(self, inputs, transform_params):
     scale, shift = self._scale_and_shift(transform_params)
     log_scale = torch.log(scale)
     outputs = (inputs - shift) / scale
     logabsdet = -utils.sum_except_batch(log_scale, num_batch_dims=1)
     return outputs, logabsdet
Beispiel #23
0
 def _coupling_transform_forward(self, inputs, transform_params):
     scale, shift = self._scale_and_shift(transform_params)
     log_scale = torch.log(scale)
     outputs = inputs * scale + shift
     logabsdet = utils.sum_except_batch(log_scale, num_batch_dims=1)
     return outputs, logabsdet
Beispiel #24
0
def evaluate_test_samples(args, simulator, filename, model=None, ood=False, n_save_reco=100):
    """ Likelihood evaluation """

    logger.info(
        "Evaluating %s samples according to %s, %s likelihood evaluation, saving in %s",
        "the ground truth" if model is None else "a trained model",
        "ood" if ood else "test",
        "with" if not args.skiplikelihood else "without",
        filename,
    )

    # Prepare
    x, _ = simulator.load_dataset(
        train=False, numpy=True, ood=ood, dataset_dir=create_filename("dataset", None, args), true_param_id=args.trueparam, joint_score=False, limit_samplesize=args.evaluate,
    )
    parameter_grid = [None] if simulator.parameter_dim() is None else simulator.eval_parameter_grid(resolution=args.gridresolution)

    log_probs = []
    x_recos = None
    reco_error = None

    # Evaluate
    for i, params in enumerate(parameter_grid):
        logger.debug("Evaluating grid point %s / %s", i + 1, len(parameter_grid))
        if model is None:
            params_ = None if params is None else np.asarray([params for _ in x])
            log_prob = simulator.log_density(x, parameters=params_)

        else:
            log_prob = []
            reco_error_ = []
            x_recos_ = []
            n_batches = (args.evaluate - 1) // args.evalbatchsize + 1
            for j in range(n_batches):
                x_ = torch.tensor(x[j * args.evalbatchsize : (j + 1) * args.evalbatchsize], dtype=torch.float)
                if params is None:
                    params_ = None
                else:
                    params_ = np.asarray([params for _ in x_])
                    params_ = torch.tensor(params_, dtype=torch.float)

                if args.algorithm == "flow":
                    x_reco, log_prob_, _ = model(x_, context=params_)
                elif args.algorithm in ["pie", "slice"]:
                    x_reco, log_prob_, _ = model(x_, context=params_, mode=args.algorithm if not args.skiplikelihood else "projection")
                else:
                    x_reco, log_prob_, _ = model(x_, context=params_, mode="mf" if not args.skiplikelihood else "projection")

                if not args.skiplikelihood:
                    log_prob.append(log_prob_.detach().numpy())
                reco_error_.append((sum_except_batch((x_ - x_reco) ** 2) ** 0.5).detach().numpy())
                x_recos_.append(x_reco.detach().numpy())

            if not args.skiplikelihood:
                log_prob = np.concatenate(log_prob, axis=0)
            if reco_error is None:
                reco_error = np.concatenate(reco_error_, axis=0)
            if x_recos is None:
                x_recos = np.concatenate(x_recos_, axis=0)

        if not args.skiplikelihood:
            log_probs.append(log_prob)

    # Save results
    if len(log_probs) > 0:
        if simulator.parameter_dim() is None:
            log_probs = log_probs[0]

        np.save(create_filename("results", filename.format("log_likelihood"), args), log_probs)

    if len(x_recos) > 0:
        np.save(create_filename("results", filename.format("x_reco"), args), x_recos[:n_save_reco])

    if reco_error is not None:
        np.save(create_filename("results", filename.format("reco_error"), args), reco_error)

    if parameter_grid is not None:
        np.save(create_filename("results", "parameter_grid_test", args), parameter_grid)
Beispiel #25
0
 def log_prob(self, value, context):
     return utils.sum_except_batch(super().log_prob(value))
 def log_prob(self, x, context=None):
     log_base = -0.5 * math.log(2 * math.pi)
     log_inner = -0.5 * x**2
     return sum_except_batch(log_base + log_inner, num_dims=2)
 def sample_with_log_prob(self, context):
     dist = self.cond_dist(context)
     z = dist.rsample()
     log_prob = dist.log_prob(z)
     log_prob = sum_except_batch(log_prob, num_dims=2)
     return z, log_prob
Beispiel #28
0
 def forward(self, inputs, context=None):
     outputs = (1 / np.pi) * torch.atan(inputs) + 0.5
     logabsdet = utils.sum_except_batch(
         - np.log(np.pi) - torch.log(1 + inputs ** 2)
     )
     return outputs, logabsdet
Beispiel #29
0
 def forward(self, inputs, context=None):
     outputs = torch.tanh(inputs)
     logabsdet = torch.log(1 - outputs ** 2)
     logabsdet = utils.sum_except_batch(logabsdet, num_batch_dims=1)
     return outputs, logabsdet