Exemple #1
0
def sample_from_mix_gaussian_1d(l, nr_mix):
    l = l.permute(0, 2, 3, 1)
    ls = [int(y) for y in l.size()]
    xs = ls[:-1] + [1]  #[3]

    # unpack parameters
    logit_probs = l[:, :, :, :nr_mix]
    l = l[:, :, :,
          nr_mix:].contiguous().view(xs + [nr_mix * 2])  # for mean, scale

    # sample mixture indicator from softmax
    temp = torch.FloatTensor(logit_probs.size())
    if l.is_cuda: temp = temp.cuda()
    temp.uniform_(1e-5, 1. - 1e-5)
    temp = logit_probs.data - torch.log(-torch.log(temp))
    _, argmax = temp.max(dim=3)

    one_hot = to_one_hot(argmax, nr_mix)
    sel = one_hot.view(xs[:-1] + [1, nr_mix])
    # select logistic parameters
    means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4)
    log_scales = torch.clamp(torch.sum(l[:, :, :, :, nr_mix:2 * nr_mix] * sel,
                                       dim=4),
                             min=-7.)
    u = torch.FloatTensor(means.size())
    if l.is_cuda: u = u.cuda()
    u.uniform_(1e-5, 1. - 1e-5)
    u = Variable(u)
    distribution = Normal(loc=means, scale=log_scales)
    x = distribution.icdf(u)
    x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.)
    out = x0.unsqueeze(1)
    return out
Exemple #2
0
    def pred_dist_quantile(quantiles: list, pred_params: pd.DataFrame):
        """
        Function that calculates the quantiles from the predicted response distribution.

        quantiles: list
            Which quantiles to calculate
        pred_params: pd.DataFrame
            Dataframe with predicted distributional parameters.

        Returns
        -------
        pd.DataFrame with calculated quantiles.

        """
        qGaussian = Normal(loc=torch.tensor(pred_params["location"]),
                           scale=torch.tensor(pred_params["scale"]))

        pred_quantiles_list = []

        for i in range(len(quantiles)):
            q = qGaussian.icdf(torch.tensor(quantiles[i]))
            q = q.detach().numpy()
            pred_quantiles_list.append(q)

        pred_quantiles = pd.DataFrame(pred_quantiles_list).T
        return pred_quantiles
Exemple #3
0
def sample_truncated_normal_perturbations(
    X: Tensor,
    n_discrete_points: int,
    sigma: float,
    bounds: Tensor,
    qmc: bool = True,
) -> Tensor:
    r"""Sample points around `X`.

    Sample perturbed points around `X` such that the added perturbations
    are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d.

    Args:
        X: A `n x d`-dim tensor starting points.
        n_discrete_points: The number of points to sample.
        sigma: The standard deviation of the additive gaussian noise for
            perturbing the points.
        bounds: A `2 x d`-dim tensor containing the bounds.
        qmc: A boolean indicating whether to use qmc.

    Returns:
        A `n_discrete_points x d`-dim tensor containing the sampled points.
    """
    X = normalize(X, bounds=bounds)
    d = X.shape[1]
    # sample points from N(X_center, sigma^2 I), truncated to be within
    # [0, 1]^d.
    if X.shape[0] > 1:
        rand_indices = torch.randint(X.shape[0], (n_discrete_points, ),
                                     device=X.device)
        X = X[rand_indices]
    if qmc:
        std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device)
        std_bounds[1] = 1
        u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points,
                               q=1).squeeze(1)
    else:
        u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device)
    # compute bounds to sample from
    a = -X
    b = 1 - X
    # compute z-score of bounds
    alpha = a / sigma
    beta = b / sigma
    normal = Normal(0, 1)
    cdf_alpha = normal.cdf(alpha)
    # use inverse transform
    perturbation = normal.icdf(cdf_alpha + u *
                               (normal.cdf(beta) - cdf_alpha)) * sigma
    # add perturbation and clip points that are still outside
    perturbed_X = (X + perturbation).clamp(0.0, 1.0)
    return unnormalize(perturbed_X, bounds=bounds)
