Beispiel #1
0
def _create_partitioned_variables(name,
                                  num_hosts,
                                  vocabulary_size,
                                  embedding_dimension,
                                  initializer,
                                  collections=None):  # pylint: disable=redefined-outer-name
    """Creates ParitionedVariables based on `num_hosts` for `table`."""
    # TODO(shizhiw): automatically place embedding lookup elsewhere?
    if vocabulary_size < num_hosts:
        raise ValueError(
            '`vocabulary_size`({}) is smaller than `num_hosts`({}). '
            'As TPU embedding is not optimized for small tables, '
            'please consider other ways for this embedding lookup.')

    slicing = [num_hosts, 1]

    # TODO(shizhiw): deprecated, use tf.get_variable()?
    return partitioned_variables.create_partitioned_variables(
        name=name,
        slicing=slicing,
        shape=(vocabulary_size, embedding_dimension),
        dtype=dtypes.float32,
        initializer=initializer,
        collections=collections,
        trainable=False)
Beispiel #2
0
 def testRandomInitializer(self):
   # Sanity check that the slices uses a different seed when using a random
   # initializer function.
   with self.cached_session():
     var0, var1 = partitioned_variables.create_partitioned_variables(
         [20, 12], [1, 2], init_ops.random_uniform_initializer())
     variables.global_variables_initializer().run()
     val0, val1 = var0.eval().flatten(), var1.eval().flatten()
     self.assertTrue(np.linalg.norm(val0 - val1) > 1e-6)
   # Negative test that proves that slices have the same values if
   # the random initializer uses a seed.
   with self.cached_session():
     var0, var1 = partitioned_variables.create_partitioned_variables(
         [20, 12], [1, 2], init_ops.random_uniform_initializer(seed=201))
     variables.global_variables_initializer().run()
     val0, val1 = var0.eval().flatten(), var1.eval().flatten()
     self.assertAllClose(val0, val1)
 def testRandomInitializer(self):
   # Sanity check that the slices uses a different seed when using a random
   # initializer function.
   with self.test_session():
     var0, var1 = partitioned_variables.create_partitioned_variables(
         [20, 12], [1, 2], init_ops.random_uniform_initializer())
     variables.global_variables_initializer().run()
     val0, val1 = var0.eval().flatten(), var1.eval().flatten()
     self.assertTrue(np.linalg.norm(val0 - val1) > 1e-6)
   # Negative test that proves that slices have the same values if
   # the random initializer uses a seed.
   with self.test_session():
     var0, var1 = partitioned_variables.create_partitioned_variables(
         [20, 12], [1, 2], init_ops.random_uniform_initializer(seed=201))
     variables.global_variables_initializer().run()
     val0, val1 = var0.eval().flatten(), var1.eval().flatten()
     self.assertAllClose(val0, val1)
Beispiel #4
0
 def testVecConstantInit(self):
   with self.cached_session():
     rnd_par = constant_op.constant([1, 2, 3, 4])
     vs = partitioned_variables.create_partitioned_variables([4], [4], rnd_par)
     variables.global_variables_initializer().run()
     val = array_ops.concat(vs, 0).eval()
     rnd = rnd_par.eval()
     self.assertAllClose(rnd, val)
     self.assertEqual([dtypes.int32] * 4, [v.dtype.base_dtype for v in vs])
     self._TestSaveSpec(vs, ["4 0,1", "4 1,1", "4 2,1", "4 3,1"])
 def testVecConstantInit(self):
   with self.test_session():
     rnd_par = constant_op.constant([1, 2, 3, 4])
     vs = partitioned_variables.create_partitioned_variables([4], [4], rnd_par)
     variables.global_variables_initializer().run()
     val = array_ops.concat(vs, 0).eval()
     rnd = rnd_par.eval()
     self.assertAllClose(rnd, val)
     self.assertEqual([dtypes.int32] * 4, [v.dtype.base_dtype for v in vs])
     self._TestSaveSpec(vs, ["4 0,1", "4 1,1", "4 2,1", "4 3,1"])
 def testDegenerate(self):
   with self.test_session():
     rnd = variables.Variable(random_ops.random_uniform([10, 43]))
     vs = partitioned_variables.create_partitioned_variables(
         rnd.get_shape(), [1, 1], rnd.initialized_value())
     variables.global_variables_initializer().run()
     val = array_ops.concat(vs, 0).eval()
     rnd = rnd.eval()
     self.assertAllClose(rnd, val)
     self._TestSaveSpec(vs, ["10 43 0,10:0,43"])
 def testDegenerate(self):
     with self.test_session():
         rnd = variables.Variable(random_ops.random_uniform([10, 43]))
         vs = partitioned_variables.create_partitioned_variables(
             rnd.get_shape(), [1, 1], rnd.initialized_value())
         variables.global_variables_initializer().run()
         val = array_ops.concat_v2(vs, 0).eval()
         rnd = rnd.eval()
         self.assertAllClose(rnd, val)
         self._TestSaveSpec(vs, ["10 43 0,10:0,43"])
