Example #1
0
    def call(self, x):

        x_scopes_first = tf.transpose(x, (1, 2, 0, 3))

        log_weights_unnormalized = self.accumulators

        if not self.logspace_accumulators \
                and self.backprop_mode in [BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED]:
            out_scopes_first = logmatmul_hard_em_through_grads_from_accumulators(
                x_scopes_first,
                self.accumulators,
                unweighted=self.backprop_mode ==
                BackpropMode.HARD_EM_UNWEIGHTED)
            return tf.transpose(out_scopes_first, (2, 0, 1, 3))

        if not self.logspace_accumulators and self.backprop_mode == BackpropMode.EM:
            log_weights_normalized = log_softmax_from_accumulators_with_em_grad(
                self.accumulators, axis=2)
        elif not self.logspace_accumulators:
            log_weights_normalized = tf.nn.log_softmax(
                tf.math.log(log_weights_unnormalized), axis=2)
        else:
            log_weights_normalized = tf.nn.log_softmax(
                log_weights_unnormalized, axis=2)

        out_scopes_first = logmatmul(x_scopes_first, log_weights_normalized)

        return tf.transpose(out_scopes_first, (2, 0, 1, 3))
Example #2
0
    def call(self, x):
        log_weights_unnormalized = self.accumulators
        x_squeezed = tf.reshape(x, (-1, self._num_nodes_in))
        if not self.logspace_accumulators:

            if self.backprop_mode in [
                    BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED
            ]:
                if self.return_weighted_child_logits:
                    return logmultiply_hard_em(x_squeezed, self.accumulators)

                logmatmul_out = logmatmul_hard_em_through_grads_from_accumulators(
                    tf.reshape(x, (1, 1, -1, self._num_nodes_in)),
                    tf.reshape(self.accumulators,
                               (1, 1, self._num_nodes_in, 1)),
                    unweighted=self.backprop_mode ==
                    BackpropMode.HARD_EM_UNWEIGHTED)
                return tf.reshape(logmatmul_out, (-1, 1))

            log_weights_unnormalized = tf.math.log(log_weights_unnormalized)

        if self.backprop_mode == BackpropMode.EM:
            log_weights_normalized = log_softmax_from_accumulators_with_em_grad(
                self.accumulators, axis=0)
        else:
            log_weights_normalized = tf.nn.log_softmax(
                log_weights_unnormalized, axis=0)

        if self.return_weighted_child_logits:
            return tf.expand_dims(log_weights_normalized, axis=0) + x_squeezed
        else:
            return logmatmul(x_squeezed,
                             tf.expand_dims(log_weights_normalized, axis=1))
Example #3
0
    def call(self, x):
        log_weights_unnormalized = self._accumulators

        if not self.logspace_accumulators and \
                self.backprop_mode in [BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED]:
            return logmatmul_hard_em_through_grads_from_accumulators(
                x, self._accumulators,
                unweighted=self.backprop_mode == BackpropMode.HARD_EM_UNWEIGHTED
            )

        if not self.logspace_accumulators and self.backprop_mode == BackpropMode.EM:
            log_weights_normalized = log_softmax_from_accumulators_with_em_grad(
                self._accumulators, axis=2)
        elif not self.logspace_accumulators:
            log_weights_normalized = tf.nn.log_softmax(tf.math.log(log_weights_unnormalized), axis=2)
        else:
            log_weights_normalized = tf.nn.log_softmax(log_weights_unnormalized, axis=2)

        return logmatmul(x, log_weights_normalized)