def _init_set_name(self, name, zero_based=True): if not name: self._name = backend.unique_object_name( generic_utils.to_snake_case(self.__class__.__name__), zero_based=zero_based) else: self._name = name
def get_custom_object_name(obj): """Returns the name to use for a custom loss or metric callable. Args: obj: Custom loss of metric callable Returns: Name to use, or `None` if the object was not recognized. """ if hasattr(obj, 'name'): # Accept `Loss` instance as `Metric`. return obj.name elif hasattr(obj, '__name__'): # Function. return obj.__name__ elif hasattr(obj, '__class__'): # Class instance. return generic_utils.to_snake_case(obj.__class__.__name__) else: # Unrecognized object. return None
def testWrapperWeights(self, wrapper): """Tests that wrapper weights contain wrapped cells weights.""" base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell") rnn_cell = wrapper(base_cell) rnn_layer = layers.RNN(rnn_cell) inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32) rnn_layer(inputs) wrapper_name = generic_utils.to_snake_case(wrapper.__name__) expected_weights = ["rnn/" + wrapper_name + "/" + var for var in ("kernel:0", "recurrent_kernel:0", "bias:0")] self.assertLen(rnn_cell.weights, 3) self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights) self.assertCountEqual([v.name for v in rnn_cell.trainable_variables], expected_weights) self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables], []) self.assertCountEqual([v.name for v in rnn_cell.cell.weights], expected_weights)
def __init__(self, layer, pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0), block_size=(1, 1), block_pooling_type='AVG', **kwargs): """Create a pruning wrapper for a keras layer. #TODO(pulkitb): Consider if begin_step should be 0 by default. Args: layer: The keras layer to be pruned. pruning_schedule: A `PruningSchedule` object that controls pruning rate throughout training. block_size: (optional) The dimensions (height, weight) for the block sparse pattern in rank-2 weight tensors. block_pooling_type: (optional) The function to use to pool weights in the block. Must be 'AVG' or 'MAX'. **kwargs: Additional keyword arguments to be passed to the keras layer. """ self.pruning_schedule = pruning_schedule self.block_size = block_size self.block_pooling_type = block_pooling_type # An instance of the Pruning class. This class contains the logic to prune # the weights of this layer. self.pruning_obj = None # A list of all (weight,mask,threshold) tuples for this layer self.pruning_vars = [] if block_pooling_type not in ['AVG', 'MAX']: raise ValueError( 'Unsupported pooling type \'{}\'. Should be \'AVG\' or \'MAX\'.' .format(block_pooling_type)) if not isinstance(layer, tf.keras.layers.Layer): raise ValueError( 'Please initialize `Prune` layer with a ' '`Layer` instance. You passed: {input}'.format(input=layer)) # TODO(pulkitb): This should be pushed up to the wrappers.py # Name the layer using the wrapper and underlying layer name. # Prune(Dense) becomes prune_dense_1 kwargs.update({ 'name': '{}_{}'.format( generic_utils.to_snake_case(self.__class__.__name__), layer.name) }) if isinstance(layer, prunable_layer.PrunableLayer) or hasattr( layer, 'get_prunable_weights'): # Custom layer in client code which supports pruning. super(PruneLowMagnitude, self).__init__(layer, **kwargs) elif prune_registry.PruneRegistry.supports(layer): # Built-in keras layers which support pruning. super(PruneLowMagnitude, self).__init__( prune_registry.PruneRegistry.make_prunable(layer), **kwargs) else: raise ValueError( 'Please initialize `Prune` with a supported layer. Layers should ' 'either be supported by the PruneRegistry (built-in keras layers) or ' 'should be a `PrunableLayer` instance, or should has a customer ' 'defined `get_prunable_weights` method. You passed: ' '{input}'.format(input=layer.__class__)) self._track_trackable(layer, name='layer') # TODO(yunluli): Work-around to handle the first layer of Sequential model # properly. Can remove this when it is implemented in the Wrapper base # class. # # Enables end-user to prune the first layer in Sequential models, while # passing the input shape to the original layer. # # tf.keras.Sequential( # prune_low_magnitude(tf.keras.layers.Dense(2, input_shape=(3,))) # ) # # as opposed to # # tf.keras.Sequential( # prune_low_magnitude(tf.keras.layers.Dense(2), input_shape=(3,)) # ) # # Without this code, the pruning wrapper doesn't have an input # shape and being the first layer, this causes the model to not be # built. Being not built is confusing since the end-user has passed an # input shape. if not hasattr(self, '_batch_input_shape') and hasattr( layer, '_batch_input_shape'): self._batch_input_shape = self.layer._batch_input_shape metrics.MonitorBoolGauge('prune_low_magnitude_wrapper_usage').set( layer.__class__.__name__)
def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in initializer. """ global LOCAL if not hasattr(LOCAL, 'ALL_OBJECTS'): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf2.enabled() # Compatibility aliases (need to exist in both V1 and V2). LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros # Out of an abundance of caution we also include these aliases that have # a non-zero probability of having been included in saved configs in the past. LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform if tf2.enabled(): # For V2, entries are generated automatically based on the content of # initializers_v2.py. v2_objs = {} base_cls = initializers_v2.Initializer generic_utils.populate_dict_with_module_objects( v2_objs, [initializers_v2], obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) for key, value in v2_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value else: # V1 initializers. v1_objs = { 'Constant': init_ops.Constant, 'GlorotNormal': init_ops.GlorotNormal, 'GlorotUniform': init_ops.GlorotUniform, 'Identity': init_ops.Identity, 'Ones': init_ops.Ones, 'Orthogonal': init_ops.Orthogonal, 'VarianceScaling': init_ops.VarianceScaling, 'Zeros': init_ops.Zeros, 'HeNormal': initializers_v1.HeNormal, 'HeUniform': initializers_v1.HeUniform, 'LecunNormal': initializers_v1.LecunNormal, 'LecunUniform': initializers_v1.LecunUniform, 'RandomNormal': initializers_v1.RandomNormal, 'RandomUniform': initializers_v1.RandomUniform, 'TruncatedNormal': initializers_v1.TruncatedNormal, } for key, value in v1_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value # More compatibility aliases. LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal'] LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform'] LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones'] LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
def __init__(self, layer, pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0), block_size=(1, 1), block_pooling_type='AVG', **kwargs): """Create a pruning wrapper for a keras layer. #TODO(pulkitb): Consider if begin_step should be 0 by default. Args: layer: The keras layer to be pruned. pruning_schedule: A `PruningSchedule` object that controls pruning rate throughout training. block_size: (optional) The dimensions (height, weight) for the block sparse pattern in rank-2 weight tensors. block_pooling_type: (optional) The function to use to pool weights in the block. Must be 'AVG' or 'MAX'. **kwargs: Additional keyword arguments to be passed to the keras layer. """ self.pruning_schedule = pruning_schedule self.block_size = block_size self.block_pooling_type = block_pooling_type # An instance of the Pruning class. This class contains the logic to prune # the weights of this layer. self.pruning_obj = None # A list of all (weight,mask,threshold) tuples for this layer self.pruning_vars = [] if block_pooling_type not in ['AVG', 'MAX']: raise ValueError( 'Unsupported pooling type \'{}\'. Should be \'AVG\' or \'MAX\'.' .format(block_pooling_type)) if not isinstance(layer, Layer): raise ValueError( 'Please initialize `Prune` layer with a ' '`Layer` instance. You passed: {input}'.format(input=layer)) # TODO(pulkitb): This should be pushed up to the wrappers.py # Name the layer using the wrapper and underlying layer name. # Prune(Dense) becomes prune_dense_1 kwargs.update({ 'name': '{}_{}'.format( generic_utils.to_snake_case(self.__class__.__name__), layer.name) }) if isinstance(layer, prunable_layer.PrunableLayer): # Custom layer in client code which supports pruning. super(PruneLowMagnitude, self).__init__(layer, **kwargs) elif prune_registry.PruneRegistry.supports(layer): # Built-in keras layers which support pruning. super(PruneLowMagnitude, self).__init__( prune_registry.PruneRegistry.make_prunable(layer), **kwargs) else: raise ValueError( 'Please initialize `Prune` with a supported layer. Layers should ' 'either be a `PrunableLayer` instance, or should be supported by the ' 'PruneRegistry. You passed: {input}'.format( input=layer.__class__)) self._track_trackable(layer, name='layer') # TODO(yunluli): Work-around to handle the first layer of Sequential model # properly. Can remove this when it is implemented in the Wrapper base # class. # The _batch_input_shape attribute in the first layer makes a Sequential # model to be built. This change makes sure that when we apply the wrapper # to the whole model, this attribute is pulled into the wrapper to preserve # the 'built' state of the model. if not hasattr(self, '_batch_input_shape') and hasattr( layer, '_batch_input_shape'): self._batch_input_shape = self.layer._batch_input_shape