Beispiel #8
0
 def testConstantInit(self):
   with self.cached_session():
     rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
     vs = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
                                                             rnd_par)
     variables.global_variables_initializer().run()
     val = array_ops.concat(vs, 1).eval()
     rnd = rnd_par.eval()
     self.assertAllClose(rnd, val)
     self.assertEqual([dtypes.int32] * 2, [v.dtype.base_dtype for v in vs])
     self._TestSaveSpec(vs, ["2 4 0,2:0,2", "2 4 0,2:2,2"])
 def testConstantInit(self):
   with self.test_session():
     rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
     vs = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
                                                             rnd_par)
     variables.global_variables_initializer().run()
     val = array_ops.concat(vs, 1).eval()
     rnd = rnd_par.eval()
     self.assertAllClose(rnd, val)
     self.assertEqual([dtypes.int32] * 2, [v.dtype.base_dtype for v in vs])
     self._TestSaveSpec(vs, ["2 4 0,2:0,2", "2 4 0,2:2,2"])
  def _random_weights(self, size=50, num_shards=1):
    assert size > 0
    assert num_shards > 0
    assert num_shards <= size

    embedding_weights = partitioned_variables.create_partitioned_variables(
        shape=[size],
        slicing=[num_shards],
        initializer=init_ops.truncated_normal_initializer(
            mean=0.0, stddev=1.0, dtype=dtypes.float32))
    for w in embedding_weights:
      w.initializer.run()
    return embedding_weights
    def _random_weights(self, size=50, num_shards=1):
        assert size > 0
        assert num_shards > 0
        assert num_shards <= size

        embedding_weights = partitioned_variables.create_partitioned_variables(
            shape=[size],
            slicing=[num_shards],
            initializer=init_ops.truncated_normal_initializer(
                mean=0.0, stddev=1.0, dtype=dtypes.float32))
        for w in embedding_weights:
            w.initializer.run()
        return embedding_weights
Beispiel #12
0
def _create_embedding_lookup(input_tensor, vocab_size, dimension,
                             weight_collections, initializer, combiner,
                             trainable, name):
  """Creates embedding variable and does a lookup.

  Args:
    input_tensor: A tensor which should contain sparse id to look up.
    vocab_size: An integer specifying the vocabulary size.
    dimension: An integer specifying the embedding vector dimension.
    weight_collections: List of graph collections to which weights are added.
    initializer: A variable initializer function to be used in embedding
      variable initialization.
    combiner: A string specifying how to reduce if the sparse column is
      multivalent. Currently "mean", "sqrtn" and "sum" are supported:
        * "sum": do not normalize features in the column
        * "mean": do l1 normalization on features in the column
        * "sqrtn": do l2 normalization on features in the column
      For more information: `tf.embedding_lookup_sparse`.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
    name: A string specifying the name of the embedding variable.

  Returns:
    A Tensor with shape [batch_size, dimension] and embedding Variable.

  Raises:
    ValueError: If initializer is None or not callable.
  """
  slicing = _max_size_embedding_partitioner()(vocab_size, dimension)
  logging.info("Slicing=%s for name=%s, vocab_size=%d, embed_dim=%d",
               str(slicing), name, vocab_size, dimension)
  if not initializer:
    raise ValueError("initializer must be defined.")
  if not callable(initializer):
    raise ValueError("initializer must be callable.")
  embeddings = partitioned_variables.create_partitioned_variables(
      shape=[vocab_size, dimension],
      slicing=slicing,
      initializer=initializer,
      dtype=dtypes.float32,
      collections=weight_collections,
      name=name,
      reuse=False,
      trainable=trainable)

  return contrib_embedding_ops.safe_embedding_lookup_sparse(
      embeddings,
      input_tensor,
      default_id=0,
      combiner=combiner,
      name=name), embeddings
