Esempio n. 1
0
    def testWrapperWeights(self, wrapper):
        """Tests that wrapper weights contain wrapped cells weights."""
        base_cell = layers.SimpleRNNCell(1, name="basic_rnn_cell")
        rnn_cell = wrapper(base_cell)
        rnn_layer = layers.RNN(rnn_cell)
        inputs = tf.convert_to_tensor([[[1]]], dtype=tf.float32)
        rnn_layer(inputs)

        wrapper_name = generic_utils.to_snake_case(wrapper.__name__)
        expected_weights = [
            "rnn/" + wrapper_name + "/" + var
            for var in ("kernel:0", "recurrent_kernel:0", "bias:0")
        ]
        self.assertLen(rnn_cell.weights, 3)
        self.assertCountEqual(
            [v.name for v in rnn_cell.weights], expected_weights
        )
        self.assertCountEqual(
            [v.name for v in rnn_cell.trainable_variables], expected_weights
        )
        self.assertCountEqual(
            [v.name for v in rnn_cell.non_trainable_variables], []
        )
        self.assertCountEqual(
            [v.name for v in rnn_cell.cell.weights], expected_weights
        )
Esempio n. 2
0
def get_custom_object_name(obj):
  """Returns the name to use for a custom loss or metric callable.

  Args:
    obj: Custom loss of metric callable

  Returns:
    Name to use, or `None` if the object was not recognized.
  """
  if hasattr(obj, 'name'):  # Accept `Loss` instance as `Metric`.
    return obj.name
  elif hasattr(obj, '__name__'):  # Function.
    return obj.__name__
  elif hasattr(obj, '__class__'):  # Class instance.
    return generic_utils.to_snake_case(obj.__class__.__name__)
  else:  # Unrecognized object.
    return None
Esempio n. 3
0
 def test_snake_case(self):
     self.assertEqual(generic_utils.to_snake_case('SomeClass'),
                      'some_class')
     self.assertEqual(generic_utils.to_snake_case('Conv2D'), 'conv2d')
     self.assertEqual(generic_utils.to_snake_case('ConvLSTM2D'),
                      'conv_lstm2d')
Esempio n. 4
0
 def test_snake_case(self):
     self.assertEqual(generic_utils.to_snake_case("SomeClass"),
                      "some_class")
     self.assertEqual(generic_utils.to_snake_case("Conv2D"), "conv2d")
     self.assertEqual(generic_utils.to_snake_case("ConvLSTM2D"),
                      "conv_lstm2d")
