def test_save_restore_different_partitions(self): fname = os.path.join(self.get_temp_dir(), 'checkpoint') variables = [ variables_lib.Variable([0]), variables_lib.Variable([1]), variables_lib.Variable([2]), variables_lib.Variable([3]) ] s = sharded_variable.ShardedVariable(variables, name='s') cp = util.Checkpoint(s=s) cp.write(fname) variables2 = [variables_lib.Variable([0, 0, 0, 0])] s2 = sharded_variable.ShardedVariable(variables2, name='s') # Restore from 4 partitions into 1. cp2 = util.Checkpoint(s=s2) cp2.restore(fname) self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1, 2, 3]) self.evaluate(cp2.s.variables[0].assign([5, 10, 15, 20])) cp2.write(fname) # Restore 1 partition into 4. cp.restore(fname) self.assertEqual(self.evaluate(cp.s.variables[0]), [5]) self.assertEqual(self.evaluate(cp.s.variables[1]), [10]) self.assertEqual(self.evaluate(cp.s.variables[2]), [15]) self.assertEqual(self.evaluate(cp.s.variables[3]), [20])
def test_validation_errors(self): with self.assertRaisesRegex(ValueError, 'Expected a list of '): sharded_variable.ShardedVariable( [variables_lib.Variable([0]), 'not-a-variable']) with self.assertRaisesRegex(ValueError, 'must have the same dtype'): sharded_variable.ShardedVariable([ variables_lib.Variable([0], dtype='int64'), variables_lib.Variable([1], dtype='int32') ]) with self.assertRaisesRegex(ValueError, 'the same shapes except'): sharded_variable.ShardedVariable([ variables_lib.Variable(array_ops.ones((5, 10))), variables_lib.Variable(array_ops.ones((5, 20))) ]) with self.assertRaisesRegex(ValueError, '`SaveSliceInfo` should not'): v = variables_lib.Variable([0]) v._set_save_slice_info( variables_lib.Variable.SaveSliceInfo(full_name='s', full_shape=[2], var_offset=[0], var_shape=[1])) sharded_variable.ShardedVariable([v])
def test_delayed_restore(self): fname = os.path.join(self.get_temp_dir(), 'checkpoint') model = tracking.AutoTrackable() variables = [ variables_lib.Variable([0]), variables_lib.Variable([1]), variables_lib.Variable([2]), variables_lib.Variable([3]) ] model.s = sharded_variable.ShardedVariable(variables) cp = util.Checkpoint(model=model) cp.write(fname) model2 = tracking.AutoTrackable() cp2 = util.Checkpoint(model=model2) cp2.restore(fname) variables2 = [ variables_lib.Variable([0]), variables_lib.Variable([0]), variables_lib.Variable([0]), variables_lib.Variable([0]) ] model2.s = sharded_variable.ShardedVariable(variables2) self.assertAllEqual(self.evaluate(model2.s.variables[0]), [0]) self.assertAllEqual(self.evaluate(model2.s.variables[1]), [1]) self.assertAllEqual(self.evaluate(model2.s.variables[2]), [2]) self.assertAllEqual(self.evaluate(model2.s.variables[3]), [3])
def test_as_function_input(self): variables1 = [ variables_lib.Variable([1]), variables_lib.Variable([1]), ] s = sharded_variable.ShardedVariable(variables1) variables2 = [ variables_lib.Variable([2]), variables_lib.Variable([2]), ] s2 = sharded_variable.ShardedVariable(variables2) trace_count = [0] @def_function.function def func(sharded_var): trace_count[0] = trace_count[0] + 1 sharded_var.assign([0, 0]) func(s) self.assertAllEqual(ops.convert_to_tensor(s), [0, 0]) self.assertEqual(trace_count[0], 1) func(s2) self.assertAllEqual(ops.convert_to_tensor(s2), [0, 0]) self.assertEqual(trace_count[0], 1)
def __init__(self): super().__init__() variables1 = [ variables_lib.Variable([0]), variables_lib.Variable([1]), ] variables2 = [ variables_lib.Variable([2], trainable=False), variables_lib.Variable([3], trainable=False), ] self.w = sharded_variable.ShardedVariable(variables1) self.b = sharded_variable.ShardedVariable(variables2)
def test_operator_overload(self): v1 = [ variables_lib.Variable([1.]), variables_lib.Variable([2.]), ] sv1 = sharded_variable.ShardedVariable(v1) v2 = [ variables_lib.Variable([1.]), variables_lib.Variable([2.]), ] sv2 = sharded_variable.ShardedVariable(v2) equal = sv1 == sv2 self.assertAllEqual(equal, [True, True]) self.assertAllEqual(sv1 + sv2, [2.0, 4.0])
def test_save_graph_def(self): root = tracking.AutoTrackable() v1 = variables_lib.Variable([3.]) v2 = variables_lib.Variable([2.]) root.v = sharded_variable.ShardedVariable([v1, v2]) root.train = def_function.function( lambda x: embedding_ops.embedding_lookup_v2(root.v.variables, x)) # TODO(b/144057383): Remove the necessity of root.serve once saving context # is made to tf.function cache. root.serve = def_function.function( lambda x: embedding_ops.embedding_lookup_v2( root.v.variables[0], x), input_signature=[ tensor_spec.TensorSpec([2], dtypes.int32, name='x') ]) # Trace and use root.train self.assertAllEqual([3., 2.], root.train([0, 1]).numpy()) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save.save(root, save_dir, root.serve) self.assertAllEqual([3., 2.], _load_and_run(save_dir, {'x': [0, 1]})['output_0']) # Continue using root.train for training self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
def test_convert_to_tensor(self): v0 = variables_lib.Variable([[0, 0]]) v1 = variables_lib.Variable([[1, 1], [2, 2]]) v2 = variables_lib.Variable([[3, 3]]) s = sharded_variable.ShardedVariable([v0, v1, v2]) t = ops.convert_to_tensor(s) self.assertAllEqual(t, [[0, 0], [1, 1], [2, 2], [3, 3]])
def __init__(self): super().__init__() variables = [ variables_lib.Variable([0]), variables_lib.Variable([1]), ] self.w = sharded_variable.ShardedVariable(variables)
def test_assign(self): v0 = variables_lib.Variable([[0, 0]]) v1 = variables_lib.Variable([[1, 1], [2, 2]]) v2 = variables_lib.Variable([[3, 3]]) s = sharded_variable.ShardedVariable([v0, v1, v2]) s.assign([[4, 4], [5, 5], [6, 6], [7, 7]]) self.assertAllEqual(self.evaluate(s.variables[0]), [[4, 4]]) self.assertAllEqual(self.evaluate(s.variables[1]), [[5, 5], [6, 6]]) self.assertAllEqual(self.evaluate(s.variables[2]), [[7, 7]])
def test_numpy(self): v1 = [ variables_lib.Variable([1.]), variables_lib.Variable([2.]), ] sv1 = sharded_variable.ShardedVariable(v1) sv1_np = sv1.numpy() self.assertIsInstance(sv1_np, np.ndarray) self.assertAllEqual(sv1_np, np.array([1., 2.]))
def test_assign_sub(self): v0 = variables_lib.Variable([[0, 0]]) v1 = variables_lib.Variable([[1, 1], [2, 2]]) v2 = variables_lib.Variable([[3, 3]]) s = sharded_variable.ShardedVariable([v0, v1, v2]) s.assign_sub([[0, 0], [1, 1], [1, 1], [3, 3]]) self.assertAllEqual(self.evaluate(s.variables[0]), [[0, 0]]) self.assertAllEqual(self.evaluate(s.variables[1]), [[0, 0], [1, 1]]) self.assertAllEqual(self.evaluate(s.variables[2]), [[0, 0]])
def test_sharded_variable_simple(self): v0 = variables_lib.Variable([0]) v1 = variables_lib.Variable([1]) s = sharded_variable.ShardedVariable([v0, v1], name='s') self.assertEqual(s.variables[0], v0) self.assertEqual(s.variables[1], v1) self.assertEqual(s.shape.as_list(), [2]) self.assertEqual(s.dtype, v0.dtype) self.assertEqual(s.name, 's')
def test_shards_have_container_set(self): v1 = [ variables_lib.Variable([1.]), variables_lib.Variable([2.]), ] sv1 = sharded_variable.ShardedVariable(v1) for v in sv1.variables: self.assertTrue(hasattr(v, '_sharded_container')) self.assertIs(v._sharded_container(), sv1)
def test_assign_add(self): v0 = variables_lib.Variable([[0, 0]]) v1 = variables_lib.Variable([[1, 1], [2, 2]]) v2 = variables_lib.Variable([[3, 3]]) s = sharded_variable.ShardedVariable([v0, v1, v2]) ret = s.assign_add([[1, 1], [1, 1], [2, 2], [2, 2]]) self.assertAllEqual(self.evaluate(s.variables[0]), [[1, 1]]) self.assertAllEqual(self.evaluate(s.variables[1]), [[2, 2], [4, 4]]) self.assertAllEqual(self.evaluate(s.variables[2]), [[5, 5]]) self.assertIs(ret, s)
def sharded_variable_creator(next_creator, **kwargs): if "shape" not in kwargs or kwargs["shape"] is None: raise ValueError("shape must be explicitly specified when creating " "sharded variables") init_fn = kwargs.get("initial_value", None) # We intentionally don't allow non-callable initial_value to ensure the # value is created on PS but not client. If the value is created on # client, it will needed to be sent to PS for variable initialization, # which is inefficient and can potentially hit the 2GB limit on protobuf # serialization. if init_fn is None or not callable(init_fn): raise ValueError("initial_value must be specified as a callable when " "creating sharded variables") # Use "div" partition strategy to partition the variable. full_shape = kwargs["shape"] if self._num_ps < full_shape[0]: num_shards = self._num_ps else: num_shards = full_shape[0] offsets = [] base = full_shape[0] // num_shards extra = full_shape[0] % num_shards for i in range(num_shards): if i == 0: offsets.append(0) else: prev_shard_size = base + (1 if i - 1 < extra else 0) offsets.append(offsets[i - 1] + prev_shard_size) # Note: The way we initialize sharded variables is suboptimal, as it # needs to create the full value tensor separately on each PS which the # variable is going to be placed on. The full value could be very large # and consume a lot of memory. The ideal way is to only create what's # needed on the shard, however that's not practical because: # 1. Initializers don't have sharded behavior support, even though some # initializers (e.g, uniform) can be used directly. # 2. tf.Variable signature requires "initial_value" to be either a value # or a callable without arguments, meaning it is not straightforward # to make the sharded component from it. def init_shard_fn(shard_index): full_value = init_fn() if shard_index < num_shards - 1: return full_value[offsets[shard_index]:offsets[shard_index + 1]] else: return full_value[offsets[shard_index]:] var_list = [] for i in range(num_shards): kwargs["shape"] = None kwargs["initial_value"] = lambda: init_shard_fn(i) var_list.append(next_creator(**kwargs)) result = sharded_variable.ShardedVariable(var_list) return result
def test_flatten(self): variables = [ variables_lib.Variable([0]), variables_lib.Variable([1]), ] s = sharded_variable.ShardedVariable(variables) got = nest.flatten(s) self.assertIs(s, got[0]) got = nest.flatten(s, expand_composites=True) self.assertAllEqual(variables, got)
def test_load_raises_error(self): root = tracking.AutoTrackable() v1 = variables_lib.Variable([3.]) v2 = variables_lib.Variable([2.]) root.v = sharded_variable.ShardedVariable([v1, v2]) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save.save(root, save_dir) with self.assertRaisesRegex( ValueError, 'Loading a saved_model containing ShardedVariable'): load.load(save_dir)
def test_load_raises_error(self): root = tracking.AutoTrackable() v1 = variables_lib.Variable([3.]) v2 = variables_lib.Variable([2.]) root.v = sharded_variable.ShardedVariable([v1, v2]) save_dir = os.path.join(self.get_temp_dir(), 'saved_model') save.save(root, save_dir) with self.assertRaisesWithLiteralMatch( ValueError, 'Loading `ShardedVariable` is not supported'): load.load(save_dir)
def test_embedding_with_sharded_variable(self): layer = keras.layers.Embedding(input_dim=5, output_dim=2) v = [ tf.Variable([[1., 2.], [3., 4.]]), tf.Variable([[5., 6.], [7., 8.]]), tf.Variable([[9., 10.]]) ] model = keras.models.Sequential([layer]) layer.embeddings = sharded_variable.ShardedVariable(v) model.run_eagerly = testing_utils.should_run_eagerly() outputs = model.predict(np.array([[0, 2, 4]], dtype='int32')) self.assertAllClose(outputs, [[[1., 2.], [5., 6.], [9., 10.]]])
def sharded_variable_creator(next_creator, **kwargs): v1_value = kwargs['initial_value']()[0:1] v2_value = kwargs['initial_value']()[1:] kwargs['initial_value'] = v1_value kwargs['shape'] = (1, ) v1 = next_creator(**kwargs) kwargs['initial_value'] = v2_value kwargs['shape'] = (1, ) v2 = next_creator(**kwargs) return sharded_variable.ShardedVariable([v1, v2])
def test_embedding_lookup(self): v = [ variables_lib.Variable([[1., 2.], [3., 4.]]), variables_lib.Variable([[5., 6.], [7., 8.]]), variables_lib.Variable([[9., 10.]]) ] sv = sharded_variable.ShardedVariable(v) @def_function.function def lookup(): ids = constant_op.constant([0, 3, 4]) return embedding_ops.embedding_lookup_v2(sv, ids) @def_function.function def sparse_lookup(): sp_ids = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [2, 2]], values=[0, 3, 4, 1], dense_shape=[3, 3]) return embedding_ops.embedding_lookup_sparse_v2(sv, sp_ids, None) @def_function.function def safe_sparse_lookup(): sp_ids = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [2, 2]], values=[0, -1, 4, 1], dense_shape=[3, 3]) sp_weights = sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [2, 2]], values=[1., 1., -1., 1.], dense_shape=[3, 3]) return embedding_ops.safe_embedding_lookup_sparse_v2( sv, sp_ids, sp_weights) # TODO(chenkai): Add safe_sparse_lookup to the list. Currently # ShardedVariable is converted to a tensor in safe_sparse_lookup. for func in [lookup, sparse_lookup]: num_gather_ops = 0 for op in func.get_concrete_function().graph.get_operations(): if op.type == 'ResourceGather': num_gather_ops += 1 self.assertEqual( num_gather_ops, len(v), 'Number of ResourceGather op does not match' ' expected, possibly due to ShardedVariable accidentally being' ' converted to tensor in embedding_lookup ops.') self.assertAllEqual(lookup(), [[1., 2.], [7., 8.], [9., 10.]]) self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]]) self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]])
def test_save_restore_4_to_2_partitions(self): fname = os.path.join(self.get_temp_dir(), 'checkpoint') variables = [ variables_lib.Variable([0]), variables_lib.Variable([1]), variables_lib.Variable([2]), variables_lib.Variable([3]) ] s = sharded_variable.ShardedVariable(variables, name='s') cp = util.Checkpoint(s=s) cp.write(fname) variables2 = [ variables_lib.Variable([0, 0]), variables_lib.Variable([0, 0]) ] s2 = sharded_variable.ShardedVariable(variables2, name='s') cp2 = util.Checkpoint(s=s2) cp2.restore(fname) # Assert that weights from the 4 partitions were loaded here. self.assertLen(cp2.s.variables, 2) self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1]) self.assertAllEqual(self.evaluate(cp2.s.variables[1]), [2, 3])
def test_control_dep_on_assign(self): v0 = variables_lib.Variable([[0, 0]]) v1 = variables_lib.Variable([[1, 1], [2, 2]]) v2 = variables_lib.Variable([[3, 3]]) s = sharded_variable.ShardedVariable([v0, v1, v2]) @def_function.function def func(): ret = s.assign([[4, 4], [5, 5], [6, 6], [7, 7]]) with ops.control_dependencies([ret]): a = array_ops.ones((1, 1)) with ops.control_dependencies([control_flow_ops.group(ret)]): b = array_ops.ones((1, 1)) return a, b func()
def test_sparse_read(self): v = variables_lib.Variable(array_ops.zeros((30, 1))) indices = constant_op.constant([0, 10, 12, 21, 22]) v0 = variables_lib.Variable(array_ops.zeros((10, 1))) v1 = variables_lib.Variable(array_ops.zeros((10, 1))) v2 = variables_lib.Variable(array_ops.zeros((10, 1))) sv = sharded_variable.ShardedVariable([v0, v1, v2]) self.assertAllEqual(v.sparse_read(indices), sv.sparse_read(indices)) @def_function.function def func(): return v.sparse_read(indices), sv.sparse_read(indices) got, expect = func() self.assertAllEqual(got, expect)
def test_save_restore(self): fname = os.path.join(self.get_temp_dir(), 'checkpoint') variables = [ variables_lib.Variable([0]), variables_lib.Variable([1]), variables_lib.Variable([2]), variables_lib.Variable([3]) ] s = sharded_variable.ShardedVariable(variables, name='s') cp = util.Checkpoint(s=s) self.assertEqual(self.evaluate(cp.s.variables[0]), [0]) cp.write(fname) self.evaluate(cp.s.variables[0].assign([4])) self.assertEqual(self.evaluate(cp.s.variables[0]), [4]) cp.restore(fname) # Tests that the original weights are restored. self.assertEqual(self.evaluate(cp.s.variables[0]), [0])
def test_scatter_add_uneven_partition(self): v = variables_lib.Variable(array_ops.zeros((32, 1))) sparse_delta = indexed_slices.IndexedSlices( values=constant_op.constant([[0.], [1.], [2.], [3.], [4.], [5.]]), indices=constant_op.constant([0, 10, 11, 12, 30, 31])) v0 = variables_lib.Variable(array_ops.zeros((11, 1))) v1 = variables_lib.Variable(array_ops.zeros((11, 1))) v2 = variables_lib.Variable(array_ops.zeros((10, 1))) sv = sharded_variable.ShardedVariable([v0, v1, v2]) v.scatter_add(sparse_delta) sv.scatter_add(sparse_delta) self.assertAllEqual(v, ops.convert_to_tensor(sv)) @def_function.function def func(): v.scatter_add(sparse_delta) sv.scatter_add(sparse_delta) func() self.assertAllEqual(v, ops.convert_to_tensor(sv))
def test_batch_scatter_update(self): v = variables_lib.Variable(array_ops.zeros((32, 1))) sparse_delta = ops.IndexedSlices( values=constant_op.constant([[0.], [1.], [2.], [3.], [4.], [5.]]), indices=constant_op.constant([10, 11, 12, 13, 14, 15])) v0 = variables_lib.Variable(array_ops.zeros((11, 1))) v1 = variables_lib.Variable(array_ops.zeros((11, 1))) v2 = variables_lib.Variable(array_ops.zeros((10, 1))) sv = sharded_variable.ShardedVariable([v0, v1, v2]) v.batch_scatter_update(sparse_delta) sv.batch_scatter_update(sparse_delta) self.assertAllEqual(v, ops.convert_to_tensor(sv)) @def_function.function def func(): v.batch_scatter_update(sparse_delta) sv.batch_scatter_update(sparse_delta) func() self.assertAllEqual(v, ops.convert_to_tensor(sv))
def test_scatter_ops_even_partition(self, op): v = variables_lib.Variable(array_ops.zeros((30, 1))) sparse_delta = ops.IndexedSlices( values=constant_op.constant([[0.], [1.], [2.], [3.], [4.]]), indices=constant_op.constant([0, 10, 12, 21, 22])) v0 = variables_lib.Variable(array_ops.zeros((10, 1))) v1 = variables_lib.Variable(array_ops.zeros((10, 1))) v2 = variables_lib.Variable(array_ops.zeros((10, 1))) sv = sharded_variable.ShardedVariable([v0, v1, v2]) getattr(v, op)(sparse_delta, name='scatter_v') getattr(sv, op)(sparse_delta, name='scatter_sv') self.assertAllEqual(v, ops.convert_to_tensor(sv)) @def_function.function def func(): getattr(v, op)(sparse_delta, name='scatter_v') getattr(sv, op)(sparse_delta, name='scatter_sv') func() self.assertAllEqual(v, ops.convert_to_tensor(sv))
def _create_variable(self, next_creator, **kwargs): """Implements StrategyExtendedV2._create_variable. Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be created if satisfying all the following criteria: 1. `self._variable_partitioner` results in more than one partition on the first axis. 2. variable's rank is greater than 0. 3. variable is not colocated with another variable. Otherwise a `Variable` will be created. Args: next_creator: See `variable_scope.variable_creator_scope`; the next creator in the chain. **kwargs: Passed through to the next creator. Returns: A `Variable` or `ShardedVariable`. """ var_creator = self._create_var_creator(next_creator, **kwargs) if "colocate_with" in kwargs: # Never partition colocated_with variables. colocate_with = kwargs["colocate_with"] # Clear the variable scope to avoid possible conflicts between device # scope and colocation scope. with ops.device(None): with ops.colocate_with(colocate_with): var = var_creator(**kwargs) logging.debug( "Creating variable (name:%s, shape:%r) that colocates with %s", var.name, var.shape, kwargs["colocate_with"].name) return var if self._variable_partitioner is None: return self._create_variable_round_robin(var_creator, **kwargs) name = kwargs.get("name", None) dtype = kwargs.get("dtype", None) shape = kwargs.get("shape", None) initial_value = kwargs.get("initial_value", None) if initial_value is None: # If we are loading, next_creator will return an UninitializedVariable v = next_creator(**kwargs) if not isinstance(v, resource_variable_ops.UninitializedVariable): raise ValueError( "It looks like you are using `ParameterServerStrategy` with a " "`variable_partitioner`, and trying to create a variable without " "specifying `initial_value`. This is not allowed. Please specify the " "`initial_value`.") elif shape is None or dtype is None: raise ValueError( "It looks like you are trying to load a `SavedModel` using " "`tf.saved_model.load` within a `ParameterServerStrategy` scope, " "but the `SavedModel` is missing shape or dtype information." ) else: def initializer(shape, dtype, **kwargs): if "partition_shape" in kwargs: shape = kwargs["partition_shape"] return array_ops.zeros(shape, dtype) initial_value = functools.partial(initializer, shape=shape, dtype=dtype) # Two cases where initial_value can be a callable: # 1. initial_value is passed as a callable, e.g, an `initializer` class. # 2. restoring from checkpoint, initial_value is a # "CheckpointInitialValueCallable". init_from_fn = callable(initial_value) if init_from_fn and (shape is None or dtype is None): init_from_fn = False initial_value = initial_value() if not init_from_fn: # The initial_value is created on coordinator, it will need to be sent to # ps for variable initialization, which can be inefficient and can # potentially hit the 2GB limit on protobuf serialization. initial_value = ops.convert_to_tensor(initial_value, dtype=dtype) dtype = initial_value.dtype shape = initial_value.shape else: shape = tensor_shape.as_shape(shape) if shape.rank == 0: # Skip partitioning rank-0 variable. return self._create_variable_round_robin(var_creator, **kwargs) num_partitions = self._variable_partitioner(shape=shape, dtype=dtype) if not num_partitions or num_partitions[0] == 0 or any( v != 1 for v in num_partitions[1:]): raise ValueError( "variable_partitioner must return a list/tuple whose elements are 1" " besides the first element (non-zero), got: %r" % num_partitions) if num_partitions[0] == 1: # no partition return self._create_variable_round_robin(var_creator, **kwargs) # Use "div" partition strategy to partition the variable. num_partitions = min(num_partitions[0], shape[0]) base = shape[0] // num_partitions extra = shape[0] % num_partitions # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2] # offsets: [0, 3, 6, 8, 10] offsets = [] for i in range(num_partitions): if i == 0: offsets.append(0) else: prev_shard_size = base + (1 if i - 1 < extra else 0) offsets.append(offsets[i - 1] + prev_shard_size) offsets.append(shape[0]) def init_shard_fn(shard_index): if not init_from_fn: logging.log_if( logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) return initial_value[offsets[shard_index]:offsets[shard_index + 1]] partition_shape = (offsets[shard_index + 1] - offsets[shard_index], ) + shape[1:] partition_offset = ( offsets[shard_index], ) + (0, ) * len(shape[1:]) arg_spec = tf_inspect.getfullargspec(initial_value) if ("shard_info" not in arg_spec.args and "shard_info" not in arg_spec.kwonlyargs): try: value = initial_value(partition_shape=partition_shape, partition_offset=partition_offset) except (TypeError, ValueError): # TypeError: Initializer doesn't accept kwargs # ValueError: Initializer doesn't accept partition kwargs # In both cases we go ahead creating the full value and then slice. value = initial_value() if value.shape == partition_shape: # Initializer supports partition: value is the partition value. return value else: # Initializer doesn't support partition: value is the full value # and needs to be sliced to get the partition value. logging.log_if( logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS) return value[offsets[shard_index]:offsets[shard_index + 1]] else: # For compatibility with `CheckpointInitialValueCallable`. return initial_value(shard_info=trackable.ShardInfo( shape=tensor_shape.as_shape(partition_shape), offset=partition_offset)) var_list = [] for i in range(num_partitions): kwargs["shape"] = (offsets[i + 1] - offsets[i], ) + shape[1:] kwargs["initial_value"] = lambda: init_shard_fn(i) if name is not None: kwargs["name"] = "{}/part_{}".format(name, i) var_list.append( self._create_variable_round_robin(var_creator, **kwargs)) result = sharded_variable.ShardedVariable(var_list) return result