Ejemplo n.º 1
0
def normalize_along_batch_dims(x, mean, variance, variance_epsilon):
    """Normalizes a tensor by ``mean`` and ``variance``, which are expected to have
    the same tensor spec with the inner dims of ``x``.

    Args:
        x (Tensor): a tensor of (``[D1, D2, ..] + shape``), where ``D1``, ``D2``, ..
            are arbitrary leading batch dims (can be empty).
        mean (Tensor): a tensor of ``shape``
        variance (Tensor): a tensor of ``shape``
        variance_epsilon (float): A small float number to avoid dividing by 0.
    Returns:
        Normalized tensor.
    """
    spec = TensorSpec.from_tensor(mean)
    assert spec == TensorSpec.from_tensor(variance), \
        "The specs of mean and variance must be equal!"

    bs = BatchSquash(get_outer_rank(x, spec))
    x = bs.flatten(x)

    variance_epsilon = torch.as_tensor(variance_epsilon).to(variance.dtype)
    inv = torch.rsqrt(variance + variance_epsilon)
    x = (x - mean.to(x.dtype)) * inv.to(x.dtype)

    x = bs.unflatten(x)
    return x
Ejemplo n.º 2
0
    def forward(self, inputs, state=(), min_outer_rank=1, max_outer_rank=1):
        """Preprocessing nested inputs.

        Args:
            inputs (nested Tensor): inputs to the network
            state (nested Tensor): RNN state of the network
            min_outer_rank (int): the minimal outer rank allowed
            max_outer_rank (int): the maximal outer rank allowed
        Returns:
            Tensor: tensor after preprocessing.
        """
        if self._input_preprocessors:
            inputs = alf.nest.map_structure(
                lambda preproc, tensor: preproc(tensor)[0],
                self._input_preprocessors, inputs)

        proc_inputs = self._preprocessing_combiner(inputs)
        outer_rank = get_outer_rank(proc_inputs,
                                    self._processed_input_tensor_spec)
        assert min_outer_rank <= outer_rank <= max_outer_rank, \
            ("Only supports {}<=outer_rank<={}! ".format(min_outer_rank, max_outer_rank)
            + "After preprocessing: inputs size {} vs. input tensor spec {}".format(
                proc_inputs.size(), self._processed_input_tensor_spec)
            + "\n Make sure that you have provided the right input preprocessors"
            + " and nest combiner!\n"
            + "Before preprocessing: inputs size {} vs. input tensor spec {}".format(
                alf.nest.map_structure(lambda tensor: tensor.size(), inputs),
                self._input_tensor_spec))
        return proc_inputs, state
Ejemplo n.º 3
0
def average_outer_dims(tensor, spec):
    """
    Args:
        tensor (Tensor): a single Tensor
        spec (TensorSpec):

    Returns:
        the average tensor across outer dims
    """
    outer_dims = get_outer_rank(tensor, spec)
    return tensor.mean(dim=list(range(outer_dims)))
Ejemplo n.º 4
0
    def _ml_pmi(self, x, y, y_distribution):
        num_outer_dims = get_outer_rank(x, self._x_spec)
        hidden = self._model(x)[0]
        batch_squash = BatchSquash(num_outer_dims)
        hidden = batch_squash.flatten(hidden)
        delta_loc = self._delta_loc_layer(hidden)
        delta_scale = F.softplus(self._delta_scale_layer(hidden))
        delta_loc = batch_squash.unflatten(delta_loc)
        delta_scale = batch_squash.unflatten(delta_scale)
        y_given_x_dist = DiagMultivariateNormal(
            loc=y_distribution.mean + delta_loc,
            scale=y_distribution.stddev * delta_scale)

        pmi = y_given_x_dist.log_prob(y) - y_distribution.log_prob(y).detach()
        return pmi
Ejemplo n.º 5
0
def scale_to_spec(tensor, spec: BoundedTensorSpec):
    """Shapes and scales a batch into the given spec bounds.

    Args:
        tensor: A tensor with values in the range of [-1, 1].
        spec: (BoundedTensorSpec) to use for scaling the input tensor.

    Returns:
        A batch scaled the given spec bounds.
    """
    bs = BatchSquash(get_outer_rank(tensor, spec))
    tensor = bs.flatten(tensor)
    means, magnitudes = spec_means_and_magnitudes(spec)
    tensor = means + magnitudes * tensor
    tensor = bs.unflatten(tensor)
    return tensor