Esempio n. 5
0
def populate_deserializable_objects():
    """Populates dict ALL_OBJECTS with every built-in initializer."""
    global LOCAL
    if not hasattr(LOCAL, "ALL_OBJECTS"):
        LOCAL.ALL_OBJECTS = {}
        LOCAL.GENERATED_WITH_V2 = None

    if (LOCAL.ALL_OBJECTS
            and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled()):
        # Objects dict is already generated for the proper TF version:
        # do nothing.
        return

    LOCAL.ALL_OBJECTS = {}
    LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()

    # Compatibility aliases (need to exist in both V1 and V2).
    LOCAL.ALL_OBJECTS["ConstantV2"] = initializers_v2.Constant
    LOCAL.ALL_OBJECTS["GlorotNormalV2"] = initializers_v2.GlorotNormal
    LOCAL.ALL_OBJECTS["GlorotUniformV2"] = initializers_v2.GlorotUniform
    LOCAL.ALL_OBJECTS["HeNormalV2"] = initializers_v2.HeNormal
    LOCAL.ALL_OBJECTS["HeUniformV2"] = initializers_v2.HeUniform
    LOCAL.ALL_OBJECTS["IdentityV2"] = initializers_v2.Identity
    LOCAL.ALL_OBJECTS["LecunNormalV2"] = initializers_v2.LecunNormal
    LOCAL.ALL_OBJECTS["LecunUniformV2"] = initializers_v2.LecunUniform
    LOCAL.ALL_OBJECTS["OnesV2"] = initializers_v2.Ones
    LOCAL.ALL_OBJECTS["OrthogonalV2"] = initializers_v2.Orthogonal
    LOCAL.ALL_OBJECTS["RandomNormalV2"] = initializers_v2.RandomNormal
    LOCAL.ALL_OBJECTS["RandomUniformV2"] = initializers_v2.RandomUniform
    LOCAL.ALL_OBJECTS["TruncatedNormalV2"] = initializers_v2.TruncatedNormal
    LOCAL.ALL_OBJECTS["VarianceScalingV2"] = initializers_v2.VarianceScaling
    LOCAL.ALL_OBJECTS["ZerosV2"] = initializers_v2.Zeros

    # Out of an abundance of caution we also include these aliases that have
    # a non-zero probability of having been included in saved configs in the past.
    LOCAL.ALL_OBJECTS["glorot_normalV2"] = initializers_v2.GlorotNormal
    LOCAL.ALL_OBJECTS["glorot_uniformV2"] = initializers_v2.GlorotUniform
    LOCAL.ALL_OBJECTS["he_normalV2"] = initializers_v2.HeNormal
    LOCAL.ALL_OBJECTS["he_uniformV2"] = initializers_v2.HeUniform
    LOCAL.ALL_OBJECTS["lecun_normalV2"] = initializers_v2.LecunNormal
    LOCAL.ALL_OBJECTS["lecun_uniformV2"] = initializers_v2.LecunUniform

    if tf.__internal__.tf2.enabled():
        # For V2, entries are generated automatically based on the content of
        # initializers_v2.py.
        v2_objs = {}
        base_cls = initializers_v2.Initializer
        generic_utils.populate_dict_with_module_objects(
            v2_objs,
            [initializers_v2],
            obj_filter=lambda x: inspect.isclass(x) and issubclass(
                x, base_cls),
        )
        for key, value in v2_objs.items():
            LOCAL.ALL_OBJECTS[key] = value
            # Functional aliases.
            LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
    else:
        # V1 initializers.
        v1_objs = {
            "Constant": tf.compat.v1.constant_initializer,
            "GlorotNormal": tf.compat.v1.glorot_normal_initializer,
            "GlorotUniform": tf.compat.v1.glorot_uniform_initializer,
            "Identity": tf.compat.v1.initializers.identity,
            "Ones": tf.compat.v1.ones_initializer,
            "Orthogonal": tf.compat.v1.orthogonal_initializer,
            "VarianceScaling": tf.compat.v1.variance_scaling_initializer,
            "Zeros": tf.compat.v1.zeros_initializer,
            "HeNormal": initializers_v1.HeNormal,
            "HeUniform": initializers_v1.HeUniform,
            "LecunNormal": initializers_v1.LecunNormal,
            "LecunUniform": initializers_v1.LecunUniform,
            "RandomNormal": initializers_v1.RandomNormal,
            "RandomUniform": initializers_v1.RandomUniform,
            "TruncatedNormal": initializers_v1.TruncatedNormal,
        }
        for key, value in v1_objs.items():
            LOCAL.ALL_OBJECTS[key] = value
            # Functional aliases.
            LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value

    # More compatibility aliases.
    LOCAL.ALL_OBJECTS["normal"] = LOCAL.ALL_OBJECTS["random_normal"]
    LOCAL.ALL_OBJECTS["uniform"] = LOCAL.ALL_OBJECTS["random_uniform"]
    LOCAL.ALL_OBJECTS["one"] = LOCAL.ALL_OBJECTS["ones"]
    LOCAL.ALL_OBJECTS["zero"] = LOCAL.ALL_OBJECTS["zeros"]
Esempio n. 6
0
    def __init__(self, seed=None):
        super().__init__(scale=2.,
                         mode='fan_in',
                         distribution='uniform',
                         seed=seed)

    def get_config(self):
        return {'seed': self.seed}


# Populate all initializers with their string names.
_ALL_INITIALIZERS = {}
for name, obj in inspect.getmembers(sys.modules[__name__]):
    if inspect.isclass(obj) and issubclass(obj, initializers_v2.Initializer):
        _ALL_INITIALIZERS[name] = obj
        alternative_name = generic_utils.to_snake_case(name)
        _ALL_INITIALIZERS[alternative_name] = obj


def serialize(initializer):
    return generic_utils.serialize_keras_object(initializer)


def deserialize(config, custom_objects=None):
    """Return an `Initializer` object from its config."""
    return generic_utils.deserialize_keras_object(
        config,
        module_objects=_ALL_INITIALIZERS,
        custom_objects=custom_objects,
        printable_module_name='initializer')