def test_save_restore_different_partitions(self):
    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
    variables = [
        variables_lib.Variable([0]),
        variables_lib.Variable([1]),
        variables_lib.Variable([2]),
        variables_lib.Variable([3])
    ]
    s = sharded_variable.ShardedVariable(variables, name='s')

    cp = util.Checkpoint(s=s)
    cp.write(fname)

    variables2 = [variables_lib.Variable([0, 0, 0, 0])]
    s2 = sharded_variable.ShardedVariable(variables2, name='s')

    # Restore from 4 partitions into 1.
    cp2 = util.Checkpoint(s=s2)
    cp2.restore(fname)
    self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1, 2, 3])

    self.evaluate(cp2.s.variables[0].assign([5, 10, 15, 20]))
    cp2.write(fname)

    # Restore 1 partition into 4.
    cp.restore(fname)
    self.assertEqual(self.evaluate(cp.s.variables[0]), [5])
    self.assertEqual(self.evaluate(cp.s.variables[1]), [10])
    self.assertEqual(self.evaluate(cp.s.variables[2]), [15])
    self.assertEqual(self.evaluate(cp.s.variables[3]), [20])
Exemple #2
0
    def test_validation_errors(self):
        with self.assertRaisesRegex(ValueError, 'Expected a list of '):
            sharded_variable.ShardedVariable(
                [variables_lib.Variable([0]), 'not-a-variable'])

        with self.assertRaisesRegex(ValueError, 'must have the same dtype'):
            sharded_variable.ShardedVariable([
                variables_lib.Variable([0], dtype='int64'),
                variables_lib.Variable([1], dtype='int32')
            ])

        with self.assertRaisesRegex(ValueError, 'the same shapes except'):
            sharded_variable.ShardedVariable([
                variables_lib.Variable(array_ops.ones((5, 10))),
                variables_lib.Variable(array_ops.ones((5, 20)))
            ])

        with self.assertRaisesRegex(ValueError, '`SaveSliceInfo` should not'):
            v = variables_lib.Variable([0])
            v._set_save_slice_info(
                variables_lib.Variable.SaveSliceInfo(full_name='s',
                                                     full_shape=[2],
                                                     var_offset=[0],
                                                     var_shape=[1]))
            sharded_variable.ShardedVariable([v])
  def test_delayed_restore(self):
    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
    model = tracking.AutoTrackable()
    variables = [
        variables_lib.Variable([0]),
        variables_lib.Variable([1]),
        variables_lib.Variable([2]),
        variables_lib.Variable([3])
    ]
    model.s = sharded_variable.ShardedVariable(variables)
    cp = util.Checkpoint(model=model)
    cp.write(fname)

    model2 = tracking.AutoTrackable()
    cp2 = util.Checkpoint(model=model2)
    cp2.restore(fname)
    variables2 = [
        variables_lib.Variable([0]),
        variables_lib.Variable([0]),
        variables_lib.Variable([0]),
        variables_lib.Variable([0])
    ]
    model2.s = sharded_variable.ShardedVariable(variables2)
    self.assertAllEqual(self.evaluate(model2.s.variables[0]), [0])
    self.assertAllEqual(self.evaluate(model2.s.variables[1]), [1])
    self.assertAllEqual(self.evaluate(model2.s.variables[2]), [2])
    self.assertAllEqual(self.evaluate(model2.s.variables[3]), [3])
  def test_as_function_input(self):
    variables1 = [
        variables_lib.Variable([1]),
        variables_lib.Variable([1]),
    ]
    s = sharded_variable.ShardedVariable(variables1)
    variables2 = [
        variables_lib.Variable([2]),
        variables_lib.Variable([2]),
    ]
    s2 = sharded_variable.ShardedVariable(variables2)

    trace_count = [0]

    @def_function.function
    def func(sharded_var):
      trace_count[0] = trace_count[0] + 1
      sharded_var.assign([0, 0])

    func(s)
    self.assertAllEqual(ops.convert_to_tensor(s), [0, 0])
    self.assertEqual(trace_count[0], 1)
    func(s2)
    self.assertAllEqual(ops.convert_to_tensor(s2), [0, 0])
    self.assertEqual(trace_count[0], 1)
