def normalize_along_batch_dims(x, mean, variance, variance_epsilon): """Normalizes a tensor by ``mean`` and ``variance``, which are expected to have the same tensor spec with the inner dims of ``x``. Args: x (Tensor): a tensor of (``[D1, D2, ..] + shape``), where ``D1``, ``D2``, .. are arbitrary leading batch dims (can be empty). mean (Tensor): a tensor of ``shape`` variance (Tensor): a tensor of ``shape`` variance_epsilon (float): A small float number to avoid dividing by 0. Returns: Normalized tensor. """ spec = TensorSpec.from_tensor(mean) assert spec == TensorSpec.from_tensor(variance), \ "The specs of mean and variance must be equal!" bs = BatchSquash(get_outer_rank(x, spec)) x = bs.flatten(x) variance_epsilon = torch.as_tensor(variance_epsilon).to(variance.dtype) inv = torch.rsqrt(variance + variance_epsilon) x = (x - mean.to(x.dtype)) * inv.to(x.dtype) x = bs.unflatten(x) return x
def forward(self, inputs, state=(), min_outer_rank=1, max_outer_rank=1): """Preprocessing nested inputs. Args: inputs (nested Tensor): inputs to the network state (nested Tensor): RNN state of the network min_outer_rank (int): the minimal outer rank allowed max_outer_rank (int): the maximal outer rank allowed Returns: Tensor: tensor after preprocessing. """ if self._input_preprocessors: inputs = alf.nest.map_structure( lambda preproc, tensor: preproc(tensor)[0], self._input_preprocessors, inputs) proc_inputs = self._preprocessing_combiner(inputs) outer_rank = get_outer_rank(proc_inputs, self._processed_input_tensor_spec) assert min_outer_rank <= outer_rank <= max_outer_rank, \ ("Only supports {}<=outer_rank<={}! ".format(min_outer_rank, max_outer_rank) + "After preprocessing: inputs size {} vs. input tensor spec {}".format( proc_inputs.size(), self._processed_input_tensor_spec) + "\n Make sure that you have provided the right input preprocessors" + " and nest combiner!\n" + "Before preprocessing: inputs size {} vs. input tensor spec {}".format( alf.nest.map_structure(lambda tensor: tensor.size(), inputs), self._input_tensor_spec)) return proc_inputs, state
def average_outer_dims(tensor, spec): """ Args: tensor (Tensor): a single Tensor spec (TensorSpec): Returns: the average tensor across outer dims """ outer_dims = get_outer_rank(tensor, spec) return tensor.mean(dim=list(range(outer_dims)))
def _ml_pmi(self, x, y, y_distribution): num_outer_dims = get_outer_rank(x, self._x_spec) hidden = self._model(x)[0] batch_squash = BatchSquash(num_outer_dims) hidden = batch_squash.flatten(hidden) delta_loc = self._delta_loc_layer(hidden) delta_scale = F.softplus(self._delta_scale_layer(hidden)) delta_loc = batch_squash.unflatten(delta_loc) delta_scale = batch_squash.unflatten(delta_scale) y_given_x_dist = DiagMultivariateNormal( loc=y_distribution.mean + delta_loc, scale=y_distribution.stddev * delta_scale) pmi = y_given_x_dist.log_prob(y) - y_distribution.log_prob(y).detach() return pmi
def scale_to_spec(tensor, spec: BoundedTensorSpec): """Shapes and scales a batch into the given spec bounds. Args: tensor: A tensor with values in the range of [-1, 1]. spec: (BoundedTensorSpec) to use for scaling the input tensor. Returns: A batch scaled the given spec bounds. """ bs = BatchSquash(get_outer_rank(tensor, spec)) tensor = bs.flatten(tensor) means, magnitudes = spec_means_and_magnitudes(spec) tensor = means + magnitudes * tensor tensor = bs.unflatten(tensor) return tensor
def _expand_to_replica(self, inputs, spec): """Expand the inputs of shape [B, ...] to [B, n, ...] if n > 1, where n is the number of replicas. When n = 1, the unexpanded inputs will be returned. Args: inputs (Tensor): the input tensor to be expanded spec (TensorSpec): the spec of the unexpanded inputs. It is used to determine whether the inputs is already an expanded one. If it is already expanded, inputs will be returned without any further processing. Returns: Tensor: the expaneded inputs or the original inputs. """ outer_rank = get_outer_rank(inputs, spec) if outer_rank == 1 and self._num_replicas > 1: return inputs.unsqueeze(1).expand(-1, self._num_replicas, *inputs.shape[1:]) else: return inputs
def _reduce_along_batch_dims(x, mean, op): spec = TensorSpec.from_tensor(mean) bs = alf.layers.BatchSquash(get_outer_rank(x, spec)) x = bs.flatten(x) x = op(x, dim=0)[0] return x
def train_step(self, inputs, y_distribution=None, state=None): """Perform training on one batch of inputs. Args: inputs (tuple(nested Tensor, nested Tensor)): tuple of ``x`` and ``y`` y_distribution (nested td.Distribution): distribution for the marginal distribution of ``y``. If None, will use the sampling method ``sampler`` provided at constructor to generate the samples for the marginal distribution of :math:`Y`. state: not used Returns: AlgStep: - outputs (Tensor): shape is ``[batch_size]``, its mean is the estimated MI for estimator 'KL', 'DV' and 'KLD', and Jensen-Shannon divergence for estimator 'JSD' - state: not used - info (LossInfo): ``info.loss`` is the loss """ x, y = inputs if self._type == 'ML': return self._ml_step(x, y, y_distribution) num_outer_dims = get_outer_rank(x, self._x_spec) batch_squash = BatchSquash(num_outer_dims) x = batch_squash.flatten(x) y = batch_squash.flatten(y) if y_distribution is None: x1, y1 = self._sampler(x, y) else: x1 = x y1 = y_distribution.sample() y1 = batch_squash.flatten(y1) log_ratio = self._model([x, y])[0] t1 = self._model([x1, y1])[0] if self._type == 'DV': ratio = torch.min(t1, torch.tensor(20.)).exp() mean = ratio.mean().detach() if self._mean_averager: self._mean_averager.update(mean) unbiased_mean = self._mean_averager.get().detach() else: unbiased_mean = mean # estimated MI = reduce_mean(mi) # ratio/mean-1 does not contribute to the final estimated MI, since # mean(ratio/mean-1) = 0. We add it so that we can have an estimation # of the variance of the MI estimator mi = log_ratio - (mean.log() + ratio / mean - 1) loss = ratio / unbiased_mean - log_ratio elif self._type == 'KLD': ratio = torch.min(t1, torch.tensor(20.)).exp() mi = log_ratio - ratio + 1 loss = -mi elif self._type == 'JSD': mi = -F.softplus(-log_ratio) - F.softplus(t1) + math.log(4) loss = -mi mi = batch_squash.unflatten(mi) loss = batch_squash.unflatten(loss) return AlgStep(output=mi, state=(), info=LossInfo(loss, extra=()))
def _reduce_along_batch_dims(x, spec, op): bs = alf.layers.BatchSquash(get_outer_rank(x, spec)) x = bs.flatten(x) x = op(x, dim=0)[0] return x
def _preprocess(self, tensor): assert get_outer_rank(tensor, self._input_tensor_spec) == 1, \ "Only supports one outer rank (batch dim)!" ret = self._embedding_net(tensor) # EncodingNetwork returns a pair return (ret if self._input_tensor_spec.is_discrete else ret[0])