def _build_clusters(self, name, layer):
        """Extracts the cluster centroids and cluster indices.

    Extracts cluster centroids and cluster indices from the pretrained
    clustered model when the input layer is clustered.

    Args:
      name: Name of weights in layer.
      layer: Quantization wrapped keras layer.
    Returns:
      A dictionary of the initial values of the
      cluster centroids, cluster indices, original weights,
      the pretrained flag for marking the first training
      epoch, and weight name.
    """
        result = {}
        weights = getattr(layer.layer, name)
        if self.preserve_sparsity and not tf.reduce_any(weights == 0):
            self.preserve_sparsity = False
            logging.warning(
                'Input layer does not contain zero weights, so apply CQAT instead.'
            )
        centroids_mask = None
        centroids, lookup = get_unique(weights)
        num_centroids = tf.size(centroids)

        if self.preserve_sparsity:
            sparsity_mask = tf.math.divide_no_nan(weights, weights)
            zero_idx = tf.argmin(tf.abs(centroids), axis=-1)
            centroids_mask = 1.0 - tf.one_hot(zero_idx, num_centroids)
            result = {SPARSITY_MASK: sparsity_mask}

        # Prepare clustering variables for the Keras graph when clusters
        # exist, assuming we do not use number_of_clusters larger than 1024
        if num_centroids > 1024:
            return result
        else:
            clst_centroids_tf = layer.add_weight(
                CLUSTER_CENTROIDS,
                shape=centroids.shape,
                initializer=tf.keras.initializers.Constant(
                    value=K.batch_get_value([centroids])[0]),
                dtype=centroids.dtype,
                trainable=True)

            ori_weights_tf = layer.add_weight(
                ORIGINAL_WEIGHTS,
                shape=weights.shape,
                initializer=tf.keras.initializers.Constant(
                    value=K.batch_get_value([weights])[0]),
                dtype=weights.dtype,
                trainable=True)

            # Get clustering implementation according to layer type
            clustering_impl_cls = clustering_registry.ClusteringLookupRegistry(
            ).get_clustering_impl(layer.layer, name)
            clustering_impl = clustering_impl_cls(clst_centroids_tf)

            pulling_indices = tf.dtypes.cast(
                clustering_impl.get_pulling_indices(ori_weights_tf),
                lookup.dtype)

            pulling_indices_tf = layer.add_weight(
                PULLING_INDICES,
                shape=lookup.shape,
                initializer=tf.keras.initializers.Constant(
                    value=K.batch_get_value([pulling_indices])[0]),
                dtype=lookup.dtype,
                trainable=False)

            result_clst = {
                CLUSTER_CENTROIDS: clst_centroids_tf,
                PULLING_INDICES: pulling_indices_tf,
                ORIGINAL_WEIGHTS: ori_weights_tf,
                WEIGHT_NAME: name,
                CLUSTERING_IMPL: clustering_impl,
                CENTROIDS_MASK: centroids_mask,
            }
            result.update(result_clst)
            return result
    def build(self, input_shape):
        super(ClusterWeights, self).build(input_shape)
        self.build_input_shape = input_shape

        # For every clusterable weights, create the clustering logic
        for weight_name, weight in self.layer.get_clusterable_weights():
            # Store the original weight in this wrapper
            # The child reference will be overridden in
            # update_clustered_weights_associations
            # The actual weight_name here for the clustering wrapper is not
            # necessarily the same as the original one from the layer wrapped.
            # For example for cells in StackedRNNCell, the names become
            # 'kernel/0', 'recurrent_kernel/0', 'kernel/1', 'recurrent_kernel/1'
            original_weight = self.get_weight_from_layer(weight_name)
            self.original_clusterable_weights[weight_name] = original_weight
            # Track the variable
            setattr(self, 'original_weight_' + weight_name, original_weight)
            # Store the position in layer.weights of original_weight to restore during
            # stripping
            position_original_weight = next(
                i for i, w in enumerate(self.layer.weights)
                if w is original_weight)
            self.position_original_weights[
                position_original_weight] = weight_name

            # Init the cluster centroids
            cluster_centroids = (
                clustering_centroids.CentroidsInitializerFactory.
                get_centroid_initializer(self.cluster_centroids_init)(
                    weight, self.number_of_clusters,
                    self.preserve_sparsity).get_cluster_centroids())
            self.cluster_centroids[weight_name] = self.add_weight(
                '{}{}'.format('cluster_centroids_', weight_name),
                shape=(self.number_of_clusters, ),
                dtype=weight.dtype,
                trainable=True,
                initializer=tf.keras.initializers.Constant(
                    value=cluster_centroids))

            # Init the weight clustering algorithm
            if isinstance(self.layer, tf.keras.layers.RNN):
                if isinstance(self.layer.cell,
                              tf.keras.layers.StackedRNNCells):
                    weight_name_no_index = weight_name.split('/')[0]
                else:
                    weight_name_no_index = weight_name
            elif isinstance(self.layer, tf.keras.layers.Bidirectional):
                weight_name_no_index = weight_name.split('/')[0]
            else:
                weight_name_no_index = weight_name
            self.clustering_algorithms[weight_name] = (
                clustering_registry.ClusteringLookupRegistry(
                ).get_clustering_impl(self.layer, weight_name_no_index)(
                    clusters_centroids=self.cluster_centroids[weight_name],
                    cluster_gradient_aggregation=self.
                    cluster_gradient_aggregation,
                ))

            # Init the pulling_indices (weights associations)
            pulling_indices = (self.clustering_algorithms[weight_name].
                               get_pulling_indices(weight))
            self.pulling_indices[weight_name] = self.add_weight(
                '{}{}'.format('pulling_indices_', weight_name),
                shape=pulling_indices.shape,
                dtype=tf.int64,
                trainable=False,
                synchronization=tf.VariableSynchronization.ON_READ,
                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                initializer=tf.keras.initializers.Constant(
                    value=pulling_indices))

            if self.preserve_sparsity:
                # Init the sparsity mask
                clustered_weights = (self.clustering_algorithms[weight_name].
                                     get_clustered_weight(
                                         pulling_indices, original_weight))
                self.sparsity_masks[weight_name] = (tf.cast(tf.math.not_equal(
                    clustered_weights, 0),
                                                            dtype=tf.float32))
                # If the model is pruned (which we suppose), this is approximately zero
                self.zero_idx[weight_name] = tf.argmin(tf.abs(
                    self.cluster_centroids[weight_name]),
                                                       axis=-1)
