Пример #1
0
    def call(self, x):

        x_scopes_first = tf.transpose(x, (1, 2, 0, 3))

        log_weights_unnormalized = self.accumulators

        if not self.logspace_accumulators \
                and self.backprop_mode in [BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED]:
            out_scopes_first = logmatmul_hard_em_through_grads_from_accumulators(
                x_scopes_first,
                self.accumulators,
                unweighted=self.backprop_mode ==
                BackpropMode.HARD_EM_UNWEIGHTED)
            return tf.transpose(out_scopes_first, (2, 0, 1, 3))

        if not self.logspace_accumulators and self.backprop_mode == BackpropMode.EM:
            log_weights_normalized = log_softmax_from_accumulators_with_em_grad(
                self.accumulators, axis=2)
        elif not self.logspace_accumulators:
            log_weights_normalized = tf.nn.log_softmax(
                tf.math.log(log_weights_unnormalized), axis=2)
        else:
            log_weights_normalized = tf.nn.log_softmax(
                log_weights_unnormalized, axis=2)

        out_scopes_first = logmatmul(x_scopes_first, log_weights_normalized)

        return tf.transpose(out_scopes_first, (2, 0, 1, 3))
Пример #2
0
    def call(self, x):
        log_weights_unnormalized = self.accumulators
        x_squeezed = tf.reshape(x, (-1, self._num_nodes_in))
        if not self.logspace_accumulators:

            if self.backprop_mode in [
                    BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED
            ]:
                if self.return_weighted_child_logits:
                    return logmultiply_hard_em(x_squeezed, self.accumulators)

                logmatmul_out = logmatmul_hard_em_through_grads_from_accumulators(
                    tf.reshape(x, (1, 1, -1, self._num_nodes_in)),
                    tf.reshape(self.accumulators,
                               (1, 1, self._num_nodes_in, 1)),
                    unweighted=self.backprop_mode ==
                    BackpropMode.HARD_EM_UNWEIGHTED)
                return tf.reshape(logmatmul_out, (-1, 1))

            log_weights_unnormalized = tf.math.log(log_weights_unnormalized)

        if self.backprop_mode == BackpropMode.EM:
            log_weights_normalized = log_softmax_from_accumulators_with_em_grad(
                self.accumulators, axis=0)
        else:
            log_weights_normalized = tf.nn.log_softmax(
                log_weights_unnormalized, axis=0)

        if self.return_weighted_child_logits:
            return tf.expand_dims(log_weights_normalized, axis=0) + x_squeezed
        else:
            return logmatmul(x_squeezed,
                             tf.expand_dims(log_weights_normalized, axis=1))