Beispiel #13
0
def _create_embedding_lookup(input_tensor, vocab_size, dimension,
                             weight_collections, initializer, combiner,
                             trainable, name):
  """Creates embedding variable and does a lookup.

  Args:
    input_tensor: A tensor which should contain sparse id to look up.
    vocab_size: An integer specifying the vocabulary size.
    dimension: An integer specifying the embedding vector dimension.
    weight_collections: List of graph collections to which weights are added.
    initializer: A variable initializer function to be used in embedding
      variable initialization.
    combiner: A string specifying how to reduce if the sparse column is
      multivalent. Currently "mean", "sqrtn" and "sum" are supported:
        * "sum": do not normalize features in the column
        * "mean": do l1 normalization on features in the column
        * "sqrtn": do l2 normalization on features in the column
      For more information: `tf.embedding_lookup_sparse`.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
    name: A string specifying the name of the embedding variable.

  Returns:
    A Tensor with shape [batch_size, dimension] and embedding Variable.

  Raises:
    ValueError: If initializer is None or not callable.
  """
  slicing = _max_size_embedding_partitioner()(vocab_size, dimension)
  logging.info("Slicing=%s for name=%s, vocab_size=%d, embed_dim=%d",
               str(slicing), name, vocab_size, dimension)
  if not initializer:
    raise ValueError("initializer must be defined.")
  if not callable(initializer):
    raise ValueError("initializer must be callable.")
  embeddings = partitioned_variables.create_partitioned_variables(
      shape=[vocab_size, dimension],
      slicing=slicing,
      initializer=initializer,
      dtype=dtypes.float32,
      collections=weight_collections,
      name=name,
      reuse=False,
      trainable=trainable)

  return contrib_embedding_ops.safe_embedding_lookup_sparse(
      embeddings,
      input_tensor,
      default_id=0,
      combiner=combiner,
      name=name), embeddings
Beispiel #14
0
 def testSliceSizeOne(self):
   with self.cached_session():
     rnd = variables.Variable(random_ops.random_uniform([10, 43]))
     vs = partitioned_variables.create_partitioned_variables(
         rnd.get_shape(), [10, 1], rnd.initialized_value())
     variables.global_variables_initializer().run()
     val = array_ops.concat(vs, 0).eval()
     rnd = rnd.eval()
     self.assertAllClose(rnd, val)
     self._TestSaveSpec(vs, [
         "10 43 0,1:0,43", "10 43 1,1:0,43", "10 43 2,1:0,43",
         "10 43 3,1:0,43", "10 43 4,1:0,43", "10 43 5,1:0,43",
         "10 43 6,1:0,43", "10 43 7,1:0,43", "10 43 8,1:0,43", "10 43 9,1:0,43"
     ])