Exemple #5
0
 def __init__(self):
     super().__init__()
     variables1 = [
         variables_lib.Variable([0]),
         variables_lib.Variable([1]),
     ]
     variables2 = [
         variables_lib.Variable([2], trainable=False),
         variables_lib.Variable([3], trainable=False),
     ]
     self.w = sharded_variable.ShardedVariable(variables1)
     self.b = sharded_variable.ShardedVariable(variables2)
  def test_operator_overload(self):
    v1 = [
        variables_lib.Variable([1.]),
        variables_lib.Variable([2.]),
    ]
    sv1 = sharded_variable.ShardedVariable(v1)

    v2 = [
        variables_lib.Variable([1.]),
        variables_lib.Variable([2.]),
    ]
    sv2 = sharded_variable.ShardedVariable(v2)

    equal = sv1 == sv2
    self.assertAllEqual(equal, [True, True])
    self.assertAllEqual(sv1 + sv2, [2.0, 4.0])
Exemple #7
0
    def test_save_graph_def(self):
        root = tracking.AutoTrackable()
        v1 = variables_lib.Variable([3.])
        v2 = variables_lib.Variable([2.])
        root.v = sharded_variable.ShardedVariable([v1, v2])
        root.train = def_function.function(
            lambda x: embedding_ops.embedding_lookup_v2(root.v.variables, x))
        # TODO(b/144057383): Remove the necessity of root.serve once saving context
        # is made to tf.function cache.
        root.serve = def_function.function(
            lambda x: embedding_ops.embedding_lookup_v2(
                root.v.variables[0], x),
            input_signature=[
                tensor_spec.TensorSpec([2], dtypes.int32, name='x')
            ])

        # Trace and use root.train
        self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())

        save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
        save.save(root, save_dir, root.serve)
        self.assertAllEqual([3., 2.],
                            _load_and_run(save_dir, {'x': [0, 1]})['output_0'])

        # Continue using root.train for training
        self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
 def test_convert_to_tensor(self):
   v0 = variables_lib.Variable([[0, 0]])
   v1 = variables_lib.Variable([[1, 1], [2, 2]])
   v2 = variables_lib.Variable([[3, 3]])
   s = sharded_variable.ShardedVariable([v0, v1, v2])
   t = ops.convert_to_tensor(s)
   self.assertAllEqual(t, [[0, 0], [1, 1], [2, 2], [3, 3]])
 def __init__(self):
   super().__init__()
   variables = [
       variables_lib.Variable([0]),
       variables_lib.Variable([1]),
   ]
   self.w = sharded_variable.ShardedVariable(variables)
Exemple #10
0
 def test_assign(self):
     v0 = variables_lib.Variable([[0, 0]])
     v1 = variables_lib.Variable([[1, 1], [2, 2]])
     v2 = variables_lib.Variable([[3, 3]])
     s = sharded_variable.ShardedVariable([v0, v1, v2])
     s.assign([[4, 4], [5, 5], [6, 6], [7, 7]])
     self.assertAllEqual(self.evaluate(s.variables[0]), [[4, 4]])
     self.assertAllEqual(self.evaluate(s.variables[1]), [[5, 5], [6, 6]])
     self.assertAllEqual(self.evaluate(s.variables[2]), [[7, 7]])
 def test_numpy(self):
     v1 = [
         variables_lib.Variable([1.]),
         variables_lib.Variable([2.]),
     ]
     sv1 = sharded_variable.ShardedVariable(v1)
     sv1_np = sv1.numpy()
     self.assertIsInstance(sv1_np, np.ndarray)
     self.assertAllEqual(sv1_np, np.array([1., 2.]))