示例#3
0
    def build(self, input_shape):
        super(ClusterWeights, self).build(input_shape)

        clusterable_weights = self.layer.get_clusterable_weights()

        # Map automatically assigned TF variable name (e.g. 'dense/kernel:0') to
        # provided human readable name (e.g. as in Dense(10).kernel)
        clusterable_weights_to_variables = {}

        for weight_name, weight in clusterable_weights:
            # If a variable appears in this loop, then it is going to be removed from
            # self._trainable_weights. We need to memorise what variables are going
            # away so that later we are able to restore them. We have to do this to
            # maintain the original order of the weights in the underlying layer.
            # Incorrect order results in the incorrect OPs weights configurations.

            # We can be sure that weight will be found in this array since the
            # variable is either in the self._trainable_weights or in
            # self._non_trainable_weights and self.weights is the result of
            # concatenation of those arrays
            original_index = 0
            for i in range(len(self.layer.weights)):
                if self.layer.weights[i].name == weight.name:
                    original_index = i
            self.gone_variables.append(original_index)

            # Again, not sure if this is needed. Leaving for now.
            clusterable_weights_to_variables[self._weight_name(weight.name)] =\
                weight_name

            # Build initial cluster centroids for a given tensor. Factory returns a
            # class and we init an object immediately
            centroid_initializer = clustering_centroids.CentroidsInitializerFactory.\
                get_centroid_initializer(
                    self.cluster_centroids_init
                )(weight, self.number_of_clusters, self.preserve_sparsity)

            cluster_centroids = centroid_initializer.get_cluster_centroids()

            # Use k.batch_get_value since we need to initialize the variables with an
            # initial value taken from a Tensor object. For each weight there is a
            # different set of cluster centroids
            self.cluster_centroids_tf[weight_name] = self.add_weight(
                '{}{}'.format('cluster_centroids_tf_', weight_name),
                shape=(self.number_of_clusters, ),
                dtype=weight.dtype,
                trainable=True,
                initializer=initializers.Constant(
                    value=k.batch_get_value([cluster_centroids])[0]))

            # There are vectorised implementations of look-ups, we use a new one for
            # different number of dimensions.
            clustering_impl_cls = clustering_registry.ClusteringLookupRegistry().\
                get_clustering_impl(self.layer, weight_name)
            self.clustering_impl[weight_name] = clustering_impl_cls(
                self.cluster_centroids_tf[weight_name])

            # We find the nearest cluster centroids and store them so that ops can
            # build their weights upon it. These indices are calculated once and
            # stored forever. We use to make look-ups from self.cluster_centroids_tf
            pulling_indices = self.clustering_impl[weight_name].\
                get_pulling_indices(weight)
            self.pulling_indices_tf[weight_name] = self.add_weight(
                '{}{}'.format('pulling_indices_tf_', weight_name),
                shape=pulling_indices.shape,
                dtype=tf.int32,
                trainable=False,
                synchronization=tf.VariableSynchronization.ON_READ,
                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                initializer=initializers.Constant(
                    value=k.batch_get_value([pulling_indices])[0]))

            if self.preserve_sparsity:
                # Get the clustered weights
                clustered_weights = self.clustering_impl[
                    weight_name].get_clustered_weight(pulling_indices)

                # Create the sparsity mask
                sparsity_mask = tf.cast(tf.math.not_equal(
                    clustered_weights, 0),
                                        dtype=tf.float32)

                # Store the sparsity mask for training
                self.sparsity_masks[weight_name] = sparsity_mask

            # We store these pairs to easily update this variables later on
            self.ori_weights_vars_tf[weight_name] = self.add_weight(
                '{}{}'.format('ori_weights_vars_tf_', weight_name),
                shape=weight.shape,
                dtype=weight.dtype,
                trainable=True,
                initializer=initializers.Constant(
                    value=k.batch_get_value([weight])[0]))

        # We use currying here to get an updater which can be triggered at any time
        # in the future and it would return the latest version of clustered weights
        def get_updater(for_weight_name):
            def fn():
                # Get the clustered weights
                pulling_indices = self.pulling_indices_tf[for_weight_name]
                clustered_weights = self.clustering_impl[for_weight_name].\
                    get_clustered_weight(pulling_indices)

                if self.preserve_sparsity:
                    # Get the sparsity mask
                    sparsity_mask = self.sparsity_masks[for_weight_name]

                    # Apply the sparsity mask to the clustered weights
                    clustered_weights = tf.math.multiply(
                        clustered_weights, sparsity_mask)

                return clustered_weights

            return fn

        # This will allow us to restore the order of weights later
        # This loop stores pairs of weight names and how to restore them
        for ct, weight in enumerate(self.layer.weights):
            name = self._weight_name(weight.name)
            full_name = '{}/{}'.format(self.layer.name, name)
            if ct in self.gone_variables:
                # Again, not sure if this is needed
                weight_name = clusterable_weights_to_variables[name]
                self.restore.append(
                    (name, full_name, get_updater(weight_name)))
            else:
                self.restore.append((name, full_name, weight))
  def _build_clusters(self, name, layer):
    """Extract the cluster centroids and cluster indices from the pretrained clustered model.

    Args:
      name: Name of weights in layer.
      layer: Quantization wrapped keras layer.
    Returns:
      A dictionary of the initial values of the
      cluster centroids, cluster indices, original weights,
      the pretrained flag for marking the first training
      epoch, and weight name.
    """
    weights = getattr(layer.layer, name)
    centroids, lookup = get_unique(weights)

    # Prepare trainable variables for the Keras graph
    clst_centroids_tf = layer.add_weight(
        'cluster_centroids_tf',
        shape=centroids.shape,
        initializer=tf.keras.initializers.Constant(
            value=K.batch_get_value([centroids])[0]),
        dtype=centroids.dtype,
        trainable=True)

    ori_weights_tf = layer.add_weight(
        'ori_weights_vars_tf',
        shape=weights.shape,
        initializer=tf.keras.initializers.Constant(
            value=K.batch_get_value([weights])[0]),
        dtype=weights.dtype,
        trainable=True)

    # Get clustering implementation according to layer type
    clustering_impl_cls = clustering_registry.ClusteringLookupRegistry().\
        get_clustering_impl(layer.layer, name)
    clustering_impl = clustering_impl_cls(clst_centroids_tf)

    pulling_indices = tf.dtypes.cast(
        clustering_impl.get_pulling_indices(ori_weights_tf),
        lookup.dtype
    )

    pulling_indices_tf = layer.add_weight(
        'pulling_indices_tf',
        shape=lookup.shape,
        initializer=tf.keras.initializers.Constant(
            value=K.batch_get_value([pulling_indices])[0]),
        dtype=lookup.dtype,
        trainable=False)

    for v in layer.weights:
      if 'kernel' in v.name:
        kernel = v

    result = {
        'cluster_centroids_tf': clst_centroids_tf,
        'pulling_indices_tf': pulling_indices_tf,
        'ori_weights_vars_tf': ori_weights_tf,
        'weight_name': name,
        'clst_impl': clustering_impl,
        'set_kernel_weight': kernel,
    }

    return result
    def build(self, input_shape):
        super(ClusterWeights, self).build(input_shape)

        clusterable_weights = self.layer.get_clusterable_weights()

        # Map automatically assigned TF variable name (e.g. 'dense/kernel:0') to provided human readable name
        # (e.g. as in Dense(10).kernel)
        clusterable_weights_to_variables = {}

        for weight_name, weight in clusterable_weights:
            # If a variable appears in this loop, then it is going to be removed from self._trainable_weights.
            # We need to memorise what variables are going away so that later we are able to restore them. We have to do
            # this to maintain the original order of the weights in the underlying layer. Incorrect order results in the
            # incorrect OPs weights configurations.

            # We can be sure that weight will be found in this array since the variable is either in the
            # self._trainable_weights
            # or in self._non_trainable_weights and self.weights is the result of concatenation of those arrays
            original_index = self.layer.weights.index(weight)
            self.gone_variables.append(original_index)

            # Again, not sure if this is needed. Leaving for now.
            clusterable_weights_to_variables[self._weight_name(
                weight.name)] = weight_name

            # Build initial cluster centroids for a given tensor. Factory returns a class and we init an object immediately
            centroid_initializer = clustering_centroids.CentroidsInitializerFactory.get_centroid_initializer(
                self.cluster_centroids_init)(weight, self.number_of_clusters)

            cluster_centroids = centroid_initializer.get_cluster_centroids()

            # Use k.batch_get_value since we need to initialize the variables with an initial value taken from a Tensor object
            # For each weight there is a different set of cluster centroids
            self.cluster_centroids_tf[weight_name] = self.add_weight(
                'cluster_centroids_tf',
                shape=(self.number_of_clusters, ),
                dtype=weight.dtype,
                trainable=True,
                initializer=initializers.Constant(
                    value=k.batch_get_value([cluster_centroids])[0]))

            # There are vectorised implementations of look-ups, we use a new one for different number of dimensions.
            clustering_impl_cls = clustering_registry.ClusteringLookupRegistry(
            ).get_clustering_impl(self.layer, weight_name)
            self.clustering_impl[weight_name] = clustering_impl_cls(
                self.cluster_centroids_tf[weight_name])

            # We find the nearest cluster centroids and store them so that ops can build their weights upon it
            # These indices are calculated once and stored forever. We use to make look-ups from self.cluster_centroids_tf
            pulling_indices = self.clustering_impl[
                weight_name].get_pulling_indices(weight)
            self.pulling_indices_tf[weight_name] = self.add_weight(
                'pulling_indices_tf',
                shape=pulling_indices.shape,
                dtype=tf.int32,
                trainable=False,
                initializer=initializers.Constant(
                    value=k.batch_get_value([pulling_indices])[0]))

            # We store these pairs to easily update this variables later on
            self.clustered_vars.append((weight_name, weight))

        # We use currying here to get an updater which can be triggered at any time in future and it would return
        # the latest version of clustered weights
        def get_updater(for_weight_name):
            def fn():
                return self.clustering_impl[
                    for_weight_name].get_clustered_weight(
                        self.pulling_indices_tf[for_weight_name])

            return fn

        # This will allow us to restore the order of weights later
        # This loop stores pairs of weight names and how to restore them

        for ct, weight in enumerate(self.layer.weights):
            name = self._weight_name(weight.name)
            if ct in self.gone_variables:
                # Again, not sure if this is needed
                weight_name = clusterable_weights_to_variables[name]
                self.restore.append((name, get_updater(weight_name)))
            else:
                self.restore.append((name, weight))
    def build(self, input_shape):
        super(ClusterWeights, self).build(input_shape)
        self.build_input_shape = input_shape

        # For every clusterable weights, create the clustering logic
        for weight_name, weight in self.layer.get_clusterable_weights():
            # Store the original weight in this wrapper
            # The child reference will be overridden in
            # update_clustered_weights_associations
            original_weight = getattr(self.layer, weight_name)
            self.original_clusterable_weights[weight_name] = original_weight
            setattr(self, 'original_weight_' + weight_name,
                    original_weight)  # Track the variable
            # Store the position in layer.weights of original_weight to restore during
            # stripping
            position_original_weight = next(
                i for i, w in enumerate(self.layer.weights)
                if w is original_weight)
            self.position_original_weights[
                position_original_weight] = weight_name

            # Init the cluster centroids
            cluster_centroids = (
                clustering_centroids.CentroidsInitializerFactory.
                get_centroid_initializer(self.cluster_centroids_init)(
                    weight, self.number_of_clusters,
                    self.preserve_sparsity).get_cluster_centroids())
            self.cluster_centroids[weight_name] = self.add_weight(
                '{}{}'.format('cluster_centroids_', weight_name),
                shape=(self.number_of_clusters, ),
                dtype=weight.dtype,
                trainable=True,
                initializer=tf.keras.initializers.Constant(
                    value=cluster_centroids))

            # Init the weight clustering algorithm
            self.clustering_algorithms[weight_name] = (
                clustering_registry.ClusteringLookupRegistry(
                ).get_clustering_impl(self.layer, weight_name)(
                    clusters_centroids=self.cluster_centroids[weight_name],
                    cluster_gradient_aggregation=self.
                    cluster_gradient_aggregation,
                ))

            # Init the pulling_indices (weights associations)
            pulling_indices = (self.clustering_algorithms[weight_name].
                               get_pulling_indices(weight))
            self.pulling_indices[weight_name] = self.add_weight(
                '{}{}'.format('pulling_indices_', weight_name),
                shape=pulling_indices.shape,
                dtype=tf.int64,
                trainable=False,
                synchronization=tf.VariableSynchronization.ON_READ,
                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                initializer=tf.keras.initializers.Constant(
                    value=pulling_indices))

            if self.preserve_sparsity:
                # Init the sparsity mask
                clustered_weights = (self.clustering_algorithms[weight_name].
                                     get_clustered_weight(
                                         pulling_indices, original_weight))
                self.sparsity_masks[weight_name] = (tf.cast(tf.math.not_equal(
                    clustered_weights, 0),
                                                            dtype=tf.float32))