def loss_func(self, x, y, loss=None):
        Backward compatibility loss function to be used by the model
        if loss is None:
            loss = tf.keras.losses.SparseCategoricalCrossentropy(
        h1_output = tf.argmax(self.h1(x), axis=1)
        h2_output = self(x)
        h1_diff = h1_output - y

        # Here we determine which datapoints were correctly labeled by h1
        h1_correct = (h1_diff == 0)

        # Here we pull those datapoints which were correctly labeled by h1
        _, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2)

        # Here we pull the ground truth labels for those datapoints which were
        # correctly labeled by h1.
        _, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2)

        # Here we pull those outputs of h2, on datapoints which were correctly labeled
        # by h1. And use these to calculate the new error dissonance.
        _, h2_support_output = tf.dynamic_partition(h2_output, tf.dtypes.cast(h1_correct, tf.int32), 2)
        new_error_dissonance = self.dissonance(h2_support_output, y_support, loss)

        # We calculate the new error loss.
        new_error_loss = loss(y, h2_output) + self.lambda_c * new_error_dissonance

        return tf.reduce_sum(new_error_loss)
    def __call__(self, x, y):
        h1_output = tf.argmax(self.h1(x), axis=1)
        h2_output = self.h2(x)
        h1_diff = h1_output - tf.argmax(y, axis=1)
        h1_correct = (h1_diff == 0)
        _, x_support = tf.dynamic_partition(x, tf.dtypes.cast(h1_correct, tf.int32), 2)
        _, y_support = tf.dynamic_partition(y, tf.dtypes.cast(h1_correct, tf.int32), 2)
        h2_support_output = self.h2(x_support)
        strict_imitation_dissonance = self.dissonance(h2_support_output, y_support)
        strict_imitation_loss = self.nll_loss(y, h2_output) + self.lambda_c * strict_imitation_dissonance

        return tf.reduce_sum(strict_imitation_loss)
    def __call__(self, x, y):
        h1_output = tf.argmax(self.h1(x), axis=1)
        h2_output = self.h2(x)
        h1_diff = h1_output - y
        h1_correct = (h1_diff == 0)
        _, x_support = tf.dynamic_partition(
            x, tf.dtypes.cast(h1_correct, tf.int32), 2)
        _, y_support = tf.dynamic_partition(
            y, tf.dtypes.cast(h1_correct, tf.int32), 2)
        h2_support_output = self.h2(x_support)
        dissonance = self.dissonance(h2_support_output, y_support)
        new_error_loss = self.nll_loss(y,
                                       h2_output) + self.lambda_c * dissonance

        return new_error_loss
    def nll_loss(self, target_labels, model_output):
        # Pick the model output probabilities corresponding to the ground truth labels
        _, model_outputs_for_targets = tf.dynamic_partition(
            model_output, tf.dtypes.cast(target_labels, tf.int32), 2)
        # We make sure to clip the probability values so that they do not
        # result in Nan's once we take the logarithm
        model_outputs_for_targets = tf.clip_by_value(
        loss = -1 * tf.reduce_mean(tf.math.log(model_outputs_for_targets))

        return loss
Example #5
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code
            # path.
            samples = []
            cat_samples =, seed=seed)
            stream = SeedStream(seed, salt='Mixture')

            for c in range(self.num_components):
                samples.append(self.components[c].sample(n, seed=stream()))
            stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
            x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
            npdt = dtype_util.as_numpy_dtype(x.dtype)
            mask = tf.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                off_value=npdt(0))  # [n, B, k]
            mask = distribution_util.pad_mixture_dimensions(
                mask, self, self._cat,
                    self._static_event_shape))  # [n, B, k, [1]*e]
            return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        n = tf.convert_to_tensor(n, name='n')
        static_n = tf.get_static_value(n)
        n = int(static_n) if static_n is not None else n
        cat_samples =, seed=seed)

        static_samples_shape = cat_samples.shape
        if tensorshape_util.is_fully_defined(static_samples_shape):
            samples_shape = tensorshape_util.as_list(static_samples_shape)
            samples_size = tensorshape_util.num_elements(static_samples_shape)
            samples_shape = tf.shape(cat_samples)
            samples_size = tf.size(cat_samples)
        static_batch_shape = self.batch_shape
        if tensorshape_util.is_fully_defined(static_batch_shape):
            batch_shape = tensorshape_util.as_list(static_batch_shape)
            batch_size = tensorshape_util.num_elements(static_batch_shape)
            batch_shape = tf.shape(cat_samples)[1:]
            batch_size = tf.reduce_prod(batch_shape)
        static_event_shape = self.event_shape
        if tensorshape_util.is_fully_defined(static_event_shape):
            event_shape = np.array(
                tensorshape_util.as_list(static_event_shape), dtype=np.int32)
            event_shape = None

        # Get indices into the raw cat sampling tensor. We will
        # need these to stitch sample values back out after sampling
        # within the component partitions.
        samples_raw_indices = tf.reshape(tf.range(0, samples_size),

        # Partition the raw indices so that we can use
        # dynamic_stitch later to reconstruct the samples from the
        # known partitions.
        partitioned_samples_indices = tf.dynamic_partition(

        # Copy the batch indices n times, as we will need to know
        # these to pull out the appropriate rows within the
        # component partitions.
        batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]),

        # Explanation of the dynamic partitioning below:
        #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
        # Suppose partitions are:
        #     [1 1 0 0 1 1]
        # After partitioning, batch indices are cut as:
        #     [batch_indices[x] for x in 2, 3]
        #     [batch_indices[x] for x in 0, 1, 4, 5]
        # i.e.
        #     [1 1] and [0 0 0 0]
        # Now we sample n=2 from part 0 and n=4 from part 1.
        # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
        # and for part 1 we want samples from batch entries 0, 0, 0, 0
        #   (samples 0, 1, 2, 3).
        partitioned_batch_indices = tf.dynamic_partition(
        samples_class = [None for _ in range(self.num_components)]

        stream = SeedStream(seed, salt='Mixture')

        for c in range(self.num_components):
            n_class = tf.size(partitioned_samples_indices[c])
            samples_class_c = self.components[c].sample(n_class, seed=stream())

            if event_shape is None:
                batch_ndims = prefer_static.rank_from_shape(batch_shape)
                event_shape = tf.shape(samples_class_c)[1 + batch_ndims:]

            # Pull out the correct batch entries from each index.
            # To do this, we may have to flatten the batch shape.

            # For sample s, batch element b of component c, we get the
            # partitioned batch indices from
            # partitioned_batch_indices[c]; and shift each element by
            # the sample index. The final lookup can be thought of as
            # a matrix gather along locations (s, b) in
            # samples_class_c where the n_class rows correspond to
            # samples within this component and the batch_size columns
            # correspond to batch elements within the component.
            # Thus the lookup index is
            #   lookup[c, i] = batch_size * s[i] + b[c, i]
            # for i = 0 ... n_class[c] - 1.
            lookup_partitioned_batch_indices = (
                batch_size * tf.range(n_class) + partitioned_batch_indices[c])
            samples_class_c = tf.reshape(
                tf.concat([[n_class * batch_size], event_shape], 0))
            samples_class_c = tf.gather(samples_class_c,
            samples_class[c] = samples_class_c

        # Stitch back together the samples across the components.
        lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices,
        # Reshape back to proper sample, batch, and event shape.
        ret = tf.reshape(lhs_flat_ret,
                         tf.concat([samples_shape, event_shape], 0))
        return ret