Exemplo n.º 1
0
    def create_resource(self):
        # The table must be shared if checkpointing is requested for multi-worker
        # training to work correctly. Use the node name if no shared_name has been
        # explicitly specified.
        use_node_name_sharing = self._checkpoint and self._shared_name is None
        if self._default_value.get_shape().ndims == 0:
            table_ref = gen_lookup_ops.mutable_hash_table_v2(
                shared_name=self._shared_name,
                use_node_name_sharing=use_node_name_sharing,
                key_dtype=self._key_dtype,
                value_dtype=self._value_dtype,
                name=self._name)
        else:
            table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
                shared_name=self._shared_name,
                use_node_name_sharing=use_node_name_sharing,
                key_dtype=self._key_dtype,
                value_dtype=self._value_dtype,
                value_shape=self._default_value.get_shape(),
                name=self._name)

        if context.executing_eagerly():
            self._table_name = None
        else:
            self._table_name = table_ref.op.name.split("/")[-1]
        return table_ref
Exemplo n.º 2
0
  def create_resource(self):
    # The table must be shared if checkpointing is requested for multi-worker
    # training to work correctly. Use the node name if no shared_name has been
    # explicitly specified.
    use_node_name_sharing = self._checkpoint and self._shared_name is None
    if self._default_value.get_shape().ndims == 0:
      table_ref = gen_lookup_ops.mutable_hash_table_v2(
          shared_name=self._shared_name,
          use_node_name_sharing=use_node_name_sharing,
          key_dtype=self._key_dtype,
          value_dtype=self._value_dtype,
          name=self._name)
    else:
      table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
          shared_name=self._shared_name,
          use_node_name_sharing=use_node_name_sharing,
          key_dtype=self._key_dtype,
          value_dtype=self._value_dtype,
          value_shape=self._default_value.get_shape(),
          name=self._name)

    if context.executing_eagerly():
      self._table_name = None
    else:
      self._table_name = table_ref.op.name.split("/")[-1]
    return table_ref
Exemplo n.º 3
0
  def __init__(self,
               key_dtype,
               value_dtype,
               default_value,
               shared_name=None,
               name="MutableHashTable",
               checkpoint=True):
    """Creates an empty `MutableHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).
      checkpoint: if True, the contents of the table are saved to and restored
        from checkpoints. If `shared_name` is empty for a checkpointed table, it
        is shared using the table node name.

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
    self._default_value = ops.convert_to_tensor(default_value,
                                                dtype=value_dtype)
    self._value_shape = self._default_value.get_shape()

    # The table must be shared if checkpointing is requested for multi-worker
    # training to work correctly. Use the node name if no shared_name has been
    # explicitly specified.
    use_node_name_sharing = checkpoint and shared_name is None
    if self._default_value.get_shape().ndims == 0:
      self._table_ref = gen_lookup_ops.mutable_hash_table_v2(
          shared_name=shared_name,
          use_node_name_sharing=use_node_name_sharing,
          key_dtype=key_dtype,
          value_dtype=value_dtype,
          name=name)
    else:
      self._table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
          shared_name=shared_name,
          use_node_name_sharing=use_node_name_sharing,
          key_dtype=key_dtype,
          value_dtype=value_dtype,
          value_shape=self._default_value.get_shape(),
          name=name)
    super(MutableHashTable, self).__init__(key_dtype, value_dtype,
                                           self._table_ref.op.name.split(
                                               "/")[-1])

    if checkpoint:
      saveable = MutableHashTable._Saveable(self, name)
      ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
Exemplo n.º 4
0
  def __init__(self,
               key_dtype,
               value_dtype,
               default_value,
               shared_name=None,
               name="MutableHashTable",
               checkpoint=True):
    """Creates an empty `MutableHashTable` object.

    Creates a table, the type of its keys and values are specified by key_dtype
    and value_dtype, respectively.

    Args:
      key_dtype: the type of the key tensors.
      value_dtype: the type of the value tensors.
      default_value: The value to use if a key is missing in the table.
      shared_name: If non-empty, this table will be shared under
        the given name across multiple sessions.
      name: A name for the operation (optional).
      checkpoint: if True, the contents of the table are saved to and restored
        from checkpoints. If `shared_name` is empty for a checkpointed table, it
        is shared using the table node name.

    Returns:
      A `MutableHashTable` object.

    Raises:
      ValueError: If checkpoint is True and no name was specified.
    """
    self._default_value = ops.convert_to_tensor(default_value,
                                                dtype=value_dtype)
    self._value_shape = self._default_value.get_shape()

    # The table must be shared if checkpointing is requested for multi-worker
    # training to work correctly. Use the node name if no shared_name has been
    # explicitly specified.
    use_node_name_sharing = checkpoint and shared_name is None
    if self._default_value.get_shape().ndims == 0:
      self._table_ref = gen_lookup_ops.mutable_hash_table_v2(
          shared_name=shared_name,
          use_node_name_sharing=use_node_name_sharing,
          key_dtype=key_dtype,
          value_dtype=value_dtype,
          name=name)
    else:
      self._table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
          shared_name=shared_name,
          use_node_name_sharing=use_node_name_sharing,
          key_dtype=key_dtype,
          value_dtype=value_dtype,
          value_shape=self._default_value.get_shape(),
          name=name)
    super(MutableHashTable, self).__init__(key_dtype, value_dtype,
                                           self._table_ref.op.name.split(
                                               "/")[-1])

    if checkpoint:
      saveable = MutableHashTable._Saveable(self, name)
      ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
Exemplo n.º 5
0
 def __init__(self, name, table_ref=None):
     if table_ref is None:
         self.table_ref = gen_lookup_ops.mutable_hash_table_v2(
             key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
     else:
         self.table_ref = table_ref
     self._name = name
     if not context.executing_eagerly():
         self._saveable = CheckpointedOp.CustomSaveable(self, name)
         ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
                                   self._saveable)
Exemplo n.º 6
0
 def __init__(self, name, table_ref=None):
   if table_ref is None:
     self.table_ref = gen_lookup_ops.mutable_hash_table_v2(
         key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
   else:
     self.table_ref = table_ref
   self._name = name
   if not context.executing_eagerly():
     self._saveable = CheckpointedOp.CustomSaveable(self, name)
     ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
                               self._saveable)