def sample_from_mix_gaussian_1d(l, nr_mix): l = l.permute(0, 2, 3, 1) ls = [int(y) for y in l.size()] xs = ls[:-1] + [1] #[3] # unpack parameters logit_probs = l[:, :, :, :nr_mix] l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # for mean, scale # sample mixture indicator from softmax temp = torch.FloatTensor(logit_probs.size()) if l.is_cuda: temp = temp.cuda() temp.uniform_(1e-5, 1. - 1e-5) temp = logit_probs.data - torch.log(-torch.log(temp)) _, argmax = temp.max(dim=3) one_hot = to_one_hot(argmax, nr_mix) sel = one_hot.view(xs[:-1] + [1, nr_mix]) # select logistic parameters means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) log_scales = torch.clamp(torch.sum(l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) u = torch.FloatTensor(means.size()) if l.is_cuda: u = u.cuda() u.uniform_(1e-5, 1. - 1e-5) u = Variable(u) distribution = Normal(loc=means, scale=log_scales) x = distribution.icdf(u) x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.) out = x0.unsqueeze(1) return out
def pred_dist_quantile(quantiles: list, pred_params: pd.DataFrame): """ Function that calculates the quantiles from the predicted response distribution. quantiles: list Which quantiles to calculate pred_params: pd.DataFrame Dataframe with predicted distributional parameters. Returns ------- pd.DataFrame with calculated quantiles. """ qGaussian = Normal(loc=torch.tensor(pred_params["location"]), scale=torch.tensor(pred_params["scale"])) pred_quantiles_list = [] for i in range(len(quantiles)): q = qGaussian.icdf(torch.tensor(quantiles[i])) q = q.detach().numpy() pred_quantiles_list.append(q) pred_quantiles = pd.DataFrame(pred_quantiles_list).T return pred_quantiles
def sample_truncated_normal_perturbations( X: Tensor, n_discrete_points: int, sigma: float, bounds: Tensor, qmc: bool = True, ) -> Tensor: r"""Sample points around `X`. Sample perturbed points around `X` such that the added perturbations are sampled from N(0, sigma^2 I) and truncated to be within [0,1]^d. Args: X: A `n x d`-dim tensor starting points. n_discrete_points: The number of points to sample. sigma: The standard deviation of the additive gaussian noise for perturbing the points. bounds: A `2 x d`-dim tensor containing the bounds. qmc: A boolean indicating whether to use qmc. Returns: A `n_discrete_points x d`-dim tensor containing the sampled points. """ X = normalize(X, bounds=bounds) d = X.shape[1] # sample points from N(X_center, sigma^2 I), truncated to be within # [0, 1]^d. if X.shape[0] > 1: rand_indices = torch.randint(X.shape[0], (n_discrete_points, ), device=X.device) X = X[rand_indices] if qmc: std_bounds = torch.zeros(2, d, dtype=X.dtype, device=X.device) std_bounds[1] = 1 u = draw_sobol_samples(bounds=std_bounds, n=n_discrete_points, q=1).squeeze(1) else: u = torch.rand((n_discrete_points, d), dtype=X.dtype, device=X.device) # compute bounds to sample from a = -X b = 1 - X # compute z-score of bounds alpha = a / sigma beta = b / sigma normal = Normal(0, 1) cdf_alpha = normal.cdf(alpha) # use inverse transform perturbation = normal.icdf(cdf_alpha + u * (normal.cdf(beta) - cdf_alpha)) * sigma # add perturbation and clip points that are still outside perturbed_X = (X + perturbation).clamp(0.0, 1.0) return unnormalize(perturbed_X, bounds=bounds)
class MQF2Distribution(Distribution): r""" Distribution class for the model MQF2 proposed in the paper ``Multivariate Quantile Function Forecaster`` by Kan, Aubet, Januschowski, Park, Benidis, Ruthotto, Gasthaus Parameters ---------- picnn A SequentialNet instance of a partially input convex neural network (picnn) hidden_state hidden_state obtained by unrolling the RNN encoder shape = (batch_size, context_length, hidden_size) in training shape = (batch_size, hidden_size) in inference prediction_length Length of the prediction horizon is_energy_score If True, use energy score as objective function otherwise use maximum likelihood as objective function (normalizing flows) es_num_samples Number of samples drawn to approximate the energy score beta Hyperparameter of the energy score (power of the two terms) threshold_input Clamping threshold of the (scaled) input when maximum likelihood is used as objective function this is used to make the forecaster more robust to outliers in training samples validate_args Sets whether validation is enabled or disabled For more details, refer to the descriptions in torch.distributions.distribution.Distribution """ def __init__( self, picnn: torch.nn.Module, hidden_state: torch.Tensor, prediction_length: int, is_energy_score: bool = True, es_num_samples: int = 50, beta: float = 1.0, threshold_input: float = 100.0, validate_args: bool = False, ) -> None: self.picnn = picnn self.hidden_state = hidden_state self.prediction_length = prediction_length self.is_energy_score = is_energy_score self.es_num_samples = es_num_samples self.beta = beta self.threshold_input = threshold_input super().__init__(batch_shape=self.batch_shape, validate_args=validate_args) self.context_length = self.hidden_state.shape[-2] if len( self.hidden_state.shape) > 2 else 1 self.numel_batch = self.get_numel(self.batch_shape) # mean zero and std one mu = torch.tensor(0, dtype=hidden_state.dtype, device=hidden_state.device) sigma = torch.ones_like(mu) self.standard_normal = Normal(mu, sigma) def stack_sliding_view(self, z: torch.Tensor) -> torch.Tensor: """ Auxiliary function for loss computation Unfolds the observations by sliding a window of size prediction_length over the observations z Then, reshapes the observations into a 2-dimensional tensor for further computation Parameters ---------- z A batch of time series with shape (batch_size, context_length + prediction_length - 1) Returns ------- Tensor Unfolded time series with shape (batch_size * context_length, prediction_length) """ z = z.unfold(dimension=-1, size=self.prediction_length, step=1) z = z.reshape(-1, z.shape[-1]) return z def loss(self, z: torch.Tensor) -> torch.Tensor: if self.is_energy_score: return self.energy_score(z) else: return -self.log_prob(z) def log_prob(self, z: torch.Tensor) -> torch.Tensor: """ Computes the log likelihood log(g(z)) + logdet(dg(z)/dz), where g is the gradient of the picnn Parameters ---------- z A batch of time series with shape (batch_size, context_length + prediciton_length - 1) Returns ------- loss Tesnor of shape (batch_size * context_length,) """ z = torch.clamp(z, min=-self.threshold_input, max=self.threshold_input) z = self.stack_sliding_view(z) loss = self.picnn.logp( z, self.hidden_state.reshape(-1, self.hidden_state.shape[-1])) return loss def energy_score(self, z: torch.Tensor) -> torch.Tensor: """ Computes the (approximated) energy score sum_i ES(g,z_i), where ES(g,z_i) = -1/(2*es_num_samples^2) * sum_{w,w'} ||w-w'||_2^beta + 1/es_num_samples * sum_{w''} ||w''-z_i||_2^beta, w's are samples drawn from the quantile function g(., h_i) (gradient of picnn), h_i is the hidden state associated with z_i, and es_num_samples is the number of samples drawn for each of w, w', w'' in energy score approximation Parameters ---------- z A batch of time series with shape (batch_size, context_length + prediction_length - 1) Returns ------- loss Tensor of shape (batch_size * context_length,) """ es_num_samples = self.es_num_samples beta = self.beta z = self.stack_sliding_view(z) reshaped_hidden_state = self.hidden_state.reshape( -1, self.hidden_state.shape[-1]) loss = self.picnn.energy_score(z, reshaped_hidden_state, es_num_samples=es_num_samples, beta=beta) return loss def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """ Generates the sample paths Parameters ---------- sample_shape Shape of the samples Returns ------- sample_paths Tesnor of shape (batch_size, *sample_shape, prediction_length) """ numel_batch = self.numel_batch prediction_length = self.prediction_length num_samples_per_batch = MQF2Distribution.get_numel(sample_shape) num_samples = num_samples_per_batch * numel_batch hidden_state_repeat = self.hidden_state.repeat_interleave( repeats=num_samples_per_batch, dim=0) alpha = torch.rand( (num_samples, prediction_length), dtype=self.hidden_state.dtype, device=self.hidden_state.device, layout=self.hidden_state.layout, ).clamp( min=1e-4, max=1 - 1e-4 ) # prevent numerical issues by preventing to sample beyond 0.1% and 99.9% percentiles samples = (self.quantile( alpha, hidden_state_repeat).reshape((numel_batch, ) + sample_shape + (prediction_length, )).transpose( 0, 1)) return samples def quantile(self, alpha: torch.Tensor, hidden_state: Optional[torch.Tensor] = None) -> torch.Tensor: """ Generates the predicted paths associated with the quantile levels alpha Parameters ---------- alpha quantile levels, shape = (batch_shape, prediction_length) hidden_state hidden_state, shape = (batch_shape, hidden_size) Returns ------- results predicted paths of shape = (batch_shape, prediction_length) """ if hidden_state is None: hidden_state = self.hidden_state normal_quantile = self.standard_normal.icdf(alpha) # In the energy score approach, we directly draw samples from picnn # In the MLE (Normalizing flows) approach, we need to invert the picnn # (go backward through the flow) to draw samples if self.is_energy_score: result = self.picnn(normal_quantile, context=hidden_state) else: result = self.picnn.reverse(normal_quantile, context=hidden_state) return result @staticmethod def get_numel(tensor_shape: torch.Size) -> int: # Auxiliary function # compute number of elements specified in a torch.Size() return torch.prod(torch.tensor(tensor_shape)).item() @property def batch_shape(self) -> torch.Size: # last dimension is the hidden state size return self.hidden_state.shape[:-1] @property def event_shape(self) -> Tuple: return (self.prediction_length, ) @property def event_dim(self) -> int: return 1
def interpolate(create_model_fn, idx_list): parser = argparse.ArgumentParser() # Training args parser.add_argument('--model_path', type=str) parser.add_argument('--start', type=int, default=0) parser.add_argument('--end', type=int, default=None) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--row_length', type=int, default=9) parser.add_argument('--double', type=eval, default=False) parser.add_argument('--clamp', type=eval, default=False) eval_args = parser.parse_args() model_log = os.path.join(LOG_FOLDER, eval_args.model_path) model_check = os.path.join(CHECK_FOLDER, eval_args.model_path) with open('{}/args.pickle'.format(model_log), 'rb') as f: args = pickle.load(f) torch.manual_seed(0) u = torch.rand(3, 32, 32).to(eval_args.device) if eval_args.double: u = u.double() ############### ## Load data ## ############### data = CategoricalCIFAR10() ################ ## Load model ## ################ model = create_model_fn(args) # Load pre-trained weights weights = torch.load('{}/model.pt'.format(model_check), map_location='cpu') model.load_state_dict(weights, strict=False) model = model.to(eval_args.device) model = model.eval() if eval_args.double: model = model.double() ############################ ## Perform interpolations ## ############################ gaussian = Normal(0, 1) idxs = idx_list[eval_args.start:eval_args.end] with torch.no_grad(): data1, data2 = [], [] batch_idxs = [] for n, (i1, i2) in enumerate(idxs): data1.append(data.test[i1][0].unsqueeze(0)) data2.append(data.test[i2][0].unsqueeze(0)) batch_idxs.append((i1, i2)) if (n + 1) % eval_args.batch_size == 0 or (n + 1) == len(idxs): data1 = torch.cat(data1, dim=0) data2 = torch.cat(data2, dim=0) print("Matching pairs", (n + 1) - eval_args.batch_size, "-", n + 1, "/", len(idxs)) if eval_args.double: data1 = data1.double() data2 = data2.double() double_str = '_double' if eval_args.double else '' z_lower1, z_upper1 = model.forward_transform( data1.to(eval_args.device)) z_lower2, z_upper2 = model.forward_transform( data2.to(eval_args.device)) z1 = z_lower1 + (z_upper1 - z_lower1) * u z2 = z_lower2 + (z_upper2 - z_lower2) * u # Move latent to Gaussian space g1 = gaussian.icdf(z1) g2 = gaussian.icdf(z2) g1[g1 == -math.inf] = -1e9 g1[g1 == math.inf] = 1e9 g2[g2 == -math.inf] = -1e9 g2[g2 == math.inf] = 1e9 # Interpolation in Gaussian space: ws = [(w / (math.sqrt(w**2 + (1 - w)**2)), (1 - w) / (math.sqrt(w**2 + (1 - w)**2))) for w in np.linspace(0, 1, eval_args.row_length)] zw = torch.cat( [gaussian.cdf(w[0] * g1 + w[1] * g2) for w in ws], dim=0) xw = model.inverse_transform( zw, clamp=eval_args.clamp).cpu().float() / 255 xw = xw.reshape(eval_args.row_length, len(batch_idxs), *xw.shape[1:]) for i, (i1, i2) in enumerate(batch_idxs): vutils.save_image(xw[:, i], '{}/i_{}_{}_l_{}{}.png'.format( model_log, i1, i2, eval_args.row_length, double_str), nrow=eval_args.row_length, padding=2) print("Stored interpolations") data1, data2 = [], [] batch_idxs = []