コード例 #1
0
    def testCheckPartitioners(self):
        partitioners = {
            "key_a": tf.variable_axis_size_partitioner(10),
            "key_c": tf.variable_axis_size_partitioner(10)
        }
        keys = ["key_a", "key_b"]
        self.assertRaisesRegexp(KeyError,
                                "Invalid partitioner keys.*",
                                snt.check_partitioners,
                                partitioners=partitioners,
                                keys=keys)

        del partitioners["key_c"]
        partitioners["key_b"] = "not a function"
        self.assertRaisesRegexp(TypeError,
                                "Partitioner for.*",
                                snt.check_partitioners,
                                partitioners=partitioners,
                                keys=keys)

        partitioners["key_b"] = {"key_c": "not a function"}
        self.assertRaisesRegexp(TypeError,
                                "Partitioner for.*",
                                snt.check_partitioners,
                                partitioners=partitioners,
                                keys=keys)

        partitioners["key_b"] = {
            "key_c": tf.variable_axis_size_partitioner(10),
            "key_d": tf.variable_axis_size_partitioner(10)
        }
        snt.check_partitioners(partitioners=partitioners, keys=keys)
コード例 #2
0
ファイル: util_test.py プロジェクト: ccchang0111/sonnet
  def testCheckPartitioners(self):
    partitioners = {"key_a": tf.variable_axis_size_partitioner(10),
                    "key_c": tf.variable_axis_size_partitioner(10)}
    keys = ["key_a", "key_b"]
    self.assertRaisesRegexp(KeyError,
                            "Invalid partitioner keys.*",
                            snt.check_partitioners,
                            partitioners=partitioners,
                            keys=keys)

    del partitioners["key_c"]
    partitioners["key_b"] = "not a function"
    self.assertRaisesRegexp(TypeError,
                            "Partitioner for.*",
                            snt.check_partitioners,
                            partitioners=partitioners,
                            keys=keys)

    partitioners["key_b"] = {"key_c": "not a function"}
    self.assertRaisesRegexp(TypeError,
                            "Partitioner for.*",
                            snt.check_partitioners,
                            partitioners=partitioners,
                            keys=keys)

    partitioners["key_b"] = {
        "key_c": tf.variable_axis_size_partitioner(10),
        "key_d": tf.variable_axis_size_partitioner(10),
    }
    snt.check_partitioners(partitioners=partitioners, keys=keys)
コード例 #3
0
 def __init__(self,
              output_size,
              use_bias=True,
              initializers=None,
              partitioners=None,
              regularizers=None,
              custom_getter=None,
              name="embed_lin"):
     super(AbstractGraphLayer, self).__init__(custom_getter=None, name=name)
     self._output_size = output_size
     self._use_bias = use_bias
     self._input_shape = None
     self.possible_keys = self.get_possible_initializer_keys(
         use_bias=use_bias)
     self._initializers = snt.check_initializers(initializers,
                                                 self.possible_keys)
     self._partitioners = snt.check_partitioners(partitioners,
                                                 self.possible_keys)
     self._regularizers = snt.check_regularizers(regularizers,
                                                 self.possible_keys)