Ejemplo n.º 1
0
 def _init_set_name(self, name, zero_based=True):
     if not name:
         self._name = backend.unique_object_name(
             generic_utils.to_snake_case(self.__class__.__name__),
             zero_based=zero_based)
     else:
         self._name = name
Ejemplo n.º 2
0
def get_custom_object_name(obj):
    """Returns the name to use for a custom loss or metric callable.

  Args:
    obj: Custom loss of metric callable

  Returns:
    Name to use, or `None` if the object was not recognized.
  """
    if hasattr(obj, 'name'):  # Accept `Loss` instance as `Metric`.
        return obj.name
    elif hasattr(obj, '__name__'):  # Function.
        return obj.__name__
    elif hasattr(obj, '__class__'):  # Class instance.
        return generic_utils.to_snake_case(obj.__class__.__name__)
    else:  # Unrecognized object.
        return None
Ejemplo n.º 3
0
  def testWrapperWeights(self, wrapper):
    """Tests that wrapper weights contain wrapped cells weights."""
    base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell")
    rnn_cell = wrapper(base_cell)
    rnn_layer = layers.RNN(rnn_cell)
    inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32)
    rnn_layer(inputs)

    wrapper_name = generic_utils.to_snake_case(wrapper.__name__)
    expected_weights = ["rnn/" + wrapper_name + "/" + var for var in
                        ("kernel:0", "recurrent_kernel:0", "bias:0")]
    self.assertLen(rnn_cell.weights, 3)
    self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights)
    self.assertCountEqual([v.name for v in rnn_cell.trainable_variables],
                          expected_weights)
    self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables],
                          [])
    self.assertCountEqual([v.name for v in rnn_cell.cell.weights],
                          expected_weights)
Ejemplo n.º 4
0
    def __init__(self,
                 layer,
                 pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0),
                 block_size=(1, 1),
                 block_pooling_type='AVG',
                 **kwargs):
        """Create a pruning wrapper for a keras layer.

    #TODO(pulkitb): Consider if begin_step should be 0 by default.

    Args:
      layer: The keras layer to be pruned.
      pruning_schedule: A `PruningSchedule` object that controls pruning rate
        throughout training.
      block_size: (optional) The dimensions (height, weight) for the block
        sparse pattern in rank-2 weight tensors.
      block_pooling_type: (optional) The function to use to pool weights in the
        block. Must be 'AVG' or 'MAX'.
      **kwargs: Additional keyword arguments to be passed to the keras layer.
    """
        self.pruning_schedule = pruning_schedule
        self.block_size = block_size
        self.block_pooling_type = block_pooling_type

        # An instance of the Pruning class. This class contains the logic to prune
        # the weights of this layer.
        self.pruning_obj = None

        # A list of all (weight,mask,threshold) tuples for this layer
        self.pruning_vars = []

        if block_pooling_type not in ['AVG', 'MAX']:
            raise ValueError(
                'Unsupported pooling type \'{}\'. Should be \'AVG\' or \'MAX\'.'
                .format(block_pooling_type))

        if not isinstance(layer, tf.keras.layers.Layer):
            raise ValueError(
                'Please initialize `Prune` layer with a '
                '`Layer` instance. You passed: {input}'.format(input=layer))

        # TODO(pulkitb): This should be pushed up to the wrappers.py
        # Name the layer using the wrapper and underlying layer name.
        # Prune(Dense) becomes prune_dense_1
        kwargs.update({
            'name':
            '{}_{}'.format(
                generic_utils.to_snake_case(self.__class__.__name__),
                layer.name)
        })

        if isinstance(layer, prunable_layer.PrunableLayer) or hasattr(
                layer, 'get_prunable_weights'):
            # Custom layer in client code which supports pruning.
            super(PruneLowMagnitude, self).__init__(layer, **kwargs)
        elif prune_registry.PruneRegistry.supports(layer):
            # Built-in keras layers which support pruning.
            super(PruneLowMagnitude, self).__init__(
                prune_registry.PruneRegistry.make_prunable(layer), **kwargs)
        else:
            raise ValueError(
                'Please initialize `Prune` with a supported layer. Layers should '
                'either be supported by the PruneRegistry (built-in keras layers) or '
                'should be a `PrunableLayer` instance, or should has a customer '
                'defined `get_prunable_weights` method. You passed: '
                '{input}'.format(input=layer.__class__))

        self._track_trackable(layer, name='layer')

        # TODO(yunluli): Work-around to handle the first layer of Sequential model
        # properly. Can remove this when it is implemented in the Wrapper base
        # class.
        #
        # Enables end-user to prune the first layer in Sequential models, while
        # passing the input shape to the original layer.
        #
        # tf.keras.Sequential(
        #   prune_low_magnitude(tf.keras.layers.Dense(2, input_shape=(3,)))
        # )
        #
        # as opposed to
        #
        # tf.keras.Sequential(
        #   prune_low_magnitude(tf.keras.layers.Dense(2), input_shape=(3,))
        # )
        #
        # Without this code, the pruning wrapper doesn't have an input
        # shape and being the first layer, this causes the model to not be
        # built. Being not built is confusing since the end-user has passed an
        # input shape.
        if not hasattr(self, '_batch_input_shape') and hasattr(
                layer, '_batch_input_shape'):
            self._batch_input_shape = self.layer._batch_input_shape
        metrics.MonitorBoolGauge('prune_low_magnitude_wrapper_usage').set(
            layer.__class__.__name__)