Ejemplo n.º 6
0
 def _expand_to_replica(self, inputs, spec):
     """Expand the inputs of shape [B, ...] to [B, n, ...] if n > 1,
         where n is the number of replicas. When n = 1, the unexpanded
         inputs will be returned.
     Args:
         inputs (Tensor): the input tensor to be expanded
         spec (TensorSpec): the spec of the unexpanded inputs. It is used to
             determine whether the inputs is already an expanded one. If it
             is already expanded, inputs will be returned without any
             further processing.
     Returns:
         Tensor: the expaneded inputs or the original inputs.
     """
     outer_rank = get_outer_rank(inputs, spec)
     if outer_rank == 1 and self._num_replicas > 1:
         return inputs.unsqueeze(1).expand(-1, self._num_replicas,
                                           *inputs.shape[1:])
     else:
         return inputs
Ejemplo n.º 7
0
 def _reduce_along_batch_dims(x, mean, op):
     spec = TensorSpec.from_tensor(mean)
     bs = alf.layers.BatchSquash(get_outer_rank(x, spec))
     x = bs.flatten(x)
     x = op(x, dim=0)[0]
     return x
Ejemplo n.º 8
0
    def train_step(self, inputs, y_distribution=None, state=None):
        """Perform training on one batch of inputs.

        Args:
            inputs (tuple(nested Tensor, nested Tensor)): tuple of ``x`` and ``y``
            y_distribution (nested td.Distribution): distribution
                for the marginal distribution of ``y``. If None, will use the
                sampling method ``sampler`` provided at constructor to generate
                the samples for the marginal distribution of :math:`Y`.
            state: not used
        Returns:
            AlgStep:
            - outputs (Tensor): shape is ``[batch_size]``, its mean is the
              estimated MI for estimator 'KL', 'DV' and 'KLD', and
              Jensen-Shannon divergence for estimator 'JSD'
            - state: not used
            - info (LossInfo): ``info.loss`` is the loss
        """
        x, y = inputs

        if self._type == 'ML':
            return self._ml_step(x, y, y_distribution)

        num_outer_dims = get_outer_rank(x, self._x_spec)
        batch_squash = BatchSquash(num_outer_dims)
        x = batch_squash.flatten(x)
        y = batch_squash.flatten(y)
        if y_distribution is None:
            x1, y1 = self._sampler(x, y)
        else:
            x1 = x
            y1 = y_distribution.sample()
            y1 = batch_squash.flatten(y1)

        log_ratio = self._model([x, y])[0]
        t1 = self._model([x1, y1])[0]

        if self._type == 'DV':
            ratio = torch.min(t1, torch.tensor(20.)).exp()
            mean = ratio.mean().detach()
            if self._mean_averager:
                self._mean_averager.update(mean)
                unbiased_mean = self._mean_averager.get().detach()
            else:
                unbiased_mean = mean
            # estimated MI = reduce_mean(mi)
            # ratio/mean-1 does not contribute to the final estimated MI, since
            # mean(ratio/mean-1) = 0. We add it so that we can have an estimation
            # of the variance of the MI estimator
            mi = log_ratio - (mean.log() + ratio / mean - 1)
            loss = ratio / unbiased_mean - log_ratio
        elif self._type == 'KLD':
            ratio = torch.min(t1, torch.tensor(20.)).exp()
            mi = log_ratio - ratio + 1
            loss = -mi
        elif self._type == 'JSD':
            mi = -F.softplus(-log_ratio) - F.softplus(t1) + math.log(4)
            loss = -mi
        mi = batch_squash.unflatten(mi)
        loss = batch_squash.unflatten(loss)

        return AlgStep(output=mi, state=(), info=LossInfo(loss, extra=()))
Ejemplo n.º 9
0
 def _reduce_along_batch_dims(x, spec, op):
     bs = alf.layers.BatchSquash(get_outer_rank(x, spec))
     x = bs.flatten(x)
     x = op(x, dim=0)[0]
     return x
Ejemplo n.º 10
0
 def _preprocess(self, tensor):
     assert get_outer_rank(tensor, self._input_tensor_spec) == 1, \
         "Only supports one outer rank (batch dim)!"
     ret = self._embedding_net(tensor)
     # EncodingNetwork returns a pair
     return (ret if self._input_tensor_spec.is_discrete else ret[0])