def __init__(self, key_dtype, value_dtype, default_value, shared_name=None, name="MutableHashTable", checkpoint=True): """Creates an empty `MutableHashTable` object. Creates a table, the type of its keys and values are specified by key_dtype and value_dtype, respectively. Args: key_dtype: the type of the key tensors. value_dtype: the type of the value tensors. default_value: The value to use if a key is missing in the table. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. name: A name for the operation (optional). checkpoint: if True, the contents of the table are saved to and restored from checkpoints. If `shared_name` is empty, the table is shared using the table node name. Returns: A `MutableHashTable` object. Raises: ValueError: If checkpoint is True and no name was specified. """ self._default_value = ops.convert_to_tensor(default_value, dtype=value_dtype) self._value_shape = self._default_value.get_shape() # The table must be shared if checkpointing is requested. Use the node name # if no shared_name has been explicitly specified. use_node_name_sharing = checkpoint and shared_name is None # pylint: disable=protected-access if self._default_value.get_shape().ndims == 0: self._table_ref = gen_data_flow_ops._mutable_hash_table( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, name=name) else: self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, value_shape=self._default_value.get_shape(), name=name) # pylint: enable=protected-access super(MutableHashTable, self).__init__(key_dtype, value_dtype, self._table_ref.op.name.split( "/")[-1]) if checkpoint: saveable = MutableHashTable.MutableHashTableSaveable(self, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
def __init__(self, key_dtype, value_dtype, default_value, shared_name=None, name="MutableHashTable", checkpoint=True): """Creates an empty `MutableHashTable` object. Creates a table, the type of its keys and values are specified by key_dtype and value_dtype, respectively. Args: key_dtype: the type of the key tensors. value_dtype: the type of the value tensors. default_value: The value to use if a key is missing in the table. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. name: A name for the operation (optional). checkpoint: if True, the contents of the table are saved to and restored from checkpoints. If `shared_name` is empty, the table is shared using the table node name. Returns: A `MutableHashTable` object. Raises: ValueError: If checkpoint is True and no name was specified. """ self._default_value = ops.convert_to_tensor(default_value, dtype=value_dtype) self._value_shape = self._default_value.get_shape() # The table must be shared if checkpointing is requested. Use the node name # if no shared_name has been explicitly specified. use_node_name_sharing = checkpoint and shared_name is None # pylint: disable=protected-access if self._default_value.get_shape().ndims == 0: self._table_ref = gen_data_flow_ops._mutable_hash_table( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, name=name) else: self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors( shared_name=shared_name, use_node_name_sharing=use_node_name_sharing, key_dtype=key_dtype, value_dtype=value_dtype, value_shape=self._default_value.get_shape(), name=name) # pylint: enable=protected-access super(MutableHashTable, self).__init__(key_dtype, value_dtype, self._table_ref.op.name.split("/")[-1]) if checkpoint: saveable = MutableHashTable.MutableHashTableSaveable(self, name) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
def __init__(self, key_dtype, value_dtype, default_value, shared_name=None, name=None): """Creates an empty `MutableHashTable` object. Creates a table, the type of its keys and values are specified by key_dtype and value_dtype, respectively. Args: key_dtype: the type of the key tensors. value_dtype: the type of the value tensors. default_value: The value to use if a key is missing in the table. shared_name: If non-empty, this table will be shared under the given name across multiple sessions. name: A name for the operation (optional). Returns: A `MutableHashTable` object. """ self._default_value = ops.convert_to_tensor(default_value, dtype=value_dtype) self._value_shape = self._default_value.get_shape() # pylint: disable=protected-access if self._default_value.get_shape().ndims == 0: self._table_ref = gen_data_flow_ops._mutable_hash_table( shared_name=shared_name, key_dtype=key_dtype, value_dtype=value_dtype, name=name) else: self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors( shared_name=shared_name, key_dtype=key_dtype, value_dtype=value_dtype, value_shape=self._default_value.get_shape(), name=name) # pylint: enable=protected-access super(MutableHashTable, self).__init__(key_dtype, value_dtype, self._table_ref.op.name.split( "/")[-1])