예제 #1
0
 def init_shard_fn(shard_index):
     if not init_from_fn:
         logging.log_if(
             logging.WARNING, _INEFFICIENT_INIT_WARNING % name,
             shard_index == 0
             and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
         return initial_value[offsets[shard_index]:offsets[shard_index +
                                                           1]]
     arg_spec = tf_inspect.getfullargspec(initial_value)
     if ("shard_info" not in arg_spec.args
             and "shard_info" not in arg_spec.kwonlyargs):
         # `initial_value` is a callable that doesn't accept `shard_info`.
         logging.log_if(
             logging.WARNING, _INEFFICIENT_INIT_WARNING % name,
             shard_index == 0
             and shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
         full_value = initial_value()
         return full_value[offsets[shard_index]:offsets[shard_index +
                                                        1]]
     else:
         # Memory-efficient way of initializing sharded variable. It requires
         # the `init_fn` to accept a namedtuple `shard_info`.
         component_shape = (offsets[shard_index + 1] -
                            offsets[shard_index], ) + shape[1:]
         offsets_all_axes = (
             offsets[shard_index], ) + (0, ) * len(shape[1:])
         return initial_value(shard_info=trackable.ShardInfo(
             shape=tensor_shape.as_shape(component_shape),
             offset=offsets_all_axes))
예제 #2
0
def warning_once(msg, *args):
    """Generate warning message once.

    Note that the current implementation resembles that of the ``log_every_n()```
    function in ``logging`` but reduces the calling stack by one to ensure
    the multiple warning once messages generated at difference places can be
    displayed correctly.

    Args:
        msg: str, the message to be logged.
        *args: The args to be substitued into the msg.
    """
    caller = logging.get_absl_logger().findCaller()
    count = logging._get_next_log_count_per_token(caller)
    logging.log_if(logging.WARNING, msg, not (count % (1 << 62)), *args)