def __init__(
      self,
      key_dtype=dtypes.int64,
      value_dtype=dtypes.float32,
      dim=1,
      devices=None,
      partitioner=default_partition_fn,
      shared_name=None,
      name="DynamicEmbedding_Variable",
      initializer=None,
      trainable=True,
      checkpoint=True,
      init_size=0,
      kv_creator=None,
      restrict_policy=None,
      bp_v2=False,
  ):
    """Creates an empty `Variable` object.

        Creates a group of tables placed on devices specified by `devices`,
        and the device placement mechanism of TensorFlow will be ignored,
        the type of its keys and values are specified by key_dtype
        and value_dtype, respectively.
        The environment variables 'TF_HASHTABLE_INIT_SIZE' can be used to set the
        inital size of each tables, which can help reduce rehash times.
        The default initial table size is 8,192

        Args:
          key_dtype: the type of the key tensors.
          value_dtype: the type of the value tensors.
          dim: the length of the value array for each key,
            on GPUs, `dim` should be less or equal to 200.
          devices: the list of devices holding the tables.
            One table will be created on each device. By default, `devices` is
            ['/CPU:0'] and when GPU is available, `devices` is ['/GPU:0']
          partitioner: partition function of keys,
            return the partition index for each key.

          Example partition func:
          ```python
          def default_partition_fn(keys, shard_num):
            return tf.cast(keys % shard_num, dtype=tf.int32)
          ```
          shared_name: No used.
          name: A name for the operation (optional).
          initializer: The value to use if a key is missing in the hash table.
            which can be a python number, numpy array or `tf.initializer` instances.
            If initializer is `None` (the default), `0` will be taken.
          trainable: Bool. If true, the variable will be treated as a trainable.
            Default is true.
          checkpoint: if True, the contents of the SparseVariable are
            saved to and restored from checkpoints.
            If `shared_name` is empty for a checkpointed table,
            it is shared using the table node name.
          init_size: initial size for the Variable and initial size of each hash
            tables will be int(init_size / N), N is the number of the devices.
          restrict_policy: a restrict policy to specify the rule to restrict the
            size of variable. If in training program, the variable is updated by
            optimizer, then the sparse slot variables in optimizer are also be
            restricted.
          bp_v2: By default with `bp_v2=False`, the optimizer will update
            dynamic embedding values by *setting* (key, value) after
            `optimizer.apply_gradient`. If one key is used by multiple workers
            at the same time, only one of them will be seen, while the others are
            overwritten. By setting `bp_v2=True`, the optimizer will update
            parameters by *adding delta* instead of *setting*, which solves the
            race condition problem among workers during backpropagation in
            large-scale distributed asynchronous training.

        Returns:
          A `Variable` object.
    """
    self.key_dtype = key_dtype
    self.value_dtype = value_dtype
    self.dim = dim
    self.bp_v2 = bp_v2

    def _get_default_devices():
      gpu_list = [
          x.name
          for x in device_lib.list_local_devices()
          if x.device_type == "GPU"
      ]
      return gpu_list[0:1] or [
          "/CPU:0",
      ]

    devices_ = devices or _get_default_devices()
    self.devices = (devices_ if isinstance(devices_, list) else [
        devices,
    ])
    self.partition_fn = partitioner
    self.name = name
    self.shared_name = shared_name or "shared_name.{}".format(name)

    self.initializer = None

    self.trainable = trainable
    self.checkpoint = checkpoint

    self._tables = data_structures.ListWrapper([])
    self._track_trackable(self._tables,
                          'tables_of_{}'.format(self.name),
                          overwrite=True)
    self.size_ops = []
    self._trainable_store = {}
    self.kv_creator = kv_creator if kv_creator else de.CuckooHashTableCreator()

    self.shard_num = len(self.devices)

    self.init_size = int(init_size)

    if restrict_policy is not None:
      if not issubclass(restrict_policy, de.RestrictPolicy):
        raise TypeError('restrict_policy must be subclass of RestrictPolicy.')
      self._restrict_policy = restrict_policy(self)
    else:
      self._restrict_policy = None

    valid_dtype_list = [[dtypes.int64, dtypes.float32],
                        [dtypes.int64,
                         dtypes.half], [dtypes.int64, dtypes.int32],
                        [dtypes.int64, dtypes.int8],
                        [dtypes.int64, dtypes.int64],
                        [dtypes.int64, dtypes.float64],
                        [dtypes.int64, dtypes.string],
                        [dtypes.int32, dtypes.float32],
                        [dtypes.int32, dtypes.int32],
                        [dtypes.int32, dtypes.float64],
                        [dtypes.string, dtypes.float32],
                        [dtypes.string, dtypes.half],
                        [dtypes.string, dtypes.int32],
                        [dtypes.string, dtypes.int8],
                        [dtypes.string, dtypes.int64],
                        [dtypes.string, dtypes.float64],
                        [dtypes.string, dtypes.bool]]
    if "GPU" in self.devices[0].upper():
      valid_dtype_list = [
          [dtypes.int64, dtypes.float32],
          [dtypes.int64, dtypes.half],
          [dtypes.int64, dtypes.int32],
          [dtypes.int64, dtypes.int8],
          [dtypes.int64, dtypes.int64],
          [dtypes.int32, dtypes.float32],
      ]
    if is_macos() and is_arm64():
      if value_dtype == dtypes.half:
        raise TypeError("""
          float16 value dtype is not supported on macOS with ARM64 architecture. Please try another type.
          """)
    if [key_dtype, value_dtype] not in valid_dtype_list:
      raise TypeError(
          "key-value dtype ({}-{}) is not support! The valid dtypes are \n{}\n".
          format(key_dtype, value_dtype, valid_dtype_list))

    _initializer = initializer
    if _initializer is None:
      _initializer = init_ops.zeros_initializer(dtype=self.value_dtype)
    static_default_value = self._convert_anything_to_init(_initializer, dim)
    scope_name = self.name.split("/")[-1]
    with ops.name_scope(scope_name, "DynamicEmbedding_Variable"):
      with ops.colocate_with(None, ignore_existing=True):
        for idx in range(len(self.devices)):
          with ops.device(self.devices[idx]):
            mht = None
            if not issubclass(self.kv_creator.__class__, de.KVCreator):
              raise TypeError("config should be instance of 'config', but got ",
                              str(type(self.kv_creator)))
            mht = self.kv_creator.create(
                key_dtype=self.key_dtype,
                value_dtype=self.value_dtype,
                default_value=static_default_value,
                name=self._make_name(idx),
                checkpoint=self.checkpoint,
                init_size=int(self.init_size / self.shard_num),
            )
            self._tables.append(mht)
Example #2
0
 def testLayerCollectionWithExternalMutation(self):
     l = []
     l_wrapper = data_structures.ListWrapper(l)
     layer = core.Dense(1)
     l.append(layer)
     self.assertEqual([layer], l_wrapper.layers)
 def _get_serialized_attributes_internal(self, unused_serialization_cache):
     return (
         dict(variables=data_structures.ListWrapper(self.obj.variables)),
         dict())  # TODO(b/135550038): save functions to enable saving