Esempio n. 1
0
    def __call__(self, shape, dtype=None, **kwargs):
        """Returns a tensor object initialized to a 2D identity matrix.

        Args:
          shape: Shape of the tensor. It should have exactly rank 2.
          dtype: Optional dtype of the tensor. Only floating point types are
           supported. If not specified, `tf.keras.backend.floatx()` is used,
           which default to `float32` unless you configured it otherwise
           (via `tf.keras.backend.set_floatx(float_dtype)`)
          **kwargs: Additional keyword arguments.
        """
        _validate_kwargs(self.__class__.__name__,
                         kwargs,
                         support_partition=False)
        dtype = _assert_float_dtype(_get_dtype(dtype))
        if len(shape) != 2:
            raise ValueError(
                "Identity matrix initializer can only be used for 2D matrices. "
                f"Received: shape={shape} of rank {len(shape)}.")
        layout = kwargs.pop("layout", None)
        if layout:
            return utils.call_with_layout(self._generate_init_val,
                                          layout,
                                          shape=shape,
                                          dtype=dtype)
        return self._generate_init_val(shape, dtype)
Esempio n. 2
0
    def __call__(self, shape, dtype=None, **kwargs):
        """Returns a tensor object initialized as specified by the initializer.

        Args:
          shape: Shape of the tensor.
          dtype: Optional dtype of the tensor. Only floating point types are
            supported. If not specified, `tf.keras.backend.floatx()` is used,
            which default to `float32` unless you configured it otherwise (via
            `tf.keras.backend.set_floatx(float_dtype)`)
          **kwargs: Additional keyword arguments.
        """
        _validate_kwargs(self.__class__.__name__, kwargs)
        dtype = _assert_float_dtype(_get_dtype(dtype))
        if _PARTITION_SHAPE in kwargs:
            shape = kwargs[_PARTITION_SHAPE]
        partition_offset = kwargs.get(_PARTITION_OFFSET, None)
        if partition_offset is None:
            # We skip the reuse warning for partitioned variable, since the same
            # initializer will be called multiple times for each partition.
            self._warn_reuse()
        nonce = hash(partition_offset) if partition_offset else None
        layout = kwargs.pop("layout", None)
        if layout:
            _ensure_keras_seeded()
            return utils.call_with_layout(
                self._generate_init_val,
                layout,
                shape=shape,
                dtype=dtype,
                nonce=nonce,
            )
        return self._generate_init_val(shape=shape, dtype=dtype, nonce=nonce)
Esempio n. 3
0
    def __call__(self, shape, dtype=None, **kwargs):
        """Returns a tensor object initialized to an orthogonal matrix.

        Args:
          shape: Shape of the tensor.
          dtype: Optional dtype of the tensor. Only floating point types are
            supported. If not specified, `tf.keras.backend.floatx()` is used,
           which default to `float32` unless you configured it otherwise
           (via `tf.keras.backend.set_floatx(float_dtype)`)
          **kwargs: Additional keyword arguments.
        """
        _validate_kwargs(self.__class__.__name__,
                         kwargs,
                         support_partition=False)
        dtype = _assert_float_dtype(_get_dtype(dtype))
        # Check the shape
        if len(shape) < 2:
            raise ValueError("The tensor to initialize must be "
                             "at least two-dimensional. Received: "
                             f"shape={shape} of rank {len(shape)}.")
        self._warn_reuse()
        layout = kwargs.pop("layout", None)
        if layout:
            _ensure_keras_seeded()
            return utils.call_with_layout(self._generate_init_val,
                                          layout,
                                          shape=shape,
                                          dtype=dtype)
        return self._generate_init_val(shape, dtype)
Esempio n. 4
0
    def __call__(self, shape, dtype=None, **kwargs):
        """Returns a tensor object initialized as specified by the initializer.

        Args:
          shape: Shape of the tensor.
          dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
            are supported. If not specified, `tf.keras.backend.floatx()` is
            used, which default to `float32` unless you configured it otherwise
            (via `tf.keras.backend.set_floatx(float_dtype)`).
          **kwargs: Additional keyword arguments.
        """
        _validate_kwargs(self.__class__.__name__, kwargs)
        dtype = _get_dtype(dtype)
        if not dtype.is_numpy_compatible or dtype == tf.string:
            raise ValueError(
                f"Expected numeric or boolean dtype, got {dtype}.")
        if _PARTITION_SHAPE in kwargs:
            shape = kwargs[_PARTITION_SHAPE]
        layout = kwargs.pop("layout", None)
        if layout:
            return utils.call_with_layout(tf.ones,
                                          layout,
                                          shape=shape,
                                          dtype=dtype)
        return tf.ones(shape, dtype)