Пример #3
0
    def weighted_sum(
        self,
        x: tf.Tensor,
        accumulators: tf.Tensor,
        logspace_accumulators: bool,
        normalize_in_forward_pass: bool,
    ) -> tf.Tensor:
        """
        Compute a weighted sum.

        Args:
            x: Input Tensor
            accumulators: Accumulators, can be seen as unnormalized representations of weights.
            logspace_accumulators: Whether or not accumulators are represented in logspace.
            normalize_in_forward_pass: Whether weights should be normalized during forward inference.

        Returns:
            A Tensor with the weighted sums.

        Raises:
            NotImplementedError: When called with ``losgpace_accumulators == True``.
        """
        if logspace_accumulators:
            raise NotImplementedError(
                "EM is only implemented for linear space accumulators")
        w = self._to_logspace_override_grad(accumulators,
                                            normalize_in_forward_pass)
        return logmatmul(x, w)
        def _inner_fn(
            x: tf.Tensor, accumulators: tf.Tensor
        ) -> Tuple[tf.Tensor, Callable[[tf.Tensor], Tuple[tf.Tensor, tf.Tensor]]]:

            # Normalized
            weights = (
                self._to_log_weights(accumulators)
                if normalize_in_forward_pass
                else tf.math.log(accumulators)
            )

            out = logmatmul(x, weights)

            def grad(parent_counts: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
                # Determine winning child
                max_child = tf.reduce_max(x, axis=-1, keepdims=True)
                equal_to_max = tf.cast(tf.equal(max_child, x), tf.float32)

                # Holds the index of the winning child per sum
                if self.sample_prob is not None:
                    equal_to_max = (
                        tf.math.log(
                            self.sample_prob * tf.exp(x - max_child)
                            + (1.0 - self.sample_prob) * equal_to_max
                        )
                        + max_child
                    )
                else:
                    equal_to_max = tf.math.log(equal_to_max)

                num_in = tf.shape(x)[-1]
                equal_to_max_flat_outer = tf.reshape(
                    equal_to_max, tf.concat([[-1], [num_in]], axis=0)
                )
                winning_child_per_scope = tf.reshape(
                    tf.random.categorical(equal_to_max_flat_outer, num_samples=1),
                    tf.shape(x)[:-1],
                )

                sum_parent_counts = tf.reduce_sum(parent_counts, axis=-1, keepdims=True)

                winning_child_per_scope_one_hot = tf.one_hot(
                    winning_child_per_scope, depth=num_in, axis=-1
                )
                child_counts = winning_child_per_scope_one_hot * sum_parent_counts

                weight_counts = tf.matmul(
                    winning_child_per_scope_one_hot, parent_counts, transpose_a=True
                )
                weight_counts = tf.reduce_sum(weight_counts, axis=[0, 1], keepdims=True)
                return child_counts, weight_counts

            return out, grad
Пример #5
0
    def call(self, x):
        log_weights_unnormalized = self._accumulators

        if not self.logspace_accumulators and \
                self.backprop_mode in [BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED]:
            return logmatmul_hard_em_through_grads_from_accumulators(
                x, self._accumulators,
                unweighted=self.backprop_mode == BackpropMode.HARD_EM_UNWEIGHTED
            )

        if not self.logspace_accumulators and self.backprop_mode == BackpropMode.EM:
            log_weights_normalized = log_softmax_from_accumulators_with_em_grad(
                self._accumulators, axis=2)
        elif not self.logspace_accumulators:
            log_weights_normalized = tf.nn.log_softmax(tf.math.log(log_weights_unnormalized), axis=2)
        else:
            log_weights_normalized = tf.nn.log_softmax(log_weights_unnormalized, axis=2)

        return logmatmul(x, log_weights_normalized)
Пример #6
0
    def weighted_sum(
        self,
        x: tf.Tensor,
        accumulators: tf.Tensor,
        logspace_accumulators: bool,
        normalize_in_forward_pass: bool,
    ) -> tf.Tensor:
        """
        Compute a weighted sum.

        Args:
            x: Input Tensor
            accumulators: Accumulators, can be seen as unnormalized representations of weights.
            logspace_accumulators: Whether or not accumulators are represented in logspace.
            normalize_in_forward_pass: Whether weights should be normalized during forward inference.

        Returns:
            A Tensor with the weighted sums.
        """
        w = self._weights_in_logspace(accumulators, logspace_accumulators,
                                      normalize_in_forward_pass)
        return logmatmul(x, w)
Пример #7
0
    def _inner_fn(child_log_prob, linear_accumulators):

        log_accumulators = tf.math.log(linear_accumulators)

        # Normalized
        weights = tf.nn.log_softmax(log_accumulators, axis=2)

        if unweighted:
            pairwise_product_backprop = tf.expand_dims(child_log_prob, axis=3)
            out = logmatmul(child_log_prob, weights)
        else:
            # Pairwise product in forward pass
            # [scopes, decomps, batch, 1, num_in]
            child_log_prob = tf.expand_dims(child_log_prob, axis=3)
            # [scopes, decomps, 1, num_out, num_in]
            weights = tf.expand_dims(tf.transpose(weights, (0, 1, 3, 2)),
                                     axis=2)

            pairwise_product = child_log_prob + weights

            # Max per sum for determining winning child + choosing the constant for numerical
            # stability
            max_per_sum = tf.stop_gradient(
                tf.reduce_max(pairwise_product, axis=-1, keepdims=True))
            pairwise_product_backprop = child_log_prob + weights

            # Perform log(sum(exp(...))) with the numerical stability trick
            out = tf.math.log(
                tf.reduce_sum(tf.exp(pairwise_product - max_per_sum),
                              axis=-1)) + tf.squeeze(max_per_sum, axis=-1)

        def grad(dy):
            # Determine winning child
            if unweighted:
                max_per_sum_backprop = tf.reduce_max(pairwise_product_backprop,
                                                     axis=-1,
                                                     keepdims=True)
                equal_to_max = tf.cast(
                    tf.equal(pairwise_product_backprop, max_per_sum_backprop),
                    tf.float32)
            else:
                equal_to_max = tf.cast(
                    tf.equal(pairwise_product_backprop, max_per_sum),
                    tf.float32)

            num_in = tf.shape(child_log_prob)[-1]
            num_out = tf.shape(out)[-1]
            equal_to_max_flat_outer = tf.reshape(
                equal_to_max, tf.concat([[-1], [num_in]], axis=0))

            # Holds the index of the winning child per sum
            num_samples = num_out if unweighted else 1
            winning_child_per_sum = tf.reshape(
                tf.random.categorical(tf.math.log(equal_to_max_flat_outer),
                                      num_samples=num_samples), tf.shape(out))

            # Pass on the counts to the edges between child and parent
            edge_counts = tf.expand_dims(dy, -1) * tf.one_hot(
                winning_child_per_sum, depth=num_in)

            # Sum over parents to get counts per child
            child_counts = tf.reduce_sum(edge_counts, axis=3)

            # Sum over batch to get counts per weight
            weight_counts = tf.reduce_sum(edge_counts, axis=2)

            return child_counts, tf.transpose(weight_counts, (0, 1, 3, 2))

        return out, grad