Beispiel #1
0
    def __init__(self,
                 key_dtype,
                 value_dtype,
                 default_value,
                 empty_key,
                 initial_num_buckets=None,
                 shared_name=None,
                 name="MutableDenseHashTable",
                 checkpoint=True):
        """Creates an empty `MutableDenseHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      empty_key: the key to use to represent empty buckets internally. Must not
        be used in insert or lookup operations.
      initial_num_buckets: the initial number of buckets.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).
      checkpoint: if True, the contents of the table are saved to and restored
        from checkpoints. If `shared_name` is empty for a checkpointed table, it
        is shared using the table node name.

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
        self._default_value = ops.convert_to_tensor(default_value,
                                                    dtype=value_dtype)
        self._value_shape = self._default_value.get_shape()

        # The table must be shared if checkpointing is requested for multi-worker
        # training to work correctly. Use the node name if no shared_name has been
        # explicitly specified.
        use_node_name_sharing = checkpoint and shared_name is None
        empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
        # pylint: disable=protected-access
        self._table_ref = gen_data_flow_ops._mutable_dense_hash_table(
            empty_key=empty_key,
            shared_name=shared_name,
            use_node_name_sharing=use_node_name_sharing,
            value_dtype=value_dtype,
            value_shape=self._value_shape,
            initial_num_buckets=initial_num_buckets,
            name=name)
        # pylint: enable=protected-access
        super(MutableDenseHashTable,
              self).__init__(key_dtype, value_dtype,
                             self._table_ref.op.name.split("/")[-1])

        if checkpoint:
            saveable = MutableDenseHashTable._Saveable(self, name)
            ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
Beispiel #2
0
  def __init__(self,
               key_dtype,
               value_dtype,
               default_value,
               empty_key,
               initial_num_buckets=None,
               shared_name=None,
               name="MutableDenseHashTable",
               checkpoint=True):
    """Creates an empty `MutableDenseHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      empty_key: the key to use to represent empty buckets internally. Must not
        be used in insert or lookup operations.
      initial_num_buckets: the initial number of buckets.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).
      checkpoint: if True, the contents of the table are saved to and restored
        from checkpoints. If `shared_name` is empty for a checkpointed table, it
        is shared using the table node name.

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
    self._default_value = ops.convert_to_tensor(
        default_value, dtype=value_dtype)
    self._value_shape = self._default_value.get_shape()

    # The table must be shared if checkpointing is requested for multi-worker
    # training to work correctly. Use the node name if no shared_name has been
    # explicitly specified.
    use_node_name_sharing = checkpoint and shared_name is None
    empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
    # pylint: disable=protected-access
    self._table_ref = gen_data_flow_ops._mutable_dense_hash_table(
        empty_key=empty_key,
        shared_name=shared_name,
        use_node_name_sharing=use_node_name_sharing,
        value_dtype=value_dtype,
        value_shape=self._value_shape,
        initial_num_buckets=initial_num_buckets,
        name=name)
    # pylint: enable=protected-access
    super(MutableDenseHashTable, self).__init__(
        key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1])

    if checkpoint:
      saveable = MutableDenseHashTable._Saveable(self, name)
      ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
Beispiel #3
0
    def __init__(
        self,
        key_dtype,
        value_dtype,
        default_value,
        empty_key,
        initial_num_buckets=None,
        shared_name=None,
        name="MutableDenseHashTable",
    ):
        """Creates an empty `MutableDenseHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      empty_key: the key to use to represent empty buckets internally. Must not
        be used in insert or lookup operations.
      initial_num_buckets: the initial number of buckets.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
        self._default_value = ops.convert_to_tensor(default_value, dtype=value_dtype)
        self._value_shape = self._default_value.get_shape()

        empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
        # pylint: disable=protected-access
        self._table_ref = gen_data_flow_ops._mutable_dense_hash_table(
            empty_key=empty_key,
            shared_name=shared_name,
            value_dtype=value_dtype,
            value_shape=self._value_shape,
            initial_num_buckets=initial_num_buckets,
            name=name,
        )
        # pylint: enable=protected-access
        super(MutableDenseHashTable, self).__init__(key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1])
Beispiel #4
0
    def __init__(self,
                 key_dtype,
                 value_dtype,
                 default_value,
                 empty_key,
                 initial_num_buckets=None,
                 shared_name=None,
                 name="MutableDenseHashTable"):
        """Creates an empty `MutableDenseHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      empty_key: the key to use to represent empty buckets internally. Must not
        be used in insert or lookup operations.
      initial_num_buckets: the initial number of buckets.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
        self._default_value = ops.convert_to_tensor(default_value,
                                                    dtype=value_dtype)
        self._value_shape = self._default_value.get_shape()

        empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
        # pylint: disable=protected-access
        self._table_ref = gen_data_flow_ops._mutable_dense_hash_table(
            empty_key=empty_key,
            shared_name=shared_name,
            value_dtype=value_dtype,
            value_shape=self._value_shape,
            initial_num_buckets=initial_num_buckets,
            name=name)
        # pylint: enable=protected-access
        super(MutableDenseHashTable,
              self).__init__(key_dtype, value_dtype,
                             self._table_ref.op.name.split("/")[-1])