Esempio n. 5
0
    def __call__(self, shape, dtype=None, **kwargs):
        """Returns a tensor object initialized as specified by the initializer.

    Args:
      shape: Shape of the tensor.
      dtype: Optional dtype of the tensor. Only floating point and integer
      types are supported. If not specified,
        `tf.keras.backend.floatx()` is used,
       which default to `float32` unless you configured it otherwise
       (via `tf.keras.backend.set_floatx(float_dtype)`).
      **kwargs: Additional keyword arguments.
    """
        _validate_kwargs(self.__class__.__name__, kwargs)
        dtype = _get_dtype(dtype)
        if not dtype.is_floating and not dtype.is_integer:
            raise ValueError(f'Expected float or integer dtype, got {dtype}.')
        if _PARTITION_SHAPE in kwargs:
            shape = kwargs[_PARTITION_SHAPE]
        layout = kwargs.pop('layout', None)
        if layout:
            self._random_generator._force_generator = True
            _ensure_keras_seeded()
            return utils.call_with_layout(
                self._random_generator.random_uniform, layout, shape,
                self.minval, self.maxval, dtype)
        return self._random_generator.random_uniform(shape, self.minval,
                                                     self.maxval, dtype)
Esempio n. 6
0
    def __call__(self, shape, dtype=None, **kwargs):
        """Returns a tensor object initialized as specified by the initializer.

    Args:
      shape: Shape of the tensor.
      dtype: Optional dtype of the tensor. Only floating point types are
        supported. If not specified, `tf.keras.backend.floatx()` is used, which
        default to `float32` unless you configured it otherwise (via
        `tf.keras.backend.set_floatx(float_dtype)`)
      **kwargs: Additional keyword arguments.
    """
        _validate_kwargs(self.__class__.__name__, kwargs)
        dtype = _assert_float_dtype(_get_dtype(dtype))
        if _PARTITION_SHAPE in kwargs:
            shape = kwargs[_PARTITION_SHAPE]
        partition_offset = kwargs.get(_PARTITION_OFFSET, None)
        nonce = hash(partition_offset) if partition_offset else None
        layout = kwargs.pop('layout', None)
        if layout:
            self._random_generator._rng_type = self._random_generator.RNG_STATEFUL
            _ensure_keras_seeded()
            return utils.call_with_layout(self._generate_init_val,
                                          layout,
                                          shape=shape,
                                          dtype=dtype,
                                          nonce=nonce)
        return self._generate_init_val(shape=shape, dtype=dtype, nonce=nonce)
Esempio n. 7
0
    def __call__(self, shape, dtype=None, **kwargs):
        """Returns a tensor object initialized to `self.value`.

    Args:
      shape: Shape of the tensor.
      dtype: Optional dtype of the tensor. If not specified,
       `tf.keras.backend.floatx()` is used,
       which default to `float32` unless you configured it otherwise
       (via `tf.keras.backend.set_floatx(float_dtype)`).
      **kwargs: Additional keyword arguments.
    """
        layout = kwargs.pop('layout', None)
        if layout:
            return utils.call_with_layout(tf.constant,
                                          layout,
                                          self.value,
                                          shape=shape,
                                          dtype=dtype)
        return tf.constant(self.value, dtype=_get_dtype(dtype), shape=shape)