Beispiel #15
0
 def testIotaInitializer(self):
   self.assertAllClose([0., 1., 2., 3.], _IotaInitializer([4]))
   self.assertAllClose([[0., 1.], [0., 10.], [0., 100.], [0., 1000.]],
                       _IotaInitializer([4, 2]))
   with self.cached_session():
     vs = partitioned_variables.create_partitioned_variables([13, 5], [3, 1],
                                                             _IotaInitializer)
     variables.global_variables_initializer().run()
     slice0 = _IotaInitializer([5, 5])
     slice1 = _IotaInitializer([4, 5])
     slice2 = _IotaInitializer([4, 5])
     val = array_ops.concat(vs, 0).eval()
     self.assertAllClose(slice0 + slice1 + slice2, val)
     self._TestSaveSpec(vs, ["13 5 0,5:0,5", "13 5 5,4:0,5", "13 5 9,4:0,5"])
 def testIotaInitializer(self):
   self.assertAllClose([0., 1., 2., 3.], _IotaInitializer([4]))
   self.assertAllClose([[0., 1.], [0., 10.], [0., 100.], [0., 1000.]],
                       _IotaInitializer([4, 2]))
   with self.test_session():
     vs = partitioned_variables.create_partitioned_variables([13, 5], [3, 1],
                                                             _IotaInitializer)
     variables.global_variables_initializer().run()
     slice0 = _IotaInitializer([5, 5])
     slice1 = _IotaInitializer([4, 5])
     slice2 = _IotaInitializer([4, 5])
     val = array_ops.concat(vs, 0).eval()
     self.assertAllClose(slice0 + slice1 + slice2, val)
     self._TestSaveSpec(vs, ["13 5 0,5:0,5", "13 5 5,4:0,5", "13 5 9,4:0,5"])
 def _testNameHelper(self, use_resource=False):
     with self.cached_session():
         rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
         with variable_scope.variable_scope("hi",
                                            use_resource=use_resource):
             vs1 = partitioned_variables.create_partitioned_variables(
                 [2, 4], [1, 2], rnd_par)
             vs2 = partitioned_variables.create_partitioned_variables(
                 [2, 4], [1, 2], rnd_par)
         variables.global_variables_initializer().run()
         var1_name = vs1[0]._save_slice_info.full_name
         var2_name = vs2[0]._save_slice_info.full_name
         self.assertEqual("hi/PartitionedVariable", var1_name)
         self.assertEqual("hi/PartitionedVariable_1", var2_name)
         self.assertEqual(var1_name + "/part_0:0", vs1[0].name)
         self.assertEqual(var1_name + "/part_1:0", vs1[1].name)
         self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
         self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
     # Test same variable.
     with self.cached_session():
         rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
         with variable_scope.variable_scope(
                 "hola", use_resource=use_resource) as vs:
             vs1 = partitioned_variables.create_partitioned_variables(
                 [2, 4], [1, 2], rnd_par, dtype=dtypes.int32)
         with variable_scope.variable_scope(vs,
                                            reuse=True,
                                            use_resource=use_resource):
             vs2 = partitioned_variables.create_partitioned_variables(
                 [2, 4], [1, 2], rnd_par, dtype=dtypes.int32)
         variables.global_variables_initializer().run()
         var1_name = vs1[0]._save_slice_info.full_name
         var2_name = vs2[0]._save_slice_info.full_name
         self.assertEqual("hola/PartitionedVariable", var1_name)
         self.assertEqual("hola/PartitionedVariable", var2_name)
         self.assertEqual(var1_name + "/part_0:0", vs1[0].name)
         self.assertEqual(var1_name + "/part_1:0", vs1[1].name)
         self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
         self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
     # Test name_scope
     with self.cached_session():
         rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
         with ops.name_scope("ola"):
             vs1 = partitioned_variables.create_partitioned_variables(
                 [2, 4], [1, 2], rnd_par)
             vs2 = partitioned_variables.create_partitioned_variables(
                 [2, 4], [1, 2], rnd_par)
         variables.global_variables_initializer().run()
         var1_name = vs1[0]._save_slice_info.full_name
         var2_name = vs2[0]._save_slice_info.full_name
         # Currently, the name scope 'ola' has no effect.
         self.assertEqual("PartitionedVariable", var1_name)
         self.assertEqual("PartitionedVariable_1", var2_name)
         self.assertEqual(var1_name + "/part_0:0", vs1[0].name)
         self.assertEqual(var1_name + "/part_1:0", vs1[1].name)
         self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
         self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
  def _random_weights(self, vocab_size=4, embed_dim=4, num_shards=1):
    assert vocab_size > 0
    assert embed_dim > 0
    assert num_shards > 0
    assert num_shards <= vocab_size

    embedding_weights = partitioned_variables.create_partitioned_variables(
        shape=[vocab_size, embed_dim],
        slicing=[num_shards, 1],
        initializer=init_ops.truncated_normal_initializer(
            mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
    for w in embedding_weights:
      w.initializer.run()
    embedding_weights = [w.eval() for w in embedding_weights]
    return embedding_weights
  def _random_weights(self, vocab_size=4, embed_dim=4, num_shards=1):
    assert vocab_size > 0
    assert embed_dim > 0
    assert num_shards > 0
    assert num_shards <= vocab_size

    embedding_weights = partitioned_variables.create_partitioned_variables(
        shape=[vocab_size, embed_dim],
        slicing=[num_shards, 1],
        initializer=init_ops.truncated_normal_initializer(
            mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
    for w in embedding_weights:
      w.initializer.run()
    embedding_weights = [w.eval() for w in embedding_weights]
    return embedding_weights
 def testLargePartitionedVariables(self):
   save_path = os.path.join(self.get_temp_dir(), "large_variable")
   var_name = "my_var"
   # Saving large partition variable.
   with session.Session("", graph=ops.Graph()) as sess:
     with ops.device("/cpu:0"):
       # Create a partitioned variable which is larger than int32 size but
       # split into smaller sized variables.
       init = lambda shape, dtype, partition_info: constant_op.constant(
           True, dtype, shape)
       partitioned_var = partitioned_variables.create_partitioned_variables(
           [1 << 31], [4], init, dtype=dtypes.bool, name=var_name)
       variables.global_variables_initializer().run()
       save = saver.Saver(partitioned_var)
       val = save.save(sess, save_path)
       self.assertEqual(save_path, val)
Beispiel #21
0
 def testLargePartitionedVariables(self):
     save_path = os.path.join(self.get_temp_dir(), "large_variable")
     var_name = "my_var"
     # Saving large partition variable.
     with session.Session("", graph=ops.Graph()) as sess:
         with ops.device("/cpu:0"):
             # Create a partitioned variable which is larger than int32 size but
             # split into smaller sized variables.
             init = lambda shape, dtype, partition_info: constant_op.constant(
                 True, dtype, shape)
             partitioned_var = partitioned_variables.create_partitioned_variables(
                 [1 << 31], [4], init, dtype=dtypes.bool, name=var_name)
             variables.global_variables_initializer().run()
             save = saver.Saver(partitioned_var)
             val = save.save(sess, save_path)
             self.assertEqual(save_path, val)
 def testRandomInitValue(self):
   with self.test_session():
     rnd = variables.Variable(random_ops.random_uniform([200, 40]))
     vs = partitioned_variables.create_partitioned_variables(
         rnd.get_shape(), [1, 10], rnd.initialized_value())
     variables.global_variables_initializer().run()
     val = array_ops.concat(vs, 1).eval()
     rnd = rnd.eval()
     self.assertAllClose(rnd, val)
     self.assertEqual([dtypes.float32] * 10, [v.dtype.base_dtype for v in vs])
     self._TestSaveSpec(vs, [
         "200 40 0,200:0,4", "200 40 0,200:4,4", "200 40 0,200:8,4",
         "200 40 0,200:12,4", "200 40 0,200:16,4", "200 40 0,200:20,4",
         "200 40 0,200:24,4", "200 40 0,200:28,4", "200 40 0,200:32,4",
         "200 40 0,200:36,4"
     ])
 def _testNameHelper(self, use_resource=False):
   with self.test_session():
     rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
     with variable_scope.variable_scope("hi", use_resource=use_resource):
       vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
                                                                rnd_par)
       vs2 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
                                                                rnd_par)
     variables.global_variables_initializer().run()
     var1_name = vs1[0]._save_slice_info.full_name
     var2_name = vs2[0]._save_slice_info.full_name
     self.assertEqual("hi/PartitionedVariable", var1_name)
     self.assertEqual("hi/PartitionedVariable_1", var2_name)
     self.assertEqual(var1_name + "/part_0:0", vs1[0].name)
     self.assertEqual(var1_name + "/part_1:0", vs1[1].name)
     self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
     self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
   # Test same variable.
   with self.test_session():
     rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
     with variable_scope.variable_scope(
         "hola", use_resource=use_resource) as vs:
       vs1 = partitioned_variables.create_partitioned_variables(
           [2, 4], [1, 2], rnd_par, dtype=dtypes.int32)
     with variable_scope.variable_scope(
         vs, reuse=True, use_resource=use_resource):
       vs2 = partitioned_variables.create_partitioned_variables(
           [2, 4], [1, 2], rnd_par, dtype=dtypes.int32)
     variables.global_variables_initializer().run()
     var1_name = vs1[0]._save_slice_info.full_name
     var2_name = vs2[0]._save_slice_info.full_name
     self.assertEqual("hola/PartitionedVariable", var1_name)
     self.assertEqual("hola/PartitionedVariable", var2_name)
     self.assertEqual(var1_name + "/part_0:0", vs1[0].name)
     self.assertEqual(var1_name + "/part_1:0", vs1[1].name)
     self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
     self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
   # Test name_scope
   with self.test_session():
     rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
     with ops.name_scope("ola"):
       vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
                                                                rnd_par)
       vs2 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
                                                                rnd_par)
     variables.global_variables_initializer().run()
     var1_name = vs1[0]._save_slice_info.full_name
     var2_name = vs2[0]._save_slice_info.full_name
     # Currently, the name scope 'ola' has no effect.
     self.assertEqual("PartitionedVariable", var1_name)
     self.assertEqual("PartitionedVariable_1", var2_name)
     self.assertEqual(var1_name + "/part_0:0", vs1[0].name)
     self.assertEqual(var1_name + "/part_1:0", vs1[1].name)
     self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
     self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