Exemple #12
0
 def test_assign_sub(self):
     v0 = variables_lib.Variable([[0, 0]])
     v1 = variables_lib.Variable([[1, 1], [2, 2]])
     v2 = variables_lib.Variable([[3, 3]])
     s = sharded_variable.ShardedVariable([v0, v1, v2])
     s.assign_sub([[0, 0], [1, 1], [1, 1], [3, 3]])
     self.assertAllEqual(self.evaluate(s.variables[0]), [[0, 0]])
     self.assertAllEqual(self.evaluate(s.variables[1]), [[0, 0], [1, 1]])
     self.assertAllEqual(self.evaluate(s.variables[2]), [[0, 0]])
 def test_sharded_variable_simple(self):
   v0 = variables_lib.Variable([0])
   v1 = variables_lib.Variable([1])
   s = sharded_variable.ShardedVariable([v0, v1], name='s')
   self.assertEqual(s.variables[0], v0)
   self.assertEqual(s.variables[1], v1)
   self.assertEqual(s.shape.as_list(), [2])
   self.assertEqual(s.dtype, v0.dtype)
   self.assertEqual(s.name, 's')
 def test_shards_have_container_set(self):
     v1 = [
         variables_lib.Variable([1.]),
         variables_lib.Variable([2.]),
     ]
     sv1 = sharded_variable.ShardedVariable(v1)
     for v in sv1.variables:
         self.assertTrue(hasattr(v, '_sharded_container'))
         self.assertIs(v._sharded_container(), sv1)
 def test_assign_add(self):
   v0 = variables_lib.Variable([[0, 0]])
   v1 = variables_lib.Variable([[1, 1], [2, 2]])
   v2 = variables_lib.Variable([[3, 3]])
   s = sharded_variable.ShardedVariable([v0, v1, v2])
   ret = s.assign_add([[1, 1], [1, 1], [2, 2], [2, 2]])
   self.assertAllEqual(self.evaluate(s.variables[0]), [[1, 1]])
   self.assertAllEqual(self.evaluate(s.variables[1]), [[2, 2], [4, 4]])
   self.assertAllEqual(self.evaluate(s.variables[2]), [[5, 5]])
   self.assertIs(ret, s)
    def sharded_variable_creator(next_creator, **kwargs):
      if "shape" not in kwargs or kwargs["shape"] is None:
        raise ValueError("shape must be explicitly specified when creating "
                         "sharded variables")
      init_fn = kwargs.get("initial_value", None)
      # We intentionally don't allow non-callable initial_value to ensure the
      # value is created on PS but not client. If the value is created on
      # client, it will needed to be sent to PS for variable initialization,
      # which is inefficient and can potentially hit the 2GB limit on protobuf
      # serialization.
      if init_fn is None or not callable(init_fn):
        raise ValueError("initial_value must be specified as a callable when "
                         "creating sharded variables")

      # Use "div" partition strategy to partition the variable.
      full_shape = kwargs["shape"]
      if self._num_ps < full_shape[0]:
        num_shards = self._num_ps
      else:
        num_shards = full_shape[0]
      offsets = []
      base = full_shape[0] // num_shards
      extra = full_shape[0] % num_shards
      for i in range(num_shards):
        if i == 0:
          offsets.append(0)
        else:
          prev_shard_size = base + (1 if i - 1 < extra else 0)
          offsets.append(offsets[i - 1] + prev_shard_size)

      # Note: The way we initialize sharded variables is suboptimal, as it
      # needs to create the full value tensor separately on each PS which the
      # variable is going to be placed on. The full value could be very large
      # and consume a lot of memory. The ideal way is to only create what's
      # needed on the shard, however that's not practical because:
      #  1. Initializers don't have sharded behavior support, even though some
      #     initializers (e.g, uniform) can be used directly.
      #  2. tf.Variable signature requires "initial_value" to be either a value
      #     or a callable without arguments, meaning it is not straightforward
      #     to make the sharded component from it.
      def init_shard_fn(shard_index):
        full_value = init_fn()
        if shard_index < num_shards - 1:
          return full_value[offsets[shard_index]:offsets[shard_index + 1]]
        else:
          return full_value[offsets[shard_index]:]

      var_list = []
      for i in range(num_shards):
        kwargs["shape"] = None
        kwargs["initial_value"] = lambda: init_shard_fn(i)
        var_list.append(next_creator(**kwargs))

      result = sharded_variable.ShardedVariable(var_list)
      return result
  def test_flatten(self):
    variables = [
        variables_lib.Variable([0]),
        variables_lib.Variable([1]),
    ]
    s = sharded_variable.ShardedVariable(variables)

    got = nest.flatten(s)
    self.assertIs(s, got[0])

    got = nest.flatten(s, expand_composites=True)
    self.assertAllEqual(variables, got)
  def test_load_raises_error(self):
    root = tracking.AutoTrackable()
    v1 = variables_lib.Variable([3.])
    v2 = variables_lib.Variable([2.])
    root.v = sharded_variable.ShardedVariable([v1, v2])

    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
    save.save(root, save_dir)

    with self.assertRaisesRegex(
        ValueError, 'Loading a saved_model containing ShardedVariable'):
      load.load(save_dir)
    def test_load_raises_error(self):
        root = tracking.AutoTrackable()
        v1 = variables_lib.Variable([3.])
        v2 = variables_lib.Variable([2.])
        root.v = sharded_variable.ShardedVariable([v1, v2])

        save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
        save.save(root, save_dir)

        with self.assertRaisesWithLiteralMatch(
                ValueError, 'Loading `ShardedVariable` is not supported'):
            load.load(save_dir)
