示例#1
0
  def __init__(self,
               key_dtype,
               value_dtype,
               default_value,
               shared_name=None,
               name="MutableHashTable",
               checkpoint=True):
    """Creates an empty `MutableHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).
      checkpoint: if True, the contents of the table are saved to and restored
        from checkpoints. If `shared_name` is empty, the table is shared using
        the table node name.

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
    self._default_value = ops.convert_to_tensor(default_value,
                                                dtype=value_dtype)
    self._value_shape = self._default_value.get_shape()

    # The table must be shared if checkpointing is requested. Use the node name
    # if no shared_name has been explicitly specified.
    use_node_name_sharing = checkpoint and shared_name is None
    # pylint: disable=protected-access
    if self._default_value.get_shape().ndims == 0:
      self._table_ref = gen_data_flow_ops._mutable_hash_table(
          shared_name=shared_name,
          use_node_name_sharing=use_node_name_sharing,
          key_dtype=key_dtype,
          value_dtype=value_dtype,
          name=name)
    else:
      self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors(
          shared_name=shared_name,
          use_node_name_sharing=use_node_name_sharing,
          key_dtype=key_dtype,
          value_dtype=value_dtype,
          value_shape=self._default_value.get_shape(),
          name=name)
    # pylint: enable=protected-access
    super(MutableHashTable, self).__init__(key_dtype, value_dtype,
                                           self._table_ref.op.name.split(
                                               "/")[-1])

    if checkpoint:
      saveable = MutableHashTable.MutableHashTableSaveable(self, name)
      ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
示例#2
0
    def __init__(self,
                 key_dtype,
                 value_dtype,
                 default_value,
                 shared_name=None,
                 name="MutableHashTable",
                 checkpoint=True):
        """Creates an empty `MutableHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).
      checkpoint: if True, the contents of the table are saved to and restored
        from checkpoints. If `shared_name` is empty, the table is shared using
        the table node name.

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
        self._default_value = ops.convert_to_tensor(default_value,
                                                    dtype=value_dtype)
        self._value_shape = self._default_value.get_shape()

        # The table must be shared if checkpointing is requested. Use the node name
        # if no shared_name has been explicitly specified.
        use_node_name_sharing = checkpoint and shared_name is None
        # pylint: disable=protected-access
        if self._default_value.get_shape().ndims == 0:
            self._table_ref = gen_data_flow_ops._mutable_hash_table(
                shared_name=shared_name,
                use_node_name_sharing=use_node_name_sharing,
                key_dtype=key_dtype,
                value_dtype=value_dtype,
                name=name)
        else:
            self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors(
                shared_name=shared_name,
                use_node_name_sharing=use_node_name_sharing,
                key_dtype=key_dtype,
                value_dtype=value_dtype,
                value_shape=self._default_value.get_shape(),
                name=name)
        # pylint: enable=protected-access
        super(MutableHashTable,
              self).__init__(key_dtype, value_dtype,
                             self._table_ref.op.name.split("/")[-1])

        if checkpoint:
            saveable = MutableHashTable.MutableHashTableSaveable(self, name)
            ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
示例#3
0
  def __init__(self,
               key_dtype,
               value_dtype,
               default_value,
               shared_name=None,
               name=None):
    """Creates an empty `MutableHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).

    Returns:
      A `MutableHashTable` object.
    """
    self._default_value = ops.convert_to_tensor(default_value,
                                                dtype=value_dtype)
    self._value_shape = self._default_value.get_shape()

    # pylint: disable=protected-access
    if self._default_value.get_shape().ndims == 0:
      self._table_ref = gen_data_flow_ops._mutable_hash_table(
          shared_name=shared_name,
          key_dtype=key_dtype,
          value_dtype=value_dtype,
          name=name)
    else:
      self._table_ref = gen_data_flow_ops._mutable_hash_table_of_tensors(
          shared_name=shared_name,
          key_dtype=key_dtype,
          value_dtype=value_dtype,
          value_shape=self._default_value.get_shape(),
          name=name)
    # pylint: enable=protected-access

    super(MutableHashTable, self).__init__(key_dtype, value_dtype,
                                           self._table_ref.op.name.split(
                                               "/")[-1])