Beispiel #24
0
 def testRandomInitValue(self):
   with self.cached_session():
     rnd = variables.Variable(random_ops.random_uniform([200, 40]))
     vs = partitioned_variables.create_partitioned_variables(
         rnd.get_shape(), [1, 10], rnd.initialized_value())
     variables.global_variables_initializer().run()
     val = array_ops.concat(vs, 1).eval()
     rnd = rnd.eval()
     self.assertAllClose(rnd, val)
     self.assertEqual([dtypes.float32] * 10, [v.dtype.base_dtype for v in vs])
     self._TestSaveSpec(vs, [
         "200 40 0,200:0,4", "200 40 0,200:4,4", "200 40 0,200:8,4",
         "200 40 0,200:12,4", "200 40 0,200:16,4", "200 40 0,200:20,4",
         "200 40 0,200:24,4", "200 40 0,200:28,4", "200 40 0,200:32,4",
         "200 40 0,200:36,4"
     ])
Beispiel #25
0
 def testRandomInitUnevenPartitions(self):
   with self.test_session():
     rnd = variables.Variable(
         random_ops.random_uniform(
             [20, 43], dtype=dtypes.float64))
     var_lists = [
         partitioned_variables.create_partitioned_variables(
             rnd.get_shape(), [1, i], rnd.initialized_value())
         for i in xrange(1, 10)
     ]
     variables.global_variables_initializer().run()
     rnd_val = rnd.eval()
     # Only check the slice save specs for the first 5 tf.
     save_specs = [
         # One slice
         ["20 43 0,20:0,43"],
         # Two slices
         ["20 43 0,20:0,22", "20 43 0,20:22,21"],
         # Three slices
         ["20 43 0,20:0,15", "20 43 0,20:15,14", "20 43 0,20:29,14"],
         # Four slices
         [
             "20 43 0,20:0,11", "20 43 0,20:11,11", "20 43 0,20:22,11",
             "20 43 0,20:33,10"
         ],
         # Five slices
         [
             "20 43 0,20:0,9", "20 43 0,20:9,9", "20 43 0,20:18,9",
             "20 43 0,20:27,8", "20 43 0,20:35,8"
         ]
     ]
     for i, vs in enumerate(var_lists):
       var_val = array_ops.concat(vs, 1).eval()
       self.assertAllClose(rnd_val, var_val)
       self.assertEqual([dtypes.float64] * len(vs),
                        [v.dtype.base_dtype for v in vs])
       if i < len(save_specs):
         self._TestSaveSpec(vs, save_specs[i])
 def testRandomInitUnevenPartitions(self):
   with self.test_session():
     rnd = variables.Variable(
         random_ops.random_uniform(
             [20, 43], dtype=dtypes.float64))
     var_lists = [
         partitioned_variables.create_partitioned_variables(
             rnd.get_shape(), [1, i], rnd.initialized_value())
         for i in xrange(1, 10)
     ]
     variables.global_variables_initializer().run()
     rnd_val = rnd.eval()
     # Only check the slice save specs for the first 5 tf.
     save_specs = [
         # One slice
         ["20 43 0,20:0,43"],
         # Two slices
         ["20 43 0,20:0,22", "20 43 0,20:22,21"],
         # Three slices
         ["20 43 0,20:0,15", "20 43 0,20:15,14", "20 43 0,20:29,14"],
         # Four slices
         [
             "20 43 0,20:0,11", "20 43 0,20:11,11", "20 43 0,20:22,11",
             "20 43 0,20:33,10"
         ],
         # Five slices
         [
             "20 43 0,20:0,9", "20 43 0,20:9,9", "20 43 0,20:18,9",
             "20 43 0,20:27,8", "20 43 0,20:35,8"
         ]
     ]
     for i, vs in enumerate(var_lists):
       var_val = array_ops.concat(vs, 1).eval()
       self.assertAllClose(rnd_val, var_val)
       self.assertEqual([dtypes.float64] * len(vs),
                        [v.dtype.base_dtype for v in vs])
       if i < len(save_specs):
         self._TestSaveSpec(vs, save_specs[i])
