Exemplo n.º 1
0
    def _compute_mpe_path_common(self,
                                 reducible_tensor,
                                 counts,
                                 w_tensor,
                                 latent_indicators_tensor,
                                 *input_tensors,
                                 accumulate_weights_batch=False,
                                 sample=False,
                                 sample_prob=None):
        sample_prob = utils.maybe_first(sample_prob, self._sample_prob)
        num_samples = 1 if reducible_tensor.shape[1] != 1 else self._num_sums
        if sample:
            max_indices = self._reduce_sample_log(reducible_tensor,
                                                  sample_prob=sample_prob,
                                                  num_samples=num_samples)
        else:
            max_indices = self._reduce_argmax(reducible_tensor,
                                              num_samples=num_samples)
        max_counts = utils.scatter_values(params=counts,
                                          indices=max_indices,
                                          num_out_cols=self._max_sum_size)
        max_counts_split = self._accumulate_and_split_to_children(
            max_counts, *input_tensors)
        if accumulate_weights_batch:
            w_counts = tf.reduce_sum(max_counts, axis=self._batch_axis)
        else:
            w_counts = max_counts

        return self._scatter_to_input_tensors(
            (w_counts, w_tensor),  # Weights
            (max_counts, latent_indicators_tensor)) + tuple(max_counts_split)
Exemplo n.º 2
0
    def _compute_mpe_path_common(self,
                                 reducible_tensor,
                                 counts,
                                 w_tensor,
                                 latent_indicators_tensor,
                                 *input_tensors,
                                 sample=False,
                                 sample_prob=None,
                                 accumulate_weights_batch=False):
        """Common operations for computing the MPE path.

        Args:
            reducible_tensor (Tensor): A (weighted) ``Tensor`` of (log-)values of this node.
            counts (Tensor): A ``Tensor`` that contains the accumulated counts of the parents
                             of this node.
            w_tensor (Tensor):  A ``Tensor`` containing the (log-)value of the weights.
            latent_indicators_tensor (Tensor): A ``Tensor`` containing the (log-)value of the IndicatorLeaf.
            input_tensors (list): A list of ``Tensor``s with outputs of the child nodes.
            log (bool): Whether the computation is in log-space or not
            sample (bool): Whether to sample the 'winner' of the max or not
            sample_prob (Tensor): A scalar ``Tensor`` indicating the probability of drawing
                a sample. If a sample is drawn, the probability for each index is given by the
                (log-)normalized probability as given by ``reducible_tensor``.
        Returns:
            A ``list`` of ``tuple``s [(MPE counts, input tensor), ...] where the first corresponds
            to the Weights of this node, the second corresponds to the IndicatorLeaf and the remaining
            tuples correspond to the nodes in ``self._values``.
        """
        sample_prob = utils.maybe_first(sample_prob, self._sample_prob)
        num_samples = 1 if reducible_tensor.shape[
            self._reduce_axis] != 1 else self._num_sums
        if sample:
            max_indices = self._reduce_sample_log(reducible_tensor,
                                                  sample_prob=sample_prob,
                                                  num_samples=num_samples)
        else:
            max_indices = self._reduce_argmax(reducible_tensor,
                                              num_samples=num_samples)
        max_counts = utils.scatter_values(params=counts,
                                          indices=max_indices,
                                          num_out_cols=self._max_sum_size)
        max_counts_acc, max_counts_split = self._accumulate_and_split_to_children(
            max_counts, *input_tensors)
        if accumulate_weights_batch:
            max_counts = tf.reduce_sum(max_counts, axis=0, keepdims=False)
        return self._scatter_to_input_tensors(
            (max_counts, w_tensor),  # Weights
            (max_counts_acc, latent_indicators_tensor),  # IndicatorLeaf
            *[(t, v)
              for t, v in zip(max_counts_split, input_tensors)])  # Values
Exemplo n.º 3
0
    def _compute_mpe_path_common(
            self, reducible_log_prob, counts, w_log_prob, latent_indicator_log_prob, *child_log_prob,
            sample=False, sample_prob=None, accumulate_weights_batch=False, use_unweighted=False):
        """Common operations for computing the MPE path.

        Args:
            reducible_log_prob (Tensor): A (weighted) ``Tensor`` of (log-)values of this container.
            counts (Tensor): A ``Tensor`` that contains the accumulated counts of the parents
                             of this container.
            w_log_prob (Tensor):  A ``Tensor`` containing the (log-)value of the weights.
            latent_indicator_log_prob (Tensor): A ``Tensor`` containing the logit of the
                latent indicators.
            child_log_prob (list): A list of ``Tensor``s with outputs of the child nodes.

        Returns:
            A ``list`` of ``tuple``s [(MPE counts, input tensor), ...] where the first corresponds
            to the Weights of this container, the second corresponds to the latent indicators and
            the remaining tuples correspond to the nodes in ``self._values``.
        """
        sample_prob = utils.maybe_first(sample_prob, self._sample_prob)
        num_samples = 1 if reducible_log_prob.shape[self._channel_axis] != 1 else self._num_channels
        if sample:
            max_indices = self._reduce_sample_log(
                reducible_log_prob, sample_prob=sample_prob, num_samples=num_samples)
        else:
            max_indices = self._reduce_argmax(reducible_log_prob, num_samples=num_samples)
        max_indices = tf.reshape(max_indices, (-1, self._compute_out_size()))
        max_counts = utils.scatter_values(
            params=counts, indices=max_indices, num_out_cols=self._max_sum_size)
        weight_counts, input_counts = self._accumulate_and_split_to_children(max_counts)

        if accumulate_weights_batch:
            weight_counts = tf.reduce_sum(weight_counts, axis=0, keepdims=False)
        return self._scatter_to_input_tensors(
            (weight_counts, w_log_prob),  # Weights
            (max_counts, latent_indicator_log_prob),  # Latent indicators
            *[(t, v) for t, v in zip(input_counts, child_log_prob)])  # Values