def loss_func(self, x, y, loss=None): """ Backward compatibility loss function to be used by the model """ if loss is None: loss = tf.keras.losses.SparseCategoricalCrossentropy( reduction=tf.keras.losses.Reduction.SUM) h1_output = tf.argmax(self.h1(x), axis=1) h2_output = self(x) h1_diff = h1_output - y # Here we determine which datapoints were correctly labeled by h1 h1_correct = (h1_diff == 0) # Here we pull those datapoints which were correctly labeled by h1 _, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2) # Here we pull the ground truth labels for those datapoints which were # correctly labeled by h1. _, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2) # Here we pull those outputs of h2, on datapoints which were correctly labeled # by h1. And use these to calculate the new error dissonance. _, h2_support_output = tf.dynamic_partition(h2_output, tf.dtypes.cast(h1_correct, tf.int32), 2) new_error_dissonance = self.dissonance(h2_support_output, y_support, loss) # We calculate the new error loss. new_error_loss = loss(y, h2_output) + self.lambda_c * new_error_dissonance return tf.reduce_sum(new_error_loss)
def __call__(self, x, y): h1_output = tf.argmax(self.h1(x), axis=1) h2_output = self.h2(x) h1_diff = h1_output - tf.argmax(y, axis=1) h1_correct = (h1_diff == 0) _, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2) _, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2) h2_support_output = self.h2(x_support) strict_imitation_dissonance = self.dissonance(h2_support_output, y_support) strict_imitation_loss = self.nll_loss(y, h2_output) + self.lambda_c * strict_imitation_dissonance return tf.reduce_sum(strict_imitation_loss)
def __call__(self, x, y): h1_output = tf.argmax(self.h1(x), axis=1) h2_output = self.h2(x) h1_diff = h1_output - y h1_correct = (h1_diff == 0) _, x_support = tf.dynamic_partition( x, tf.dtypes.cast(h1_correct, tf.int32), 2) _, y_support = tf.dynamic_partition( y, tf.dtypes.cast(h1_correct, tf.int32), 2) h2_support_output = self.h2(x_support) dissonance = self.dissonance(h2_support_output, y_support) new_error_loss = self.nll_loss(y, h2_output) + self.lambda_c * dissonance return new_error_loss
def nll_loss(self, target_labels, model_output): # Pick the model output probabilities corresponding to the ground truth labels _, model_outputs_for_targets = tf.dynamic_partition( model_output, tf.dtypes.cast(target_labels, tf.int32), 2) # We make sure to clip the probability values so that they do not # result in Nan's once we take the logarithm model_outputs_for_targets = tf.clip_by_value( model_outputs_for_targets, clip_value_min=self.clip_value_min, clip_value_max=self.clip_value_max) loss = -1 * tf.reduce_mean(tf.math.log(model_outputs_for_targets)) return loss
def _sample_n(self, n, seed=None): if self._use_static_graph: # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seed) stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): samples.append(self.components[c].sample(n, seed=stream())) stack_axis = -1 - tensorshape_util.rank(self._static_event_shape) x = tf.stack(samples, axis=stack_axis) # [n, B, k, E] npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=npdt(1), off_value=npdt(0)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, tensorshape_util.rank( self._static_event_shape)) # [n, B, k, [1]*e] return tf.reduce_sum(x * mask, axis=stack_axis) # [n, B, E] n = tf.convert_to_tensor(n, name='n') static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.shape if tensorshape_util.is_fully_defined(static_samples_shape): samples_shape = tensorshape_util.as_list(static_samples_shape) samples_size = tensorshape_util.num_elements(static_samples_shape) else: samples_shape = tf.shape(cat_samples) samples_size = tf.size(cat_samples) static_batch_shape = self.batch_shape if tensorshape_util.is_fully_defined(static_batch_shape): batch_shape = tensorshape_util.as_list(static_batch_shape) batch_size = tensorshape_util.num_elements(static_batch_shape) else: batch_shape = tf.shape(cat_samples)[1:] batch_size = tf.reduce_prod(batch_shape) static_event_shape = self.event_shape if tensorshape_util.is_fully_defined(static_event_shape): event_shape = np.array( tensorshape_util.as_list(static_event_shape), dtype=np.int32) else: event_shape = None # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): n_class = tf.size(partitioned_samples_indices[c]) samples_class_c = self.components[c].sample(n_class, seed=stream()) if event_shape is None: batch_ndims = prefer_static.rank_from_shape(batch_shape) event_shape = tf.shape(samples_class_c)[1 + batch_ndims:] # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather(samples_class_c, lookup_partitioned_batch_indices, name='samples_class_c_gather') samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape(lhs_flat_ret, tf.concat([samples_shape, event_shape], 0)) tensorshape_util.set_shape( ret, tensorshape_util.concatenate(static_samples_shape, self.event_shape)) return ret