Beispiel #27
0
def _create_partitioned_variables(name,
                                  num_hosts,
                                  vocabulary_size,
                                  embedding_dimension,
                                  initializer,
                                  collections=None):  # pylint: disable=redefined-outer-name
  """Creates ParitionedVariables based on `num_hosts` for `table`."""
  # TODO(shizhiw): automatically place embedding lookup elsewhere?
  if vocabulary_size < num_hosts:
    raise ValueError('`vocabulary_size`({}) is smaller than `num_hosts`({}). '
                     'As TPU embedding is not optimized for small tables, '
                     'please consider other ways for this embedding lookup.')

  slicing = [num_hosts, 1]

  # TODO(shizhiw): deprecated, use tf.get_variable()?
  return partitioned_variables.create_partitioned_variables(
      name=name,
      slicing=slicing,
      shape=(vocabulary_size, embedding_dimension),
      dtype=dtypes.float32,
      initializer=initializer,
      collections=collections,
      trainable=False)
 def testSomeErrors(self):
     with self.test_session():
         rnd = variables.Variable(random_ops.random_uniform([10, 43]))
         with self.assertRaises(ValueError):
             partitioned_variables.create_partitioned_variables(
                 [10], [1, 1], rnd.initialized_value())
         with self.assertRaises(ValueError):
             partitioned_variables.create_partitioned_variables(
                 [10, 20], [1], rnd.initialized_value())
         with self.assertRaises(ValueError):
             partitioned_variables.create_partitioned_variables(
                 [10, 43], [1], rnd.initialized_value())
         with self.assertRaises(ValueError):
             partitioned_variables.create_partitioned_variables(
                 [10, 43], [1, 2, 3], rnd.initialized_value())
         with self.assertRaises(ValueError):
             partitioned_variables.create_partitioned_variables(
                 [10, 43], [11, 1], rnd.initialized_value())
         with self.assertRaises(ValueError):
             partitioned_variables.create_partitioned_variables(
                 [10, 43], [20, 1], rnd.initialized_value())
         with self.assertRaises(ValueError):
             partitioned_variables.create_partitioned_variables(
                 [10, 43], [1, 50], rnd.initialized_value())
 def testSomeErrors(self):
   with self.test_session():
     rnd = variables.Variable(random_ops.random_uniform([10, 43]))
     with self.assertRaises(ValueError):
       partitioned_variables.create_partitioned_variables(
           [10], [1, 1], rnd.initialized_value())
     with self.assertRaises(ValueError):
       partitioned_variables.create_partitioned_variables(
           [10, 20], [1], rnd.initialized_value())
     with self.assertRaises(ValueError):
       partitioned_variables.create_partitioned_variables(
           [10, 43], [1], rnd.initialized_value())
     with self.assertRaises(ValueError):
       partitioned_variables.create_partitioned_variables(
           [10, 43], [1, 2, 3], rnd.initialized_value())
     with self.assertRaises(ValueError):
       partitioned_variables.create_partitioned_variables(
           [10, 43], [11, 1], rnd.initialized_value())
     with self.assertRaises(ValueError):
       partitioned_variables.create_partitioned_variables(
           [10, 43], [20, 1], rnd.initialized_value())
     with self.assertRaises(ValueError):
       partitioned_variables.create_partitioned_variables(
           [10, 43], [1, 50], rnd.initialized_value())