Ejemplo n.º 5
0
def populate_deserializable_objects():
  """Populates dict ALL_OBJECTS with every built-in initializer.
  """
  global LOCAL
  if not hasattr(LOCAL, 'ALL_OBJECTS'):
    LOCAL.ALL_OBJECTS = {}
    LOCAL.GENERATED_WITH_V2 = None

  if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled():
    # Objects dict is already generated for the proper TF version:
    # do nothing.
    return

  LOCAL.ALL_OBJECTS = {}
  LOCAL.GENERATED_WITH_V2 = tf2.enabled()

  # Compatibility aliases (need to exist in both V1 and V2).
  LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant
  LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal
  LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform
  LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal
  LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform
  LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity
  LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal
  LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform
  LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones
  LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal
  LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal
  LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform
  LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal
  LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling
  LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros

  # Out of an abundance of caution we also include these aliases that have
  # a non-zero probability of having been included in saved configs in the past.
  LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal
  LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform
  LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal
  LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform
  LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal
  LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform

  if tf2.enabled():
    # For V2, entries are generated automatically based on the content of
    # initializers_v2.py.
    v2_objs = {}
    base_cls = initializers_v2.Initializer
    generic_utils.populate_dict_with_module_objects(
        v2_objs,
        [initializers_v2],
        obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
    for key, value in v2_objs.items():
      LOCAL.ALL_OBJECTS[key] = value
      # Functional aliases.
      LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
  else:
    # V1 initializers.
    v1_objs = {
        'Constant': init_ops.Constant,
        'GlorotNormal': init_ops.GlorotNormal,
        'GlorotUniform': init_ops.GlorotUniform,
        'Identity': init_ops.Identity,
        'Ones': init_ops.Ones,
        'Orthogonal': init_ops.Orthogonal,
        'VarianceScaling': init_ops.VarianceScaling,
        'Zeros': init_ops.Zeros,
        'HeNormal': initializers_v1.HeNormal,
        'HeUniform': initializers_v1.HeUniform,
        'LecunNormal': initializers_v1.LecunNormal,
        'LecunUniform': initializers_v1.LecunUniform,
        'RandomNormal': initializers_v1.RandomNormal,
        'RandomUniform': initializers_v1.RandomUniform,
        'TruncatedNormal': initializers_v1.TruncatedNormal,
    }
    for key, value in v1_objs.items():
      LOCAL.ALL_OBJECTS[key] = value
      # Functional aliases.
      LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value

  # More compatibility aliases.
  LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal']
  LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform']
  LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones']
  LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
Ejemplo n.º 6
0
    def __init__(self,
                 layer,
                 pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0),
                 block_size=(1, 1),
                 block_pooling_type='AVG',
                 **kwargs):
        """Create a pruning wrapper for a keras layer.

    #TODO(pulkitb): Consider if begin_step should be 0 by default.

    Args:
      layer: The keras layer to be pruned.
      pruning_schedule: A `PruningSchedule` object that controls pruning rate
        throughout training.
      block_size: (optional) The dimensions (height, weight) for the block
        sparse pattern in rank-2 weight tensors.
      block_pooling_type: (optional) The function to use to pool weights in the
        block. Must be 'AVG' or 'MAX'.
      **kwargs: Additional keyword arguments to be passed to the keras layer.
    """
        self.pruning_schedule = pruning_schedule
        self.block_size = block_size
        self.block_pooling_type = block_pooling_type

        # An instance of the Pruning class. This class contains the logic to prune
        # the weights of this layer.
        self.pruning_obj = None

        # A list of all (weight,mask,threshold) tuples for this layer
        self.pruning_vars = []

        if block_pooling_type not in ['AVG', 'MAX']:
            raise ValueError(
                'Unsupported pooling type \'{}\'. Should be \'AVG\' or \'MAX\'.'
                .format(block_pooling_type))

        if not isinstance(layer, Layer):
            raise ValueError(
                'Please initialize `Prune` layer with a '
                '`Layer` instance. You passed: {input}'.format(input=layer))

        # TODO(pulkitb): This should be pushed up to the wrappers.py
        # Name the layer using the wrapper and underlying layer name.
        # Prune(Dense) becomes prune_dense_1
        kwargs.update({
            'name':
            '{}_{}'.format(
                generic_utils.to_snake_case(self.__class__.__name__),
                layer.name)
        })

        if isinstance(layer, prunable_layer.PrunableLayer):
            # Custom layer in client code which supports pruning.
            super(PruneLowMagnitude, self).__init__(layer, **kwargs)
        elif prune_registry.PruneRegistry.supports(layer):
            # Built-in keras layers which support pruning.
            super(PruneLowMagnitude, self).__init__(
                prune_registry.PruneRegistry.make_prunable(layer), **kwargs)
        else:
            raise ValueError(
                'Please initialize `Prune` with a supported layer. Layers should '
                'either be a `PrunableLayer` instance, or should be supported by the '
                'PruneRegistry. You passed: {input}'.format(
                    input=layer.__class__))

        self._track_trackable(layer, name='layer')

        # TODO(yunluli): Work-around to handle the first layer of Sequential model
        # properly. Can remove this when it is implemented in the Wrapper base
        # class.
        # The _batch_input_shape attribute in the first layer makes a Sequential
        # model to be built. This change makes sure that when we apply the wrapper
        # to the whole model, this attribute is pulled into the wrapper to preserve
        # the 'built' state of the model.
        if not hasattr(self, '_batch_input_shape') and hasattr(
                layer, '_batch_input_shape'):
            self._batch_input_shape = self.layer._batch_input_shape