Esempio n. 8
0
def _create_dvariable(layout_map, object_path, variable):
  """Create a new variable instead of using the LazyInitVariable.

  We choose to do this since even the LazyInitVariable might behavior like
  a normal tf.Variable/DVariable, it is not future proof for any new changes
  to variable class. It will also fail the instance type check in python,
  which could affect user's code when they do any filtering based on type to
  find any variables.

  Args:
    layout_map: a LayoutMap which contains the variable_object_path (string) ->
      Layout.
    object_path: string, the object attribute path for the variable.
    variable: LazyInitVariable which will be replaced by the newly created
      tf.Variable.
  Returns:
    A new tf.Variable with correct layout information.
  """
  # TODO(b/228209108): Revisit this in future and see if we can just reuse the
  # LazyInitVariable rather than creating a new tf.Variable instance.
  layout = layout_map[object_path]
  if layout is None:
    variable_rank = variable.shape.rank
    layout = dtensor.Layout.replicated(
        mesh=layout_map.get_default_mesh(),
        rank=variable_rank)
  init_val = variable._initial_value  # pylint: disable=protected-access
  if callable(init_val):
    with lazy_variable.disable_init_variable_creator():
      init_val = utils.call_with_layout(init_val, layout)
  else:
    # The init value is probably already created as a tensor, we will just copy
    # it to mesh and give it a proper layout.
    init_val = dtensor.copy_to_mesh(init_val, layout)
  # Use the original variable name for new DVariable creation. TF was adding
  # ":0" suffix to it.
  variable_name = variable.name
  if variable_name.endswith(':0'):
    variable_name = variable_name[:-2]
  new_variable = dtensor.DVariable(init_val,
                                   trainable=variable.trainable,
                                   name=variable_name)
  return new_variable
Esempio n. 9
0
    def __call__(self, shape, dtype=None, **kwargs):
        """Returns a tensor object initialized to random normal values (truncated).

        Args:
          shape: Shape of the tensor.
          dtype: Optional dtype of the tensor. Only floating point types are
            supported. If not specified, `tf.keras.backend.floatx()` is used,
            which default to `float32` unless you configured it otherwise (via
            `tf.keras.backend.set_floatx(float_dtype)`)
          **kwargs: Additional keyword arguments.
        """
        _validate_kwargs(self.__class__.__name__, kwargs)
        dtype = _assert_float_dtype(_get_dtype(dtype))
        if _PARTITION_SHAPE in kwargs:
            shape = kwargs[_PARTITION_SHAPE]
        partition_offset = kwargs.get(_PARTITION_OFFSET, None)
        if partition_offset is None:
            # We skip the reuse warning for partitioned variable, since the same
            # initializer will be called multiple times for each partition.
            self._warn_reuse()
        nonce = hash(partition_offset) if partition_offset else None
        layout = kwargs.pop("layout", None)
        if layout:
            # TODO(scottzhu): Remove this once the forward compat period above
            # is expired.
            self._random_generator._rng_type = (
                self._random_generator.RNG_STATEFUL)
            _ensure_keras_seeded()
            return utils.call_with_layout(
                self._random_generator.truncated_normal,
                layout,
                shape,
                self.mean,
                self.stddev,
                dtype,
                nonce,
            )
        return self._random_generator.truncated_normal(shape, self.mean,
                                                       self.stddev, dtype,
                                                       nonce)
Esempio n. 10
0
    def __call__(self, shape, dtype=None, **kwargs):
        """Returns a tensor object initialized to random normal values (truncated).

    Args:
      shape: Shape of the tensor.
      dtype: Optional dtype of the tensor. Only floating point types are
        supported. If not specified, `tf.keras.backend.floatx()` is used, which
        default to `float32` unless you configured it otherwise (via
        `tf.keras.backend.set_floatx(float_dtype)`)
      **kwargs: Additional keyword arguments.
    """
        _validate_kwargs(self.__class__.__name__, kwargs)
        dtype = _assert_float_dtype(_get_dtype(dtype))
        if _PARTITION_SHAPE in kwargs:
            shape = kwargs[_PARTITION_SHAPE]
        layout = kwargs.pop('layout', None)
        if layout:
            self._random_generator._force_generator = True
            _ensure_keras_seeded()
            return utils.call_with_layout(
                self._random_generator.truncated_normal, layout, shape,
                self.mean, self.stddev, dtype)
        return self._random_generator.truncated_normal(shape, self.mean,
                                                       self.stddev, dtype)
Esempio n. 11
0
 def __call__(self, shape, dtype=None, **kwargs):
     layout = kwargs.pop('layout', None)
     if layout:
         _ensure_keras_seeded()
     fn = super(TruncatedNormal, self).__call__
     return utils.call_with_layout(fn, layout, shape=shape, dtype=dtype)
Esempio n. 12
0
 def __call__(self, shape, dtype=None, **kwargs):
     layout = kwargs.pop('layout', None)
     fn = super(Constant, self).__call__
     return utils.call_with_layout(fn, layout, shape=shape, dtype=dtype)