Exemple #1
0
    def __init__(self,
                 key_dtype,
                 value_dtype,
                 default_value,
                 empty_key,
                 initial_num_buckets=None,
                 shared_name=None,
                 name="MutableDenseHashTable",
                 checkpoint=True):
        """Creates an empty `MutableDenseHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      empty_key: the key to use to represent empty buckets internally. Must not
        be used in insert or lookup operations.
      initial_num_buckets: the initial number of buckets.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).
      checkpoint: if True, the contents of the table are saved to and restored
        from checkpoints. If `shared_name` is empty for a checkpointed table, it
        is shared using the table node name.

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
        self._default_value = ops.convert_to_tensor(default_value,
                                                    dtype=value_dtype)
        self._value_shape = self._default_value.get_shape()

        # The table must be shared if checkpointing is requested for multi-worker
        # training to work correctly. Use the node name if no shared_name has been
        # explicitly specified.
        use_node_name_sharing = checkpoint and shared_name is None
        empty_key = ops.convert_to_tensor(empty_key, dtype=key_dtype)
        self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
            empty_key=empty_key,
            shared_name=shared_name,
            use_node_name_sharing=use_node_name_sharing,
            value_dtype=value_dtype,
            value_shape=self._value_shape,
            initial_num_buckets=initial_num_buckets,
            name=name)
        super(MutableDenseHashTable,
              self).__init__(key_dtype, value_dtype,
                             self._table_ref.op.name.split("/")[-1])

        if checkpoint:
            saveable = MutableDenseHashTable._Saveable(self, name)
            ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
Exemple #2
0
  def __init__(self,
               key_dtype,
               value_dtype,
               default_value,
               empty_key,
               initial_num_buckets=None,
               shared_name=None,
               name="MutableDenseHashTable",
               checkpoint=True):
    """Creates an empty `MutableDenseHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      empty_key: the key to use to represent empty buckets internally. Must not
        be used in insert or lookup operations.
      initial_num_buckets: the initial number of buckets.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).
      checkpoint: if True, the contents of the table are saved to and restored
        from checkpoints. If `shared_name` is empty for a checkpointed table, it
        is shared using the table node name.

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
    self._default_value = ops.convert_to_tensor(
        default_value, dtype=value_dtype, name="default_value")
    self._value_shape = self._default_value.get_shape()

    # The table must be shared if checkpointing is requested for multi-worker
    # training to work correctly. Use the node name if no shared_name has been
    # explicitly specified.
    use_node_name_sharing = checkpoint and shared_name is None
    empty_key = ops.convert_to_tensor(
        empty_key, dtype=key_dtype, name="empty_key")
    self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
        empty_key=empty_key,
        shared_name=shared_name,
        use_node_name_sharing=use_node_name_sharing,
        value_dtype=value_dtype,
        value_shape=self._value_shape,
        initial_num_buckets=initial_num_buckets,
        name=name)
    super(MutableDenseHashTable, self).__init__(
        key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1])

    if checkpoint:
      saveable = MutableDenseHashTable._Saveable(self, name)
      ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
 def create_resource(self):
     # The table must be shared if checkpointing is requested for multi-worker
     # training to work correctly. Use the node name if no shared_name has been
     # explicitly specified.
     use_node_name_sharing = self._checkpoint and self._shared_name is None
     table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
         empty_key=self._empty_key,
         deleted_key=self._deleted_key,
         shared_name=self._shared_name,
         use_node_name_sharing=use_node_name_sharing,
         value_dtype=self._value_dtype,
         value_shape=self._value_shape,
         initial_num_buckets=self._initial_num_buckets,
         name=self._name)
     if context.executing_eagerly():
         self._table_name = None
     else:
         self._table_name = table_ref.op.name.split("/")[-1]
     return table_ref
Exemple #4
0
 def create_resource(self):
   # The table must be shared if checkpointing is requested for multi-worker
   # training to work correctly. Use the node name if no shared_name has been
   # explicitly specified.
   use_node_name_sharing = self._checkpoint and self._shared_name is None
   table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
       empty_key=self._empty_key,
       deleted_key=self._deleted_key,
       shared_name=self._shared_name,
       use_node_name_sharing=use_node_name_sharing,
       value_dtype=self._value_dtype,
       value_shape=self._value_shape,
       initial_num_buckets=self._initial_num_buckets,
       name=self._name)
   if context.executing_eagerly():
     self._table_name = None
   else:
     self._table_name = table_ref.op.name.split("/")[-1]
   return table_ref
Exemple #5
0
    def __init__(self,
                 key_dtype,
                 value_dtype,
                 default_value,
                 empty_key,
                 deleted_key,
                 initial_num_buckets=None,
                 shared_name=None,
                 name="MutableDenseHashTable",
                 checkpoint=True):
        """Creates an empty `MutableDenseHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      empty_key: the key to use to represent empty buckets internally. Must not
        be used in insert, remove or lookup operations.
      initial_num_buckets: the initial number of buckets.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).
      checkpoint: if True, the contents of the table are saved to and restored
        from checkpoints. If `shared_name` is empty for a checkpointed table, it
        is shared using the table node name.
      deleted_key: the key to use to represent deleted buckets internally. Must
        not be used in insert, remove or lookup operations and be different from
        the empty_key.

    Returns:
      A `MutableDenseHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
        self._default_value = ops.convert_to_tensor(default_value,
                                                    dtype=value_dtype,
                                                    name="default_value")
        self._value_shape = self._default_value.get_shape()

        # The table must be shared if checkpointing is requested for multi-worker
        # training to work correctly. Use the node name if no shared_name has been
        # explicitly specified.
        use_node_name_sharing = checkpoint and shared_name is None
        empty_key = ops.convert_to_tensor(empty_key,
                                          dtype=key_dtype,
                                          name="empty_key")
        deleted_key = ops.convert_to_tensor(deleted_key,
                                            dtype=key_dtype,
                                            name="deleted_key")
        executing_eagerly = context.executing_eagerly()
        if executing_eagerly and shared_name is None:
            # TODO(allenl): This will leak memory due to kernel caching by the
            # shared_name attribute value (but is better than the alternative of
            # sharing everything by default when executing eagerly; hopefully creating
            # tables in a loop is uncommon).
            shared_name = "table_%d" % (ops.uid(), )
        self._table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
            empty_key=empty_key,
            deleted_key=deleted_key,
            shared_name=shared_name,
            use_node_name_sharing=use_node_name_sharing,
            value_dtype=value_dtype,
            value_shape=self._value_shape,
            initial_num_buckets=initial_num_buckets,
            name=name)
        if executing_eagerly:
            op_name = None
        else:
            op_name = self._table_ref.op.name.split("/")[-1]
        super(MutableDenseHashTable, self).__init__(key_dtype, value_dtype,
                                                    op_name)

        if checkpoint:
            saveable = MutableDenseHashTable._Saveable(self, name)
            ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)