Exemple #20
0
 def test_embedding_with_sharded_variable(self):
     layer = keras.layers.Embedding(input_dim=5, output_dim=2)
     v = [
         tf.Variable([[1., 2.], [3., 4.]]),
         tf.Variable([[5., 6.], [7., 8.]]),
         tf.Variable([[9., 10.]])
     ]
     model = keras.models.Sequential([layer])
     layer.embeddings = sharded_variable.ShardedVariable(v)
     model.run_eagerly = testing_utils.should_run_eagerly()
     outputs = model.predict(np.array([[0, 2, 4]], dtype='int32'))
     self.assertAllClose(outputs, [[[1., 2.], [5., 6.], [9., 10.]]])
Exemple #21
0
        def sharded_variable_creator(next_creator, **kwargs):
            v1_value = kwargs['initial_value']()[0:1]
            v2_value = kwargs['initial_value']()[1:]

            kwargs['initial_value'] = v1_value
            kwargs['shape'] = (1, )
            v1 = next_creator(**kwargs)

            kwargs['initial_value'] = v2_value
            kwargs['shape'] = (1, )
            v2 = next_creator(**kwargs)

            return sharded_variable.ShardedVariable([v1, v2])
    def test_embedding_lookup(self):
        v = [
            variables_lib.Variable([[1., 2.], [3., 4.]]),
            variables_lib.Variable([[5., 6.], [7., 8.]]),
            variables_lib.Variable([[9., 10.]])
        ]
        sv = sharded_variable.ShardedVariable(v)

        @def_function.function
        def lookup():
            ids = constant_op.constant([0, 3, 4])
            return embedding_ops.embedding_lookup_v2(sv, ids)

        @def_function.function
        def sparse_lookup():
            sp_ids = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1],
                                                         [1, 0], [2, 2]],
                                                values=[0, 3, 4, 1],
                                                dense_shape=[3, 3])
            return embedding_ops.embedding_lookup_sparse_v2(sv, sp_ids, None)

        @def_function.function
        def safe_sparse_lookup():
            sp_ids = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1],
                                                         [1, 0], [2, 2]],
                                                values=[0, -1, 4, 1],
                                                dense_shape=[3, 3])
            sp_weights = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1],
                                                             [1, 0], [2, 2]],
                                                    values=[1., 1., -1., 1.],
                                                    dense_shape=[3, 3])
            return embedding_ops.safe_embedding_lookup_sparse_v2(
                sv, sp_ids, sp_weights)

        # TODO(chenkai): Add safe_sparse_lookup to the list. Currently
        # ShardedVariable is converted to a tensor in safe_sparse_lookup.
        for func in [lookup, sparse_lookup]:
            num_gather_ops = 0
            for op in func.get_concrete_function().graph.get_operations():
                if op.type == 'ResourceGather':
                    num_gather_ops += 1
            self.assertEqual(
                num_gather_ops, len(v),
                'Number of ResourceGather op does not match'
                ' expected, possibly due to ShardedVariable accidentally being'
                ' converted to tensor in embedding_lookup ops.')

        self.assertAllEqual(lookup(), [[1., 2.], [7., 8.], [9., 10.]])
        self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]])
        self.assertAllClose(safe_sparse_lookup(),
                            [[1., 2.], [0., 0.], [3., 4.]])
  def test_save_restore_4_to_2_partitions(self):
    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
    variables = [
        variables_lib.Variable([0]),
        variables_lib.Variable([1]),
        variables_lib.Variable([2]),
        variables_lib.Variable([3])
    ]
    s = sharded_variable.ShardedVariable(variables, name='s')
    cp = util.Checkpoint(s=s)
    cp.write(fname)

    variables2 = [
        variables_lib.Variable([0, 0]),
        variables_lib.Variable([0, 0])
    ]
    s2 = sharded_variable.ShardedVariable(variables2, name='s')
    cp2 = util.Checkpoint(s=s2)
    cp2.restore(fname)
    # Assert that weights from the 4 partitions were loaded here.
    self.assertLen(cp2.s.variables, 2)
    self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1])
    self.assertAllEqual(self.evaluate(cp2.s.variables[1]), [2, 3])
  def test_control_dep_on_assign(self):
    v0 = variables_lib.Variable([[0, 0]])
    v1 = variables_lib.Variable([[1, 1], [2, 2]])
    v2 = variables_lib.Variable([[3, 3]])
    s = sharded_variable.ShardedVariable([v0, v1, v2])

    @def_function.function
    def func():
      ret = s.assign([[4, 4], [5, 5], [6, 6], [7, 7]])
      with ops.control_dependencies([ret]):
        a = array_ops.ones((1, 1))
      with ops.control_dependencies([control_flow_ops.group(ret)]):
        b = array_ops.ones((1, 1))
      return a, b

    func()
  def test_sparse_read(self):
    v = variables_lib.Variable(array_ops.zeros((30, 1)))
    indices = constant_op.constant([0, 10, 12, 21, 22])

    v0 = variables_lib.Variable(array_ops.zeros((10, 1)))
    v1 = variables_lib.Variable(array_ops.zeros((10, 1)))
    v2 = variables_lib.Variable(array_ops.zeros((10, 1)))
    sv = sharded_variable.ShardedVariable([v0, v1, v2])

    self.assertAllEqual(v.sparse_read(indices), sv.sparse_read(indices))

    @def_function.function
    def func():
      return v.sparse_read(indices), sv.sparse_read(indices)

    got, expect = func()
    self.assertAllEqual(got, expect)
  def test_save_restore(self):
    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
    variables = [
        variables_lib.Variable([0]),
        variables_lib.Variable([1]),
        variables_lib.Variable([2]),
        variables_lib.Variable([3])
    ]
    s = sharded_variable.ShardedVariable(variables, name='s')

    cp = util.Checkpoint(s=s)
    self.assertEqual(self.evaluate(cp.s.variables[0]), [0])
    cp.write(fname)

    self.evaluate(cp.s.variables[0].assign([4]))
    self.assertEqual(self.evaluate(cp.s.variables[0]), [4])

    cp.restore(fname)
    # Tests that the original weights are restored.
    self.assertEqual(self.evaluate(cp.s.variables[0]), [0])
    def test_scatter_add_uneven_partition(self):
        v = variables_lib.Variable(array_ops.zeros((32, 1)))
        sparse_delta = indexed_slices.IndexedSlices(
            values=constant_op.constant([[0.], [1.], [2.], [3.], [4.], [5.]]),
            indices=constant_op.constant([0, 10, 11, 12, 30, 31]))

        v0 = variables_lib.Variable(array_ops.zeros((11, 1)))
        v1 = variables_lib.Variable(array_ops.zeros((11, 1)))
        v2 = variables_lib.Variable(array_ops.zeros((10, 1)))
        sv = sharded_variable.ShardedVariable([v0, v1, v2])

        v.scatter_add(sparse_delta)
        sv.scatter_add(sparse_delta)
        self.assertAllEqual(v, ops.convert_to_tensor(sv))

        @def_function.function
        def func():
            v.scatter_add(sparse_delta)
            sv.scatter_add(sparse_delta)

        func()
        self.assertAllEqual(v, ops.convert_to_tensor(sv))
  def test_batch_scatter_update(self):
    v = variables_lib.Variable(array_ops.zeros((32, 1)))
    sparse_delta = ops.IndexedSlices(
        values=constant_op.constant([[0.], [1.], [2.], [3.], [4.], [5.]]),
        indices=constant_op.constant([10, 11, 12, 13, 14, 15]))

    v0 = variables_lib.Variable(array_ops.zeros((11, 1)))
    v1 = variables_lib.Variable(array_ops.zeros((11, 1)))
    v2 = variables_lib.Variable(array_ops.zeros((10, 1)))
    sv = sharded_variable.ShardedVariable([v0, v1, v2])

    v.batch_scatter_update(sparse_delta)
    sv.batch_scatter_update(sparse_delta)
    self.assertAllEqual(v, ops.convert_to_tensor(sv))

    @def_function.function
    def func():
      v.batch_scatter_update(sparse_delta)
      sv.batch_scatter_update(sparse_delta)

    func()
    self.assertAllEqual(v, ops.convert_to_tensor(sv))
