def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None): if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1: raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.") if scale_tril is not None: if scale_tril.dim() < 2: raise ValueError("scale_tril matrix must be at least two-dimensional, " "with optional leading batch dimensions") batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: if covariance_matrix.dim() < 2: raise ValueError("covariance_matrix must be at least two-dimensional, " "with optional leading batch dimensions") batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1]) self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) else: if precision_matrix.dim() < 2: raise ValueError("precision_matrix must be at least two-dimensional, " "with optional leading batch dimensions") batch_shape = torch.broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1]) self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) self.loc = loc.expand(batch_shape + (-1,)) event_shape = self.loc.shape[-1:] super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args) if scale_tril is not None: self._unbroadcasted_scale_tril = scale_tril elif covariance_matrix is not None: self._unbroadcasted_scale_tril = torch.cholesky(covariance_matrix) else: # precision_matrix is not None self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
def meta_cdist_forward(x1, x2, p, compute_mode): check( x1.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", ) check( x2.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", ) check( x1.size(-1) == x2.size(-1), lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", ) check( utils.is_float_dtype(x1.dtype), lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", ) check( utils.is_float_dtype(x2.dtype), lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", ) check(p >= 0, lambda: "cdist only supports non-negative p values") check( compute_mode >= 0 and compute_mode <= 2, lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}", ) r1 = x1.size(-2) r2 = x2.size(-2) batch_tensor1 = x1.shape[:-2] batch_tensor2 = x2.shape[:-2] output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) output_shape.extend([r1, r2]) return x1.new_empty(output_shape)
def broadcast_except(*tensors: Tensor, dim=-1): shape = broadcast_shapes(*(t.select(dim, 0).shape for t in tensors)) return [ t.expand(*shape[:t.ndim + dim + 1], t.shape[dim], *shape[t.ndim + dim + 1:]) for t in pad_dims(*tensors, ndim=len(shape) + 1) ]
def __init__(self, df: Union[torch.Tensor, Number], covariance_matrix: torch.Tensor = None, precision_matrix: torch.Tensor = None, scale_tril: torch.Tensor = None, validate_args=None): assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \ "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None) if param.dim() < 2: raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions") if isinstance(df, Number): batch_shape = torch.Size(param.shape[:-2]) self.df = torch.tensor(df, dtype=param.dtype, device=param.device) else: batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape) self.df = df.expand(batch_shape) event_shape = param.shape[-2:] if self.df.le(event_shape[-1] - 1).any(): raise ValueError(f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}.") if scale_tril is not None: self.scale_tril = param.expand(batch_shape + (-1, -1)) elif covariance_matrix is not None: self.covariance_matrix = param.expand(batch_shape + (-1, -1)) elif precision_matrix is not None: self.precision_matrix = param.expand(batch_shape + (-1, -1)) self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1) if self.df.lt(event_shape[-1]).any(): warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.") super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args) self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))] if scale_tril is not None: self._unbroadcasted_scale_tril = scale_tril elif covariance_matrix is not None: self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) else: # precision_matrix is not None self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) # Chi2 distribution is needed for Bartlett decomposition sampling self._dist_chi2 = torch.distributions.chi2.Chi2( df=( self.df.unsqueeze(-1) - torch.arange( self._event_shape[-1], dtype=self._unbroadcasted_scale_tril.dtype, device=self._unbroadcasted_scale_tril.device, ).expand(batch_shape + (-1,)) ) )
def broadcastable(*shapes) -> bool: r"""Returns whether `shapes` are broadcastable. Example: >>> x = torch.rand(3, 2, 1) >>> y = torch.rand(1, 2, 3) >>> z = torch.rand(2, 2, 2) >>> broadcastable(x.shape, y.shape) True >>> broadcastable(y.shape, z.shape) False """ try: torch.broadcast_shapes(*shapes) except RuntimeError as e: return False else: return True
def _dense_unflatten(self, flat_samples: torch.Tensor) -> Dict[str, torch.Tensor]: # Convert a single flattened sample to a dict of shaped samples. sample_shape = flat_samples.shape[:-1] samples = {} pos = 0 for d, (batch_shape, event_shape) in self._dense_shapes.items(): end = pos + (batch_shape + event_shape).numel() flat_sample = flat_samples[..., pos:end] pos = end # Assumes sample shapes are left of batch shapes. samples[d] = flat_sample.reshape( torch.broadcast_shapes(sample_shape, batch_shape) + event_shape ) return samples
def broadcast_gather(input, dim, index, sparse_grad=False, index_ndim=1): """ input: Size(batch_shape..., N, event_shape...) index: Size(batch_shape..., index_shape...) ->: Size(batch_shape..., index_shape..., event_shape...) """ index_shape = index.shape[-index_ndim:] index = index.flatten(-index_ndim) batch_shape = broadcast_shapes(input.shape[:dim], index.shape[:-1]) input = input.expand(*batch_shape, *input.shape[dim:]) index = index.expand(*batch_shape, index.shape[-1]) return torch.gather( input, dim, index.reshape(*index.shape, *(input.ndim - index.ndim) * (1, )).expand(*index.shape, *input.shape[index.ndim:]) if input.ndim > index.ndim else index, sparse_grad=sparse_grad).reshape(*index.shape[:-1], *index_shape, *input.shape[dim % input.ndim + 1:])
def _masked_observe(name, fn, obs, obs_mask, *args, **kwargs): # Split into two auxiliary sample sites. with poutine.mask(mask=obs_mask): observed = sample(f"{name}_observed", fn, *args, **kwargs, obs=obs) with poutine.mask(mask=~obs_mask): unobserved = sample(f"{name}_unobserved", fn, *args, **kwargs) # Interleave observed and unobserved events. shape = obs_mask.shape + (1, ) * fn.event_dim batch_mask = obs_mask.reshape(shape) try: value = torch.where(batch_mask, observed, unobserved) except RuntimeError as e: if "must match the size of tensor" in str(e): shape = torch.broadcast_shapes(observed.shape, unobserved.shape) batch_shape = shape[:len(shape) - fn.event_dim] raise ValueError( f"Invalid obs_mask shape {tuple(obs_mask.shape)}; should be " f"broadcastable to batch_shape = {tuple(batch_shape)}") from e raise return deterministic(name, value)
def _unpack_latent(self, latent): """ Unpacks a packed latent tensor, iterating over tuples of the form:: (site, unconstrained_value) """ batch_shape = latent.shape[: -1] # for plates outside of _setup_prototype, e.g. parallel particles pos = 0 for name, site in self.prototype_trace.iter_stochastic_nodes(): constrained_shape = site["value"].shape unconstrained_shape = self._unconstrained_shapes[name] size = _product(unconstrained_shape) event_dim = (site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape)) unconstrained_shape = torch.broadcast_shapes( unconstrained_shape, batch_shape + (1, ) * event_dim) unconstrained_value = latent[..., pos:pos + size].view(unconstrained_shape) yield site, unconstrained_value pos += size if not torch._C._get_tracing_state(): assert pos == latent.size(-1)
def __init__(self, base_dist: TorchDistribution, skewness, validate_args=None): assert ( base_dist.event_shape == skewness.shape[-1:] ), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`." if (skewness.abs().sum(-1) > 1.0).any(): warnings.warn("Total skewness weight shouldn't exceed one.", UserWarning) batch_shape = broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) event_shape = skewness.shape[-1:] self.skewness = skewness.broadcast_to(batch_shape + event_shape) self.base_dist = base_dist.expand(batch_shape) super().__init__(batch_shape, event_shape, validate_args=validate_args) if self._validate_args and base_dist.mean.device != skewness.device: raise ValueError( f"base_density: {base_dist.__class__.__name__} and SineSkewed " f"must be on same device.")
def broadcast_shape(shape_a: Tuple[int, ...], shape_b: Tuple[int, ...]) -> Tuple[int, ...]: """ Infers, if possible, the broadcast output shape of two operands a and b. Inspired by stackoverflow post: https://stackoverflow.com/questions/24743753/test-if-an-array-is-broadcastable-to-a-shape Parameters ---------- shape_a : Tuple[int,...] Shape of first operand shape_b : Tuple[int,...] Shape of second operand Raises ------- ValueError If the two shapes cannot be broadcast. Examples -------- >>> import heat as ht >>> ht.core.stride_tricks.broadcast_shape((5,4),(4,)) (5, 4) >>> ht.core.stride_tricks.broadcast_shape((1,100,1),(10,1,5)) (10, 100, 5) >>> ht.core.stride_tricks.broadcast_shape((8,1,6,1),(7,1,5,)) (8,7,6,5)) >>> ht.core.stride_tricks.broadcast_shape((2,1),(8,4,3)) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "heat/core/stride_tricks.py", line 42, in broadcast_shape "operands could not be broadcast, input shapes {} {}".format(shape_a, shape_b) ValueError: operands could not be broadcast, input shapes (2, 1) (8, 4, 3) """ try: resulting_shape = torch.broadcast_shapes(shape_a, shape_b) except AttributeError: # torch < 1.8 it = itertools.zip_longest(shape_a[::-1], shape_b[::-1], fillvalue=1) resulting_shape = max(len(shape_a), len(shape_b)) * [None] for i, (a, b) in enumerate(it): if a == 0 and b == 1 or b == 0 and a == 1: resulting_shape[i] = 0 elif a == 1 or b == 1 or a == b: resulting_shape[i] = max(a, b) else: raise ValueError( "operands could not be broadcast, input shapes {} {}". format(shape_a, shape_b)) return tuple(resulting_shape[::-1]) except TypeError: raise TypeError("operand 1 must be tuple of ints, not {}".format( type(shape_a))) except NameError: raise TypeError( "operands must be tuples of ints, not {} and {}".format( shape_a, shape_b)) except RuntimeError: raise ValueError( "operands could not be broadcast, input shapes {} {}".format( shape_a, shape_b)) return tuple(resulting_shape)
def __init__( self, height: torch.Tensor, width: torch.Tensor, x: torch.Tensor, y: torch.Tensor, target_locs: torch.Tensor, background: torch.Tensor, gain: torch.Tensor, offset_samples: torch.Tensor, offset_logits: torch.Tensor, P: int, m: torch.Tensor = None, alpha: torch.Tensor = None, use_pykeops: bool = True, validate_args=None, ): # shapes for cosmos and crosstalk models self.height = height # (N, F, C, K) or (N, F, Q, K) self.width = width # (N, F, C, K) or (N, F, Q, K) self.x = x self.y = y self.target_locs = target_locs # (N, F, C, 2) self.m = m # (N, F, C, K) or (N, F, Q, K) self.background = background[..., None, None] # (N, F, C, P, P) if alpha is not None: C = alpha.shape[-1] self.gain = gain[..., None, None, None] # (1, K, P, P) self.height = (self.height.unsqueeze(-2) * alpha[..., None] ) # (N, F, Q, C, K) self.width = self.width.unsqueeze(-2) # (N, F, Q, 1, K) self.x = self.x.unsqueeze(-2) # (N, F, Q, 1, K) self.y = self.y.unsqueeze(-2) # (N, F, Q, 1, K) self.m = self.m.unsqueeze(-2) # (N, F, Q, 1, K) self.target_locs = self.target_locs.unsqueeze( -3) # (N, F, 1, C, 2) else: self.gain = gain[..., None, None] # (1, P, P) self.alpha = alpha self.rate = 1 / self.gain self.offset_samples = offset_samples self.offset_logits = offset_logits self.P = P self.use_pykeops = use_pykeops if self.use_pykeops: device = self.target_locs.device.type self.device_pykeops = "GPU" if device == "cuda" else "CPU" # calculate batch shape batch_shape = torch.broadcast_shapes( height.shape, width.shape, x.shape, y.shape) # (N, F, C, K) or (N, F, Q, K) if m is not None: batch_shape = torch.broadcast_shapes( batch_shape, m.shape) # (N, F, C, K) or (N, F, Q, K) event_shape = torch.Size([P, P]) # (P, P) bg_shape = background.shape # (N, F, C) target_shape = target_locs.shape[:-1] # (N, F, C) # remove K dim batch_shape = batch_shape[:-1] # (N, F, C) or (N, F, Q) if alpha is not None: # remove Q dim batch_shape = batch_shape[:-1] # (N, F) # add C dim event_shape = (C, ) + event_shape # (C, P, P) # remove C dim bg_shape = bg_shape[:-1] # (N, F) target_shape = target_shape[:-1] # (N, F) batch_shape = torch.broadcast_shapes( batch_shape, bg_shape, target_shape) # (N, F, C) or (N, F) super().__init__(batch_shape, event_shape, validate_args=validate_args)
def broadcast_left(*tensors, ndim): shape = broadcast_shapes(*(t.shape[:ndim] for t in tensors)) return (t.expand(*shape, *t.shape[ndim:]) for t in tensors)
def forward_shape(self, shape): return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
def inverse_shape(self, shape): return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
def inverse_shape(self, shape): return torch.broadcast_shapes(shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()))
def get_deltas(self, save_params=None): deltas = {} aux_values = {} compute_density = poutine.get_mask() is not False for name, site in self._sorted_sites: if save_params is not None and name not in save_params: continue # Sample zero-mean blockwise independent Delta/Normal/MVN. log_density = 0.0 loc = deep_getattr(self.locs, name) zero = torch.zeros_like(loc) conditional = self.conditionals[name] if callable(conditional): aux_value = deep_getattr(self.conds, name)() elif conditional == "delta": aux_value = zero elif conditional == "normal": aux_value = pyro.sample( name + "_aux", dist.Normal(zero, 1).to_event(1), infer={"is_auxiliary": True}, ) scale = deep_getattr(self.scales, name) aux_value = aux_value * scale if compute_density: log_density = (-scale.log()).expand_as(aux_value) elif conditional == "mvn": # This overparametrizes by learning (scale,scale_tril), # enabling faster learning of the more-global scale parameter. aux_value = pyro.sample( name + "_aux", dist.Normal(zero, 1).to_event(1), infer={"is_auxiliary": True}, ) scale = deep_getattr(self.scales, name) scale_tril = deep_getattr(self.scale_trils, name) aux_value = aux_value @ scale_tril.T * scale if compute_density: log_density = ( -scale_tril.diagonal(dim1=-2, dim2=-1).log() - scale.log()).expand_as(aux_value) else: raise ValueError( f"Unsupported conditional type: {conditional}") # Accumulate upstream dependencies. # Note: by accumulating upstream dependencies before updating the # aux_values dict, we encode a block-sparse structure of the # precision matrix; if we had instead accumulated after updating # aux_values, we would encode a block-sparse structure of the # covariance matrix. # Note: these shear transforms have no effect on the Jacobian # determinant, and can therefore be excluded from the log_density # computation below, even for nonlinear dep(). deps = deep_getattr(self.deps, name) for upstream in self.dependencies.get(name, {}): dep = deep_getattr(deps, upstream) aux_value = aux_value + dep(aux_values[upstream]) aux_values[name] = aux_value # Shift by loc and reshape. batch_shape = torch.broadcast_shapes(aux_value.shape[:-1], self._batch_shapes[name]) unconstrained = ( aux_value + loc).reshape(batch_shape + self._unconstrained_event_shapes[name]) if not is_identically_zero(log_density): log_density = log_density.reshape(batch_shape + (-1, )).sum(-1) # Transform to constrained space. transform = biject_to(site["fn"].support) value = transform(unconstrained) if compute_density and conditional != "delta": assert transform.codomain.event_dim == site["fn"].event_dim log_density = log_density + transform.inv.log_abs_det_jacobian( value, unconstrained) # Create a reparametrized Delta distribution. deltas[name] = dist.Delta(value, log_density, site["fn"].event_dim) return deltas