Ejemplo n.º 1
0
 def testRangeInitializer(self):
   self._range_test(
       init_ops_v2.TruncatedNormal(mean=0, stddev=1, seed=126),
       shape=(12, 99, 7),
       target_mean=0.,
       target_max=2,
       target_min=-2)
Ejemplo n.º 2
0
  def __init__(self,
               vocabulary_size: int,
               dim: int,
               initializer: Optional[Callable[[Any], None]],
               optimizer: Optional[_Optimizer] = None,
               combiner: Text = "mean",
               name: Optional[Text] = None):
    """Embedding table configuration.

    Args:
      vocabulary_size: Size of the table's vocabulary (number of rows).
      dim: The embedding dimension (width) of the table.
      initializer: A callable initializer taking one parameter, the shape of the
        variable that will be initialized. Will be called once per task, to
        initialize that task's shard of the embedding table. If not specified,
        defaults to `truncated_normal_initializer` with mean `0.0` and standard
        deviation `1/sqrt(dim)`.
      optimizer: An optional instance of an optimizer parameters class, instance
        of one of `tf.tpu.experimental.embedding.SGD`,
        `tf.tpu.experimental.embedding.Adagrad` or
        `tf.tpu.experimental.embedding.Adam`. It set will override the global
        optimizer passed to `tf.tpu.experimental.embedding.TPUEmbedding`.
      combiner: A string specifying how to reduce if there are multiple entries
        in a single row. Currently 'mean', 'sqrtn', 'sum' are supported, with
        'mean' the default. 'sqrtn' often achieves good accuracy, in particular
        with bag-of-words columns. For more information, see
        `tf.nn.embedding_lookup_sparse`.
      name: An optional string used to name the table. Useful for debugging.

    Returns:
      `TableConfig`.

    Raises:
      ValueError: if `vocabulary_size` is not a positive integer.
      ValueError: if `dim` is not a positive integer.
      ValueError: if `initializer` is specified and is not callable.
      ValueError: if `combiner` is not supported.
    """
    if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
      raise ValueError("Invalid vocabulary_size {}.".format(vocabulary_size))

    if not isinstance(dim, int) or dim < 1:
      raise ValueError("Invalid dim {}.".format(dim))

    if (initializer is not None) and (not callable(initializer)):
      raise ValueError("initializer must be callable if specified.")
    if initializer is None:
      initializer = init_ops_v2.TruncatedNormal(mean=0.0,
                                                stddev=1/math.sqrt(dim))

    if combiner not in ("mean", "sum", "sqrtn"):
      raise ValueError("Invalid combiner {}".format(combiner))

    self.vocabulary_size = vocabulary_size
    self.dim = dim
    self.initializer = initializer
    self.optimizer = optimizer
    self.combiner = combiner
    self.name = name
Ejemplo n.º 3
0
 def testInvalidDataType(self):
   init = init_ops_v2.TruncatedNormal(0.0, 1.0)
   with self.assertRaises(ValueError):
     init([1], dtype=dtypes.int32)
Ejemplo n.º 4
0
 def testInitializePartition(self):
   init = init_ops_v2.TruncatedNormal(0.0, 1.0, seed=1)
   self._partition_test(init)
Ejemplo n.º 5
0
 def testDuplicatedInitializer(self):
   init = init_ops_v2.TruncatedNormal(0.0, 1.0)
   self._duplicated_test(init)
Ejemplo n.º 6
0
 def testInitializerDifferent(self):
   init1 = init_ops_v2.TruncatedNormal(0.0, 1.0, seed=1)
   init2 = init_ops_v2.TruncatedNormal(0.0, 1.0, seed=2)
   self._identical_test(init1, init2, False)
Ejemplo n.º 7
0
 def testInitializerIdentical(self):
   self.skipTest("Not seeming to work in Eager mode")
   init1 = init_ops_v2.TruncatedNormal(0.0, 1.0, seed=1)
   init2 = init_ops_v2.TruncatedNormal(0.0, 1.0, seed=1)
   self._identical_test(init1, init2, True)