def create_resource(self): # The table must be shared if checkpointing is requested for multi-worker # training to work correctly. Use the node name if no shared_name has been # explicitly specified. use_node_name_sharing = self._checkpoint and self._shared_name is None if self._default_value.get_shape().ndims == 0: table_ref = gen_lookup_ops.mutable_hash_table_v2( shared_name=self._shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=self._key_dtype, value_dtype=self._value_dtype, name=self._name) else: table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( shared_name=self._shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=self._key_dtype, value_dtype=self._value_dtype, value_shape=self._default_value.get_shape(), name=self._name) if context.executing_eagerly(): self._table_name = None else: self._table_name = table_ref.op.name.split("/")[-1] return table_ref
def __init__(self, key_dtype, value_dtype, default_value, shared_name=None, name="MutableHashTable", checkpoint=True): """Creates an empty `MutableHashTable` object. Creates a table, the type of its keys and values are specified by key_dtype and value_dtype, respectively. Args: key_dtype: the type of the key tensors. value_dtype: the type of the value tensors. default_value: The value to use if a key is missing in the table. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. name: A name for the operation (optional). checkpoint: if True, the contents of the table are saved to and restored from checkpoints. If `shared_name` is empty for a checkpointed table, it is shared using the table node name. Returns: A `MutableHashTable` object. Raises: ValueError: If checkpoint is True and no name was specified. """ self._default_value = ops.convert_to_tensor(default_value, dtype=value_dtype) self._value_shape = self._default_value.get_shape() # The table must be shared if checkpointing is requested for multi-worker # training to work correctly. Use the node name if no shared_name has been # explicitly specified. use_node_name_sharing = checkpoint and shared_name is None if self._default_value.get_shape().ndims == 0: self._table_ref = gen_lookup_ops.mutable_hash_table_v2( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, name=name) else: self._table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, value_shape=self._default_value.get_shape(), name=name) super(MutableHashTable, self).__init__(key_dtype, value_dtype, self._table_ref.op.name.split( "/")[-1]) if checkpoint: saveable = MutableHashTable._Saveable(self, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)