Beispiel #1
0
    def testCheckInitializers(self):
        initializers = {
            "key_a": tf.truncated_normal_initializer(mean=0, stddev=1),
            "key_c": tf.truncated_normal_initializer(mean=0, stddev=1)
        }
        keys = ["key_a", "key_b"]
        self.assertRaisesRegexp(KeyError,
                                "Invalid initializer keys.*",
                                snt.check_initializers,
                                initializers=initializers,
                                keys=keys)

        del initializers["key_c"]
        initializers["key_b"] = "not a function"
        self.assertRaisesRegexp(TypeError,
                                "Initializer for.*",
                                snt.check_initializers,
                                initializers=initializers,
                                keys=keys)

        initializers["key_b"] = {"key_c": "not a function"}
        self.assertRaisesRegexp(TypeError,
                                "Initializer for.*",
                                snt.check_initializers,
                                initializers=initializers,
                                keys=keys)

        initializers["key_b"] = {
            "key_c": tf.truncated_normal_initializer(mean=0, stddev=1),
            "key_d": tf.truncated_normal_initializer(mean=0, stddev=1)
        }
        snt.check_initializers(initializers=initializers, keys=keys)
Beispiel #2
0
  def testCheckInitializers(self):
    initializers = {
        "key_a": tf.truncated_normal_initializer(mean=0, stddev=1),
        "key_c": tf.truncated_normal_initializer(mean=0, stddev=1),
    }
    keys = ["key_a", "key_b"]
    self.assertRaisesRegexp(KeyError,
                            "Invalid initializer keys.*",
                            snt.check_initializers,
                            initializers=initializers,
                            keys=keys)

    del initializers["key_c"]
    initializers["key_b"] = "not a function"
    self.assertRaisesRegexp(TypeError,
                            "Initializer for.*",
                            snt.check_initializers,
                            initializers=initializers,
                            keys=keys)

    initializers["key_b"] = {"key_c": "not a function"}
    self.assertRaisesRegexp(TypeError,
                            "Initializer for.*",
                            snt.check_initializers,
                            initializers=initializers,
                            keys=keys)

    initializers["key_b"] = {
        "key_c": tf.truncated_normal_initializer(mean=0, stddev=1),
        "key_d": tf.truncated_normal_initializer(mean=0, stddev=1),
    }
    snt.check_initializers(initializers=initializers, keys=keys)
Beispiel #3
0
 def __init__(self,
              output_size,
              use_bias=True,
              initializers=None,
              partitioners=None,
              regularizers=None,
              custom_getter=None,
              name="embed_lin"):
     super(AbstractGraphLayer, self).__init__(custom_getter=None, name=name)
     self._output_size = output_size
     self._use_bias = use_bias
     self._input_shape = None
     self.possible_keys = self.get_possible_initializer_keys(
         use_bias=use_bias)
     self._initializers = snt.check_initializers(initializers,
                                                 self.possible_keys)
     self._partitioners = snt.check_partitioners(partitioners,
                                                 self.possible_keys)
     self._regularizers = snt.check_regularizers(regularizers,
                                                 self.possible_keys)