def testWrapperWeights(self, wrapper): """Tests that wrapper weights contain wrapped cells weights.""" base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell") rnn_cell = wrapper(base_cell) rnn_layer = layers.RNN(rnn_cell) inputs = tf.convert_to_tensor([[[1]]], dtype=tf.float32) rnn_layer(inputs) wrapper_name = generic_utils.to_snake_case(wrapper.__name__) expected_weights = [ "rnn/" + wrapper_name + "/" + var for var in ("kernel:0", "recurrent_kernel:0", "bias:0") ] self.assertLen(rnn_cell.weights, 3) self.assertCountEqual( [v.name for v in rnn_cell.weights], expected_weights ) self.assertCountEqual( [v.name for v in rnn_cell.trainable_variables], expected_weights ) self.assertCountEqual( [v.name for v in rnn_cell.non_trainable_variables], [] ) self.assertCountEqual( [v.name for v in rnn_cell.cell.weights], expected_weights )
def get_custom_object_name(obj): """Returns the name to use for a custom loss or metric callable. Args: obj: Custom loss of metric callable Returns: Name to use, or `None` if the object was not recognized. """ if hasattr(obj, 'name'): # Accept `Loss` instance as `Metric`. return obj.name elif hasattr(obj, '__name__'): # Function. return obj.__name__ elif hasattr(obj, '__class__'): # Class instance. return generic_utils.to_snake_case(obj.__class__.__name__) else: # Unrecognized object. return None
def test_snake_case(self): self.assertEqual(generic_utils.to_snake_case('SomeClass'), 'some_class') self.assertEqual(generic_utils.to_snake_case('Conv2D'), 'conv2d') self.assertEqual(generic_utils.to_snake_case('ConvLSTM2D'), 'conv_lstm2d')
def test_snake_case(self): self.assertEqual(generic_utils.to_snake_case("SomeClass"), "some_class") self.assertEqual(generic_utils.to_snake_case("Conv2D"), "conv2d") self.assertEqual(generic_utils.to_snake_case("ConvLSTM2D"), "conv_lstm2d")
def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in initializer.""" global LOCAL if not hasattr(LOCAL, "ALL_OBJECTS"): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if (LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled() # Compatibility aliases (need to exist in both V1 and V2). LOCAL.ALL_OBJECTS["ConstantV2"] = initializers_v2.Constant LOCAL.ALL_OBJECTS["GlorotNormalV2"] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS["GlorotUniformV2"] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS["HeNormalV2"] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS["HeUniformV2"] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS["IdentityV2"] = initializers_v2.Identity LOCAL.ALL_OBJECTS["LecunNormalV2"] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS["LecunUniformV2"] = initializers_v2.LecunUniform LOCAL.ALL_OBJECTS["OnesV2"] = initializers_v2.Ones LOCAL.ALL_OBJECTS["OrthogonalV2"] = initializers_v2.Orthogonal LOCAL.ALL_OBJECTS["RandomNormalV2"] = initializers_v2.RandomNormal LOCAL.ALL_OBJECTS["RandomUniformV2"] = initializers_v2.RandomUniform LOCAL.ALL_OBJECTS["TruncatedNormalV2"] = initializers_v2.TruncatedNormal LOCAL.ALL_OBJECTS["VarianceScalingV2"] = initializers_v2.VarianceScaling LOCAL.ALL_OBJECTS["ZerosV2"] = initializers_v2.Zeros # Out of an abundance of caution we also include these aliases that have # a non-zero probability of having been included in saved configs in the past. LOCAL.ALL_OBJECTS["glorot_normalV2"] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS["glorot_uniformV2"] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS["he_normalV2"] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS["he_uniformV2"] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS["lecun_normalV2"] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS["lecun_uniformV2"] = initializers_v2.LecunUniform if tf.__internal__.tf2.enabled(): # For V2, entries are generated automatically based on the content of # initializers_v2.py. v2_objs = {} base_cls = initializers_v2.Initializer generic_utils.populate_dict_with_module_objects( v2_objs, [initializers_v2], obj_filter=lambda x: inspect.isclass(x) and issubclass( x, base_cls), ) for key, value in v2_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value else: # V1 initializers. v1_objs = { "Constant": tf.compat.v1.constant_initializer, "GlorotNormal": tf.compat.v1.glorot_normal_initializer, "GlorotUniform": tf.compat.v1.glorot_uniform_initializer, "Identity": tf.compat.v1.initializers.identity, "Ones": tf.compat.v1.ones_initializer, "Orthogonal": tf.compat.v1.orthogonal_initializer, "VarianceScaling": tf.compat.v1.variance_scaling_initializer, "Zeros": tf.compat.v1.zeros_initializer, "HeNormal": initializers_v1.HeNormal, "HeUniform": initializers_v1.HeUniform, "LecunNormal": initializers_v1.LecunNormal, "LecunUniform": initializers_v1.LecunUniform, "RandomNormal": initializers_v1.RandomNormal, "RandomUniform": initializers_v1.RandomUniform, "TruncatedNormal": initializers_v1.TruncatedNormal, } for key, value in v1_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value # More compatibility aliases. LOCAL.ALL_OBJECTS["normal"] = LOCAL.ALL_OBJECTS["random_normal"] LOCAL.ALL_OBJECTS["uniform"] = LOCAL.ALL_OBJECTS["random_uniform"] LOCAL.ALL_OBJECTS["one"] = LOCAL.ALL_OBJECTS["ones"] LOCAL.ALL_OBJECTS["zero"] = LOCAL.ALL_OBJECTS["zeros"]
def __init__(self, seed=None): super().__init__(scale=2., mode='fan_in', distribution='uniform', seed=seed) def get_config(self): return {'seed': self.seed} # Populate all initializers with their string names. _ALL_INITIALIZERS = {} for name, obj in inspect.getmembers(sys.modules[__name__]): if inspect.isclass(obj) and issubclass(obj, initializers_v2.Initializer): _ALL_INITIALIZERS[name] = obj alternative_name = generic_utils.to_snake_case(name) _ALL_INITIALIZERS[alternative_name] = obj def serialize(initializer): return generic_utils.serialize_keras_object(initializer) def deserialize(config, custom_objects=None): """Return an `Initializer` object from its config.""" return generic_utils.deserialize_keras_object( config, module_objects=_ALL_INITIALIZERS, custom_objects=custom_objects, printable_module_name='initializer')