def init_shard_fn(shard_index):
   if not init_from_fn:
     logging.log_if(
         logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and
         shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
     return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
   arg_spec = tf_inspect.getfullargspec(initial_value)
   if ("shard_info" not in arg_spec.args and
       "shard_info" not in arg_spec.kwonlyargs):
     # `initial_value` is a callable that doesn't accept `shard_info`.
     logging.log_if(
         logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and
         shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
     full_value = initial_value()
     return full_value[offsets[shard_index]:offsets[shard_index + 1]]
   else:
     # Memory-efficient way of initializing sharded variable. It requires
     # the `init_fn` to accept a namedtuple `shard_info`.
     component_shape = (offsets[shard_index + 1] -
                        offsets[shard_index],) + shape[1:]
     offsets_all_axes = (offsets[shard_index],) + (0,) * len(shape[1:])
     return initial_value(
         shard_info=trackable.ShardInfo(
             shape=tensor_shape.as_shape(component_shape),
             offset=offsets_all_axes))
Exemple #2
0
        def init_shard_fn(shard_index):
            if not init_from_fn:
                logging.log_if(
                    logging.WARN, _INEFFICIENT_INIT_WARNING % name,
                    shard_index == 0
                    and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
                return initial_value[offsets[shard_index]:offsets[shard_index +
                                                                  1]]
            partition_shape = (offsets[shard_index + 1] -
                               offsets[shard_index], ) + shape[1:]
            partition_offset = (
                offsets[shard_index], ) + (0, ) * len(shape[1:])
            arg_spec = tf_inspect.getfullargspec(initial_value)
            if ("shard_info" not in arg_spec.args
                    and "shard_info" not in arg_spec.kwonlyargs):
                try:
                    value = initial_value(partition_shape=partition_shape,
                                          partition_offset=partition_offset)
                except (TypeError, ValueError):
                    # TypeError: Initializer doesn't accept kwargs
                    # ValueError: Initializer doesn't accept partition kwargs
                    # In both cases we go ahead creating the full value and then slice.
                    value = initial_value()

                if value.shape == partition_shape:
                    # Initializer supports partition: value is the partition value.
                    return value
                else:
                    # Initializer doesn't support partition: value is the full value
                    # and needs to be sliced to get the partition value.
                    logging.log_if(
                        logging.WARN, _INEFFICIENT_INIT_WARNING % name,
                        shard_index == 0 and
                        shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
                    return value[offsets[shard_index]:offsets[shard_index + 1]]
            else:
                # For compatibility with `CheckpointInitialValueCallable`.
                return initial_value(shard_info=trackable.ShardInfo(
                    shape=tensor_shape.as_shape(partition_shape),
                    offset=partition_offset))