Beispiel #1
0
    def lookup(self, keys, name=None):
        """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is used for keys not present in the table.

    Args:
      keys: Keys to look up. Can be a tensor of any shape. Must match the
        table's key_dtype.
      name: A name for the operation (optional).

    Returns:
      A tensor containing the values in the same shape as `keys` using the
        table's value type.

    Raises:
      TypeError: when `keys` do not match the table data types.
    """
        if keys.dtype != self._key_dtype:
            raise TypeError(
                "Signature mismatch. Keys must be dtype %s, got %s." %
                (self._key_dtype, keys.dtype))

        with ops.name_scope(name, "%s_lookup_table_find" % self._name,
                            [self._table_ref, keys]) as name:
            with ops.colocate_with(self._table_ref):
                # pylint: disable=protected-access
                values = gen_lookup_ops._lookup_table_find_v2(
                    self._table_ref, keys, self._default_value, name=name)

        if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
            values.set_shape(
                tensor_shape.TensorShape([keys.get_shape().dims[0]
                                          ]).concatenate(self._value_shape))
        return values
Beispiel #2
0
  def lookup(self, keys, name=None):
    """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is used for keys not present in the table.

    Args:
      keys: Keys to look up. Can be a tensor of any shape. Must match the
        table's key_dtype.
      name: A name for the operation (optional).

    Returns:
      A tensor containing the values in the same shape as `keys` using the
        table's value type.

    Raises:
      TypeError: when `keys` do not match the table data types.
    """
    if keys.dtype != self._key_dtype:
      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
                      (self._key_dtype, keys.dtype))

    with ops.name_scope(name, "%s_lookup_table_find" % self._name,
                        [self._table_ref, keys]) as name:
      with ops.colocate_with(self._table_ref):
        # pylint: disable=protected-access
        values = gen_lookup_ops._lookup_table_find_v2(
            self._table_ref, keys, self._default_value, name=name)

    if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
      values.set_shape(
          tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate(
              self._value_shape))
    return values
Beispiel #3
0
    def lookup(self, keys, name=None):
        """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is used for keys not present in the table.

    Args:
      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
      name: A name for the operation (optional).

    Returns:
      A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.

    Raises:
      TypeError: when `keys` or `default_value` doesn't match the table data
        types.
    """
        key_tensor = keys
        if isinstance(keys, sparse_tensor.SparseTensor):
            key_tensor = keys.values

        if keys.dtype != self._key_dtype:
            raise TypeError(
                "Signature mismatch. Keys must be dtype %s, got %s." %
                (self._key_dtype, keys.dtype))

        with ops.name_scope(
                name, "%s_Lookup" % self._name,
            (self._table_ref, key_tensor, self._default_value)) as scope:
            # pylint: disable=protected-access
            values = gen_lookup_ops._lookup_table_find_v2(self._table_ref,
                                                          key_tensor,
                                                          self._default_value,
                                                          name=scope)
            # pylint: enable=protected-access

        values.set_shape(key_tensor.get_shape())
        if isinstance(keys, sparse_tensor.SparseTensor):
            return sparse_tensor.SparseTensor(keys.indices, values,
                                              keys.dense_shape)
        else:
            return values
Beispiel #4
0
  def lookup(self, keys, name=None):
    """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is used for keys not present in the table.

    Args:
      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
      name: A name for the operation (optional).

    Returns:
      A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.

    Raises:
      TypeError: when `keys` or `default_value` doesn't match the table data
        types.
    """
    key_tensor = keys
    if isinstance(keys, sparse_tensor.SparseTensor):
      key_tensor = keys.values

    if keys.dtype.base_dtype != self._key_dtype:
      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
                      (self._key_dtype, keys.dtype))

    with ops.name_scope(name, "%s_Lookup" % self._name,
                        (self._table_ref, key_tensor,
                         self._default_value)) as scope:
      # pylint: disable=protected-access
      values = gen_lookup_ops._lookup_table_find_v2(
          self._table_ref, key_tensor, self._default_value, name=scope)
      # pylint: enable=protected-access

    values.set_shape(key_tensor.get_shape())
    if isinstance(keys, sparse_tensor.SparseTensor):
      return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
    else:
      return values
Beispiel #5
0
 def lookup(self, keys, default):
     return gen_lookup_ops._lookup_table_find_v2(self.table_ref, keys,
                                                 default)
 def lookup(self, keys, default):
   return gen_lookup_ops._lookup_table_find_v2(self.table_ref, keys, default)