def normalize_along_batch_dims(x, mean, variance, variance_epsilon): """Normalizes a tensor by ``mean`` and ``variance``, which are expected to have the same tensor spec with the inner dims of ``x``. Args: x (Tensor): a tensor of (``[D1, D2, ..] + shape``), where ``D1``, ``D2``, .. are arbitrary leading batch dims (can be empty). mean (Tensor): a tensor of ``shape`` variance (Tensor): a tensor of ``shape`` variance_epsilon (float): A small float number to avoid dividing by 0. Returns: Normalized tensor. """ spec = TensorSpec.from_tensor(mean) assert spec == TensorSpec.from_tensor(variance), \ "The specs of mean and variance must be equal!" bs = BatchSquash(get_outer_rank(x, spec)) x = bs.flatten(x) variance_epsilon = torch.as_tensor(variance_epsilon).to(variance.dtype) inv = torch.rsqrt(variance + variance_epsilon) x = (x - mean.to(x.dtype)) * inv.to(x.dtype) x = bs.unflatten(x) return x
def _extract_spec(obj): if isinstance(obj, torch.Tensor): return TensorSpec.from_tensor(obj, from_dim) elif isinstance(obj, td.Distribution): return DistributionSpec.from_distribution(obj, from_dim) else: raise ValueError("Unsupported value type: %s" % type(obj))
def _summarize_all(path, t, m2, m): if path: path += "." spec = TensorSpec.from_tensor(m2 or m) _summary(path + "tensor.batch_min", _reduce_along_batch_dims(t, spec, torch.min)) _summary(path + "tensor.batch_max", _reduce_along_batch_dims(t, spec, torch.max)) if m is not None: _summary(path + "mean", m) if m2 is not None: _summary(path + "var", m2 - math_ops.square(m)) elif m2 is not None: _summary(path + "second_moment", m2)
def __call__(self, nested): """Combine all elements according to the method defined in ``combine_flat``. Args: nested (nest): a nested structure; each element can be either a ``Tensor` or a `TensorSpec``. Returns: Tensor or TensorSpec: if ``Tensor``, the returned is the concatenated result; otherwise it's the tensor spec of the result. """ flat = nest.flatten(nested) assert len(flat) > 0, "The nest is empty!" if isinstance(flat[0], TensorSpec): tensors = nest.map_structure( lambda spec: spec.zeros(outer_dims=(1, )), flat) else: tensors = flat ret = self._combine_flat(tensors) if isinstance(flat[0], TensorSpec): return TensorSpec.from_tensor(ret, from_dim=1) return ret
def _reduce_along_batch_dims(x, mean, op): spec = TensorSpec.from_tensor(mean) bs = alf.layers.BatchSquash(get_outer_rank(x, spec)) x = bs.flatten(x) x = op(x, dim=0)[0] return x