def call(self, x): x_scopes_first = tf.transpose(x, (1, 2, 0, 3)) log_weights_unnormalized = self.accumulators if not self.logspace_accumulators \ and self.backprop_mode in [BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED]: out_scopes_first = logmatmul_hard_em_through_grads_from_accumulators( x_scopes_first, self.accumulators, unweighted=self.backprop_mode == BackpropMode.HARD_EM_UNWEIGHTED) return tf.transpose(out_scopes_first, (2, 0, 1, 3)) if not self.logspace_accumulators and self.backprop_mode == BackpropMode.EM: log_weights_normalized = log_softmax_from_accumulators_with_em_grad( self.accumulators, axis=2) elif not self.logspace_accumulators: log_weights_normalized = tf.nn.log_softmax( tf.math.log(log_weights_unnormalized), axis=2) else: log_weights_normalized = tf.nn.log_softmax( log_weights_unnormalized, axis=2) out_scopes_first = logmatmul(x_scopes_first, log_weights_normalized) return tf.transpose(out_scopes_first, (2, 0, 1, 3))
def call(self, x): log_weights_unnormalized = self.accumulators x_squeezed = tf.reshape(x, (-1, self._num_nodes_in)) if not self.logspace_accumulators: if self.backprop_mode in [ BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED ]: if self.return_weighted_child_logits: return logmultiply_hard_em(x_squeezed, self.accumulators) logmatmul_out = logmatmul_hard_em_through_grads_from_accumulators( tf.reshape(x, (1, 1, -1, self._num_nodes_in)), tf.reshape(self.accumulators, (1, 1, self._num_nodes_in, 1)), unweighted=self.backprop_mode == BackpropMode.HARD_EM_UNWEIGHTED) return tf.reshape(logmatmul_out, (-1, 1)) log_weights_unnormalized = tf.math.log(log_weights_unnormalized) if self.backprop_mode == BackpropMode.EM: log_weights_normalized = log_softmax_from_accumulators_with_em_grad( self.accumulators, axis=0) else: log_weights_normalized = tf.nn.log_softmax( log_weights_unnormalized, axis=0) if self.return_weighted_child_logits: return tf.expand_dims(log_weights_normalized, axis=0) + x_squeezed else: return logmatmul(x_squeezed, tf.expand_dims(log_weights_normalized, axis=1))
def weighted_sum( self, x: tf.Tensor, accumulators: tf.Tensor, logspace_accumulators: bool, normalize_in_forward_pass: bool, ) -> tf.Tensor: """ Compute a weighted sum. Args: x: Input Tensor accumulators: Accumulators, can be seen as unnormalized representations of weights. logspace_accumulators: Whether or not accumulators are represented in logspace. normalize_in_forward_pass: Whether weights should be normalized during forward inference. Returns: A Tensor with the weighted sums. Raises: NotImplementedError: When called with ``losgpace_accumulators == True``. """ if logspace_accumulators: raise NotImplementedError( "EM is only implemented for linear space accumulators") w = self._to_logspace_override_grad(accumulators, normalize_in_forward_pass) return logmatmul(x, w)
def _inner_fn( x: tf.Tensor, accumulators: tf.Tensor ) -> Tuple[tf.Tensor, Callable[[tf.Tensor], Tuple[tf.Tensor, tf.Tensor]]]: # Normalized weights = ( self._to_log_weights(accumulators) if normalize_in_forward_pass else tf.math.log(accumulators) ) out = logmatmul(x, weights) def grad(parent_counts: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: # Determine winning child max_child = tf.reduce_max(x, axis=-1, keepdims=True) equal_to_max = tf.cast(tf.equal(max_child, x), tf.float32) # Holds the index of the winning child per sum if self.sample_prob is not None: equal_to_max = ( tf.math.log( self.sample_prob * tf.exp(x - max_child) + (1.0 - self.sample_prob) * equal_to_max ) + max_child ) else: equal_to_max = tf.math.log(equal_to_max) num_in = tf.shape(x)[-1] equal_to_max_flat_outer = tf.reshape( equal_to_max, tf.concat([[-1], [num_in]], axis=0) ) winning_child_per_scope = tf.reshape( tf.random.categorical(equal_to_max_flat_outer, num_samples=1), tf.shape(x)[:-1], ) sum_parent_counts = tf.reduce_sum(parent_counts, axis=-1, keepdims=True) winning_child_per_scope_one_hot = tf.one_hot( winning_child_per_scope, depth=num_in, axis=-1 ) child_counts = winning_child_per_scope_one_hot * sum_parent_counts weight_counts = tf.matmul( winning_child_per_scope_one_hot, parent_counts, transpose_a=True ) weight_counts = tf.reduce_sum(weight_counts, axis=[0, 1], keepdims=True) return child_counts, weight_counts return out, grad
def call(self, x): log_weights_unnormalized = self._accumulators if not self.logspace_accumulators and \ self.backprop_mode in [BackpropMode.HARD_EM, BackpropMode.HARD_EM_UNWEIGHTED]: return logmatmul_hard_em_through_grads_from_accumulators( x, self._accumulators, unweighted=self.backprop_mode == BackpropMode.HARD_EM_UNWEIGHTED ) if not self.logspace_accumulators and self.backprop_mode == BackpropMode.EM: log_weights_normalized = log_softmax_from_accumulators_with_em_grad( self._accumulators, axis=2) elif not self.logspace_accumulators: log_weights_normalized = tf.nn.log_softmax(tf.math.log(log_weights_unnormalized), axis=2) else: log_weights_normalized = tf.nn.log_softmax(log_weights_unnormalized, axis=2) return logmatmul(x, log_weights_normalized)
def weighted_sum( self, x: tf.Tensor, accumulators: tf.Tensor, logspace_accumulators: bool, normalize_in_forward_pass: bool, ) -> tf.Tensor: """ Compute a weighted sum. Args: x: Input Tensor accumulators: Accumulators, can be seen as unnormalized representations of weights. logspace_accumulators: Whether or not accumulators are represented in logspace. normalize_in_forward_pass: Whether weights should be normalized during forward inference. Returns: A Tensor with the weighted sums. """ w = self._weights_in_logspace(accumulators, logspace_accumulators, normalize_in_forward_pass) return logmatmul(x, w)
def _inner_fn(child_log_prob, linear_accumulators): log_accumulators = tf.math.log(linear_accumulators) # Normalized weights = tf.nn.log_softmax(log_accumulators, axis=2) if unweighted: pairwise_product_backprop = tf.expand_dims(child_log_prob, axis=3) out = logmatmul(child_log_prob, weights) else: # Pairwise product in forward pass # [scopes, decomps, batch, 1, num_in] child_log_prob = tf.expand_dims(child_log_prob, axis=3) # [scopes, decomps, 1, num_out, num_in] weights = tf.expand_dims(tf.transpose(weights, (0, 1, 3, 2)), axis=2) pairwise_product = child_log_prob + weights # Max per sum for determining winning child + choosing the constant for numerical # stability max_per_sum = tf.stop_gradient( tf.reduce_max(pairwise_product, axis=-1, keepdims=True)) pairwise_product_backprop = child_log_prob + weights # Perform log(sum(exp(...))) with the numerical stability trick out = tf.math.log( tf.reduce_sum(tf.exp(pairwise_product - max_per_sum), axis=-1)) + tf.squeeze(max_per_sum, axis=-1) def grad(dy): # Determine winning child if unweighted: max_per_sum_backprop = tf.reduce_max(pairwise_product_backprop, axis=-1, keepdims=True) equal_to_max = tf.cast( tf.equal(pairwise_product_backprop, max_per_sum_backprop), tf.float32) else: equal_to_max = tf.cast( tf.equal(pairwise_product_backprop, max_per_sum), tf.float32) num_in = tf.shape(child_log_prob)[-1] num_out = tf.shape(out)[-1] equal_to_max_flat_outer = tf.reshape( equal_to_max, tf.concat([[-1], [num_in]], axis=0)) # Holds the index of the winning child per sum num_samples = num_out if unweighted else 1 winning_child_per_sum = tf.reshape( tf.random.categorical(tf.math.log(equal_to_max_flat_outer), num_samples=num_samples), tf.shape(out)) # Pass on the counts to the edges between child and parent edge_counts = tf.expand_dims(dy, -1) * tf.one_hot( winning_child_per_sum, depth=num_in) # Sum over parents to get counts per child child_counts = tf.reduce_sum(edge_counts, axis=3) # Sum over batch to get counts per weight weight_counts = tf.reduce_sum(edge_counts, axis=2) return child_counts, tf.transpose(weight_counts, (0, 1, 3, 2)) return out, grad