Exemple #4
0
class MQF2Distribution(Distribution):
    r"""
    Distribution class for the model MQF2 proposed in the paper
    ``Multivariate Quantile Function Forecaster``
    by Kan, Aubet, Januschowski, Park, Benidis, Ruthotto, Gasthaus
    Parameters
    ----------
    picnn
        A SequentialNet instance of a
        partially input convex neural network (picnn)
    hidden_state
        hidden_state obtained by unrolling the RNN encoder
        shape = (batch_size, context_length, hidden_size) in training
        shape = (batch_size, hidden_size) in inference
    prediction_length
        Length of the prediction horizon
    is_energy_score
        If True, use energy score as objective function
        otherwise use maximum likelihood as
        objective function (normalizing flows)
    es_num_samples
        Number of samples drawn to approximate the energy score
    beta
        Hyperparameter of the energy score (power of the two terms)
    threshold_input
        Clamping threshold of the (scaled) input when maximum
        likelihood is used as objective function
        this is used to make the forecaster more robust
        to outliers in training samples
    validate_args
        Sets whether validation is enabled or disabled
        For more details, refer to the descriptions in
        torch.distributions.distribution.Distribution
    """
    def __init__(
        self,
        picnn: torch.nn.Module,
        hidden_state: torch.Tensor,
        prediction_length: int,
        is_energy_score: bool = True,
        es_num_samples: int = 50,
        beta: float = 1.0,
        threshold_input: float = 100.0,
        validate_args: bool = False,
    ) -> None:

        self.picnn = picnn
        self.hidden_state = hidden_state
        self.prediction_length = prediction_length
        self.is_energy_score = is_energy_score
        self.es_num_samples = es_num_samples
        self.beta = beta
        self.threshold_input = threshold_input

        super().__init__(batch_shape=self.batch_shape,
                         validate_args=validate_args)

        self.context_length = self.hidden_state.shape[-2] if len(
            self.hidden_state.shape) > 2 else 1
        self.numel_batch = self.get_numel(self.batch_shape)

        # mean zero and std one
        mu = torch.tensor(0,
                          dtype=hidden_state.dtype,
                          device=hidden_state.device)
        sigma = torch.ones_like(mu)
        self.standard_normal = Normal(mu, sigma)

    def stack_sliding_view(self, z: torch.Tensor) -> torch.Tensor:
        """
        Auxiliary function for loss computation
        Unfolds the observations by sliding a window of size prediction_length
        over the observations z
        Then, reshapes the observations into a 2-dimensional tensor for
        further computation
        Parameters
        ----------
        z
            A batch of time series with shape
            (batch_size, context_length + prediction_length - 1)
        Returns
        -------
        Tensor
            Unfolded time series with shape
            (batch_size * context_length, prediction_length)
        """

        z = z.unfold(dimension=-1, size=self.prediction_length, step=1)
        z = z.reshape(-1, z.shape[-1])

        return z

    def loss(self, z: torch.Tensor) -> torch.Tensor:
        if self.is_energy_score:
            return self.energy_score(z)
        else:
            return -self.log_prob(z)

    def log_prob(self, z: torch.Tensor) -> torch.Tensor:
        """
        Computes the log likelihood  log(g(z)) + logdet(dg(z)/dz),
        where g is the gradient of the picnn
        Parameters
        ----------
        z
            A batch of time series with shape
            (batch_size, context_length + prediciton_length - 1)
        Returns
        -------
        loss
            Tesnor of shape (batch_size * context_length,)
        """

        z = torch.clamp(z, min=-self.threshold_input, max=self.threshold_input)
        z = self.stack_sliding_view(z)

        loss = self.picnn.logp(
            z, self.hidden_state.reshape(-1, self.hidden_state.shape[-1]))

        return loss

    def energy_score(self, z: torch.Tensor) -> torch.Tensor:
        """
        Computes the (approximated) energy score sum_i ES(g,z_i),
        where ES(g,z_i) =
        -1/(2*es_num_samples^2) * sum_{w,w'} ||w-w'||_2^beta
        + 1/es_num_samples * sum_{w''} ||w''-z_i||_2^beta,
        w's are samples drawn from the
        quantile function g(., h_i) (gradient of picnn),
        h_i is the hidden state associated with z_i,
        and es_num_samples is the number of samples drawn
        for each of w, w', w'' in energy score approximation
        Parameters
        ----------
        z
            A batch of time series with shape
            (batch_size, context_length + prediction_length - 1)
        Returns
        -------
        loss
            Tensor of shape (batch_size * context_length,)
        """

        es_num_samples = self.es_num_samples
        beta = self.beta

        z = self.stack_sliding_view(z)
        reshaped_hidden_state = self.hidden_state.reshape(
            -1, self.hidden_state.shape[-1])

        loss = self.picnn.energy_score(z,
                                       reshaped_hidden_state,
                                       es_num_samples=es_num_samples,
                                       beta=beta)

        return loss

    def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
        """
        Generates the sample paths
        Parameters
        ----------
        sample_shape
            Shape of the samples
        Returns
        -------
        sample_paths
            Tesnor of shape (batch_size, *sample_shape, prediction_length)
        """

        numel_batch = self.numel_batch
        prediction_length = self.prediction_length

        num_samples_per_batch = MQF2Distribution.get_numel(sample_shape)
        num_samples = num_samples_per_batch * numel_batch

        hidden_state_repeat = self.hidden_state.repeat_interleave(
            repeats=num_samples_per_batch, dim=0)

        alpha = torch.rand(
            (num_samples, prediction_length),
            dtype=self.hidden_state.dtype,
            device=self.hidden_state.device,
            layout=self.hidden_state.layout,
        ).clamp(
            min=1e-4, max=1 - 1e-4
        )  # prevent numerical issues by preventing to sample beyond 0.1% and 99.9% percentiles

        samples = (self.quantile(
            alpha,
            hidden_state_repeat).reshape((numel_batch, ) + sample_shape +
                                         (prediction_length, )).transpose(
                                             0, 1))
        return samples

    def quantile(self,
                 alpha: torch.Tensor,
                 hidden_state: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Generates the predicted paths associated with the quantile levels alpha
        Parameters
        ----------
        alpha
            quantile levels,
            shape = (batch_shape, prediction_length)
        hidden_state
            hidden_state, shape = (batch_shape, hidden_size)
        Returns
        -------
        results
            predicted paths of shape = (batch_shape, prediction_length)
        """

        if hidden_state is None:
            hidden_state = self.hidden_state

        normal_quantile = self.standard_normal.icdf(alpha)

        # In the energy score approach, we directly draw samples from picnn
        # In the MLE (Normalizing flows) approach, we need to invert the picnn
        # (go backward through the flow) to draw samples
        if self.is_energy_score:
            result = self.picnn(normal_quantile, context=hidden_state)
        else:
            result = self.picnn.reverse(normal_quantile, context=hidden_state)

        return result

    @staticmethod
    def get_numel(tensor_shape: torch.Size) -> int:
        # Auxiliary function
        # compute number of elements specified in a torch.Size()
        return torch.prod(torch.tensor(tensor_shape)).item()

    @property
    def batch_shape(self) -> torch.Size:
        # last dimension is the hidden state size
        return self.hidden_state.shape[:-1]

    @property
    def event_shape(self) -> Tuple:
        return (self.prediction_length, )

    @property
    def event_dim(self) -> int:
        return 1
def interpolate(create_model_fn, idx_list):

    parser = argparse.ArgumentParser()

    # Training args
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--start', type=int, default=0)
    parser.add_argument('--end', type=int, default=None)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--row_length', type=int, default=9)
    parser.add_argument('--double', type=eval, default=False)
    parser.add_argument('--clamp', type=eval, default=False)

    eval_args = parser.parse_args()

    model_log = os.path.join(LOG_FOLDER, eval_args.model_path)
    model_check = os.path.join(CHECK_FOLDER, eval_args.model_path)

    with open('{}/args.pickle'.format(model_log), 'rb') as f:
        args = pickle.load(f)

    torch.manual_seed(0)

    u = torch.rand(3, 32, 32).to(eval_args.device)
    if eval_args.double: u = u.double()

    ###############
    ## Load data ##
    ###############

    data = CategoricalCIFAR10()

    ################
    ## Load model ##
    ################

    model = create_model_fn(args)

    # Load pre-trained weights
    weights = torch.load('{}/model.pt'.format(model_check), map_location='cpu')
    model.load_state_dict(weights, strict=False)
    model = model.to(eval_args.device)
    model = model.eval()
    if eval_args.double: model = model.double()

    ############################
    ## Perform interpolations ##
    ############################

    gaussian = Normal(0, 1)

    idxs = idx_list[eval_args.start:eval_args.end]

    with torch.no_grad():
        data1, data2 = [], []
        batch_idxs = []
        for n, (i1, i2) in enumerate(idxs):

            data1.append(data.test[i1][0].unsqueeze(0))
            data2.append(data.test[i2][0].unsqueeze(0))
            batch_idxs.append((i1, i2))

            if (n + 1) % eval_args.batch_size == 0 or (n + 1) == len(idxs):

                data1 = torch.cat(data1, dim=0)
                data2 = torch.cat(data2, dim=0)

                print("Matching pairs", (n + 1) - eval_args.batch_size, "-",
                      n + 1, "/", len(idxs))

                if eval_args.double:
                    data1 = data1.double()
                    data2 = data2.double()
                double_str = '_double' if eval_args.double else ''

                z_lower1, z_upper1 = model.forward_transform(
                    data1.to(eval_args.device))
                z_lower2, z_upper2 = model.forward_transform(
                    data2.to(eval_args.device))

                z1 = z_lower1 + (z_upper1 - z_lower1) * u
                z2 = z_lower2 + (z_upper2 - z_lower2) * u

                # Move latent to Gaussian space
                g1 = gaussian.icdf(z1)
                g2 = gaussian.icdf(z2)
                g1[g1 == -math.inf] = -1e9
                g1[g1 == math.inf] = 1e9
                g2[g2 == -math.inf] = -1e9
                g2[g2 == math.inf] = 1e9

                # Interpolation in Gaussian space:
                ws = [(w / (math.sqrt(w**2 + (1 - w)**2)),
                       (1 - w) / (math.sqrt(w**2 + (1 - w)**2)))
                      for w in np.linspace(0, 1, eval_args.row_length)]
                zw = torch.cat(
                    [gaussian.cdf(w[0] * g1 + w[1] * g2) for w in ws], dim=0)
                xw = model.inverse_transform(
                    zw, clamp=eval_args.clamp).cpu().float() / 255
                xw = xw.reshape(eval_args.row_length, len(batch_idxs),
                                *xw.shape[1:])
                for i, (i1, i2) in enumerate(batch_idxs):
                    vutils.save_image(xw[:, i],
                                      '{}/i_{}_{}_l_{}{}.png'.format(
                                          model_log, i1, i2,
                                          eval_args.row_length, double_str),
                                      nrow=eval_args.row_length,
                                      padding=2)
                print("Stored interpolations")

                data1, data2 = [], []
                batch_idxs = []