Exemple #29
0
    def test_scatter_ops_even_partition(self, op):
        v = variables_lib.Variable(array_ops.zeros((30, 1)))
        sparse_delta = ops.IndexedSlices(
            values=constant_op.constant([[0.], [1.], [2.], [3.], [4.]]),
            indices=constant_op.constant([0, 10, 12, 21, 22]))

        v0 = variables_lib.Variable(array_ops.zeros((10, 1)))
        v1 = variables_lib.Variable(array_ops.zeros((10, 1)))
        v2 = variables_lib.Variable(array_ops.zeros((10, 1)))
        sv = sharded_variable.ShardedVariable([v0, v1, v2])

        getattr(v, op)(sparse_delta, name='scatter_v')
        getattr(sv, op)(sparse_delta, name='scatter_sv')
        self.assertAllEqual(v, ops.convert_to_tensor(sv))

        @def_function.function
        def func():
            getattr(v, op)(sparse_delta, name='scatter_v')
            getattr(sv, op)(sparse_delta, name='scatter_sv')

        func()
        self.assertAllEqual(v, ops.convert_to_tensor(sv))
Exemple #30
0
    def _create_variable(self, next_creator, **kwargs):
        """Implements StrategyExtendedV2._create_variable.

    Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be
    created if satisfying all the following criteria:
      1. `self._variable_partitioner` results in more than one partition on the
         first axis.
      2. variable's rank is greater than 0.
      3. variable is not colocated with another variable.
    Otherwise a `Variable` will be created.

    Args:
      next_creator: See `variable_scope.variable_creator_scope`; the next
        creator in the chain.
      **kwargs: Passed through to the next creator.

    Returns:
      A `Variable` or `ShardedVariable`.
    """

        var_creator = self._create_var_creator(next_creator, **kwargs)
        if "colocate_with" in kwargs:  # Never partition colocated_with variables.
            colocate_with = kwargs["colocate_with"]
            # Clear the variable scope to avoid possible conflicts between device
            # scope and colocation scope.
            with ops.device(None):
                with ops.colocate_with(colocate_with):
                    var = var_creator(**kwargs)
                    logging.debug(
                        "Creating variable (name:%s, shape:%r) that colocates with %s",
                        var.name, var.shape, kwargs["colocate_with"].name)
                    return var

        if self._variable_partitioner is None:
            return self._create_variable_round_robin(var_creator, **kwargs)

        name = kwargs.get("name", None)
        dtype = kwargs.get("dtype", None)
        shape = kwargs.get("shape", None)
        initial_value = kwargs.get("initial_value", None)
        if initial_value is None:
            # If we are loading, next_creator will return an UninitializedVariable
            v = next_creator(**kwargs)
            if not isinstance(v, resource_variable_ops.UninitializedVariable):
                raise ValueError(
                    "It looks like you are using `ParameterServerStrategy` with a "
                    "`variable_partitioner`, and trying to create a variable without "
                    "specifying `initial_value`. This is not allowed. Please specify the "
                    "`initial_value`.")
            elif shape is None or dtype is None:
                raise ValueError(
                    "It looks like you are trying to load a `SavedModel` using "
                    "`tf.saved_model.load` within a `ParameterServerStrategy` scope, "
                    "but the `SavedModel` is missing shape or dtype information."
                )
            else:

                def initializer(shape, dtype, **kwargs):
                    if "partition_shape" in kwargs:
                        shape = kwargs["partition_shape"]
                    return array_ops.zeros(shape, dtype)

                initial_value = functools.partial(initializer,
                                                  shape=shape,
                                                  dtype=dtype)

        # Two cases where initial_value can be a callable:
        #   1. initial_value is passed as a callable, e.g, an `initializer` class.
        #   2. restoring from checkpoint, initial_value is a
        #     "CheckpointInitialValueCallable".
        init_from_fn = callable(initial_value)

        if init_from_fn and (shape is None or dtype is None):
            init_from_fn = False
            initial_value = initial_value()
        if not init_from_fn:
            # The initial_value is created on coordinator, it will need to be sent to
            # ps for variable initialization, which can be inefficient and can
            # potentially hit the 2GB limit on protobuf serialization.
            initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
            dtype = initial_value.dtype
            shape = initial_value.shape
        else:
            shape = tensor_shape.as_shape(shape)

        if shape.rank == 0:  # Skip partitioning rank-0 variable.
            return self._create_variable_round_robin(var_creator, **kwargs)

        num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
        if not num_partitions or num_partitions[0] == 0 or any(
                v != 1 for v in num_partitions[1:]):
            raise ValueError(
                "variable_partitioner must return a list/tuple whose elements are 1"
                " besides the first element (non-zero), got: %r" %
                num_partitions)

        if num_partitions[0] == 1:  # no partition
            return self._create_variable_round_robin(var_creator, **kwargs)

        # Use "div" partition strategy to partition the variable.
        num_partitions = min(num_partitions[0], shape[0])
        base = shape[0] // num_partitions
        extra = shape[0] % num_partitions
        # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
        # offsets: [0, 3, 6, 8, 10]
        offsets = []
        for i in range(num_partitions):
            if i == 0:
                offsets.append(0)
            else:
                prev_shard_size = base + (1 if i - 1 < extra else 0)
                offsets.append(offsets[i - 1] + prev_shard_size)
        offsets.append(shape[0])

        def init_shard_fn(shard_index):
            if not init_from_fn:
                logging.log_if(
                    logging.WARN, _INEFFICIENT_INIT_WARNING % name,
                    shard_index == 0
                    and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
                return initial_value[offsets[shard_index]:offsets[shard_index +
                                                                  1]]
            partition_shape = (offsets[shard_index + 1] -
                               offsets[shard_index], ) + shape[1:]
            partition_offset = (
                offsets[shard_index], ) + (0, ) * len(shape[1:])
            arg_spec = tf_inspect.getfullargspec(initial_value)
            if ("shard_info" not in arg_spec.args
                    and "shard_info" not in arg_spec.kwonlyargs):
                try:
                    value = initial_value(partition_shape=partition_shape,
                                          partition_offset=partition_offset)
                except (TypeError, ValueError):
                    # TypeError: Initializer doesn't accept kwargs
                    # ValueError: Initializer doesn't accept partition kwargs
                    # In both cases we go ahead creating the full value and then slice.
                    value = initial_value()

                if value.shape == partition_shape:
                    # Initializer supports partition: value is the partition value.
                    return value
                else:
                    # Initializer doesn't support partition: value is the full value
                    # and needs to be sliced to get the partition value.
                    logging.log_if(
                        logging.WARN, _INEFFICIENT_INIT_WARNING % name,
                        shard_index == 0 and
                        shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
                    return value[offsets[shard_index]:offsets[shard_index + 1]]
            else:
                # For compatibility with `CheckpointInitialValueCallable`.
                return initial_value(shard_info=trackable.ShardInfo(
                    shape=tensor_shape.as_shape(partition_shape),
                    offset=partition_offset))

        var_list = []
        for i in range(num_partitions):
            kwargs["shape"] = (offsets[i + 1] - offsets[i], ) + shape[1:]
            kwargs["initial_value"] = lambda: init_shard_fn(i)
            if name is not None:
                kwargs["name"] = "{}/part_{}".format(name, i)
            var_list.append(
                self._create_variable_round_robin(var_creator, **kwargs))

        result = sharded_variable.ShardedVariable(var_list)
        return result