def call(self, x): x_scopes_first = tf.transpose(x, (1, 2, 0, 3)) log_weights_unnormalized = self.accumulators if not self.logspace_accumulators \ and self.backprop_mode in [BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED]: out_scopes_first = logmatmul_hard_em_through_grads_from_accumulators( x_scopes_first, self.accumulators, unweighted=self.backprop_mode == BackpropMode.HARD_EM_UNWEIGHTED) return tf.transpose(out_scopes_first, (2, 0, 1, 3)) if not self.logspace_accumulators and self.backprop_mode == BackpropMode.EM: log_weights_normalized = log_softmax_from_accumulators_with_em_grad( self.accumulators, axis=2) elif not self.logspace_accumulators: log_weights_normalized = tf.nn.log_softmax( tf.math.log(log_weights_unnormalized), axis=2) else: log_weights_normalized = tf.nn.log_softmax( log_weights_unnormalized, axis=2) out_scopes_first = logmatmul(x_scopes_first, log_weights_normalized) return tf.transpose(out_scopes_first, (2, 0, 1, 3))
def call(self, x): log_weights_unnormalized = self.accumulators x_squeezed = tf.reshape(x, (-1, self._num_nodes_in)) if not self.logspace_accumulators: if self.backprop_mode in [ BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED ]: if self.return_weighted_child_logits: return logmultiply_hard_em(x_squeezed, self.accumulators) logmatmul_out = logmatmul_hard_em_through_grads_from_accumulators( tf.reshape(x, (1, 1, -1, self._num_nodes_in)), tf.reshape(self.accumulators, (1, 1, self._num_nodes_in, 1)), unweighted=self.backprop_mode == BackpropMode.HARD_EM_UNWEIGHTED) return tf.reshape(logmatmul_out, (-1, 1)) log_weights_unnormalized = tf.math.log(log_weights_unnormalized) if self.backprop_mode == BackpropMode.EM: log_weights_normalized = log_softmax_from_accumulators_with_em_grad( self.accumulators, axis=0) else: log_weights_normalized = tf.nn.log_softmax( log_weights_unnormalized, axis=0) if self.return_weighted_child_logits: return tf.expand_dims(log_weights_normalized, axis=0) + x_squeezed else: return logmatmul(x_squeezed, tf.expand_dims(log_weights_normalized, axis=1))
def call(self, x): log_weights_unnormalized = self._accumulators if not self.logspace_accumulators and \ self.backprop_mode in [BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED]: return logmatmul_hard_em_through_grads_from_accumulators( x, self._accumulators, unweighted=self.backprop_mode == BackpropMode.HARD_EM_UNWEIGHTED ) if not self.logspace_accumulators and self.backprop_mode == BackpropMode.EM: log_weights_normalized = log_softmax_from_accumulators_with_em_grad( self._accumulators, axis=2) elif not self.logspace_accumulators: log_weights_normalized = tf.nn.log_softmax(tf.math.log(log_weights_unnormalized), axis=2) else: log_weights_normalized = tf.nn.log_softmax(log_weights_unnormalized, axis=2) return logmatmul(x, log_weights_normalized)