Example #1
0
    def lookup(self, keys, name=None):
        """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is used for keys not present in the table.

    Args:
      keys: Keys to look up. Can be a tensor of any shape. Must match the
        table's key_dtype.
      name: A name for the operation (optional).

    Returns:
      A tensor containing the values in the same shape as `keys` using the
        table's value type.

    Raises:
      TypeError: when `keys` do not match the table data types.
    """
        if keys.dtype != self._key_dtype:
            raise TypeError(
                "Signature mismatch. Keys must be dtype %s, got %s." %
                (self._key_dtype, keys.dtype))

        with ops.name_scope(name, "%s_lookup_table_find" % self._name,
                            [self._table_ref, keys]) as name:
            # pylint: disable=protected-access
            values = gen_data_flow_ops._lookup_table_find(self._table_ref,
                                                          keys,
                                                          self._default_value,
                                                          name=name)

        if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
            values.set_shape(
                tensor_shape.TensorShape([keys.get_shape().dims[0]
                                          ]).concatenate(self._value_shape))
        return values
Example #2
0
    def lookup(self, keys, name=None):
        """Returns the values for the given 'keys' tensor.

    If an element on the key tensor is not found in the table, the default_value
    is used.

    Args:
      keys: The tensor for the keys.
      name: Optional name for the op.

    Returns:
      The operation that looks up the keys.

    Raises:
      TypeError: when 'keys' or 'default_value' doesn't match the table data
        types.
    """
        if name is None:
            name = "%s_lookup_table_find" % self._name

        if keys.dtype != self._key_dtype:
            raise TypeError(
                "Signature mismatch. Keys must be dtype %s, got %s." %
                (self._key_dtype, keys.dtype))

        return gen_data_flow_ops._lookup_table_find(self._table_ref,
                                                    keys,
                                                    self._default_value,
                                                    name=name)
Example #3
0
  def lookup(self, keys, name=None):
    """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is use for keys not present in the table.

    Args:
      keys: Keys to look up.
      name: Optional name for the op.

    Returns:
      The operation that looks up the keys.

    Raises:
      TypeError: when `keys` or `default_value` doesn't match the table data
        types.
    """
    if name is None:
      name = "%s_lookup_table_find" % self._name

    if keys.dtype != self._key_dtype:
      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (
          self._key_dtype, keys.dtype))

    return gen_data_flow_ops._lookup_table_find(
        self._table_ref, keys, self._default_value, name=name)
Example #4
0
  def lookup(self, keys, name=None):
    """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is used for keys not present in the table.

    Args:
      keys: Keys to look up. Can be a tensor of any shape. Must match the
        table's key_dtype.
      name: A name for the operation (optional).

    Returns:
      A tensor containing the values in the same shape as `keys` using the
        table's value type.

    Raises:
      TypeError: when `keys` do not match the table data types.
    """
    if keys.dtype != self._key_dtype:
      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
                      (self._key_dtype, keys.dtype))

    with ops.name_scope(name, "%s_lookup_table_find" % self._name,
                        [self._table_ref, keys]) as name:
      # pylint: disable=protected-access
      values = gen_data_flow_ops._lookup_table_find(
          self._table_ref, keys, self._default_value, name=name)

    if keys.get_shape().ndims is not None and keys.get_shape().ndims > 0:
      values.set_shape(
          tensor_shape.TensorShape([keys.get_shape().dims[0]]).concatenate(
              self._value_shape))
    return values
Example #5
0
  def lookup(self, keys, name=None):
    """Returns the values for the given 'keys' tensor.

    If an element on the key tensor is not found in the table, the default_value
    is used.

    Args:
      keys: The tensor for the keys.
      name: Optional name for the op.

    Returns:
      The operation that looks up the keys.

    Raises:
      TypeError: when 'keys' or 'default_value' doesn't match the table data
        types.
    """
    if name is None:
      name = "%s_lookup_table_find" % self._name

    if keys.dtype != self._key_dtype:
      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (
          self._key_dtype, keys.dtype))

    return gen_data_flow_ops._lookup_table_find(
        self._table_ref, keys, self._default_value, name=name)
Example #6
0
  def lookup(self, keys, name=None):
    """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is use for keys not present in the table.

    Args:
      keys: Keys to look up.
      name: Optional name for the op.

    Returns:
      The operation that looks up the keys.

    Raises:
      TypeError: when `keys` or `default_value` doesn't match the table data
        types.
    """
    if name is None:
      name = "%s_lookup_table_find" % self._name

    if keys.dtype != self._key_dtype:
      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (
          self._key_dtype, keys.dtype))

    return gen_data_flow_ops._lookup_table_find(
        self._table_ref, keys, self._default_value, name=name)
Example #7
0
    def lookup(self, keys, name=None):
        """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is used for keys not present in the table.

    Args:
      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
      name: A name for the operation (optional).

    Returns:
      A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.

    Raises:
      TypeError: when `keys` or `default_value` doesn't match the table data
        types.
    """
        if name is None:
            name = "%s_lookup_table_find" % self._name

        key_tensor = keys
        if isinstance(keys, ops.SparseTensor):
            key_tensor = keys.values

        if keys.dtype != self._key_dtype:
            raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (self._key_dtype, keys.dtype))

        # pylint: disable=protected-access
        values = gen_data_flow_ops._lookup_table_find(self._table_ref, key_tensor, self._default_value, name=name)
        # pylint: enable=protected-access

        values.set_shape(key_tensor.get_shape())
        if isinstance(keys, ops.SparseTensor):
            return ops.SparseTensor(keys.indices, values, keys.shape)
        else:
            return values
Example #8
0
    def lookup(self, keys, name=None):
        """Looks up `keys` in a table, outputs the corresponding values.

    The `default_value` is used for keys not present in the table.

    Args:
      keys: Keys to look up. May be either a `SparseTensor` or dense `Output`.
      name: A name for the operation (optional).

    Returns:
      A `SparseTensor` if keys are sparse, otherwise a dense `Output`.

    Raises:
      TypeError: when `keys` or `default_value` doesn't match the table data
        types.
    """
        if name is None:
            name = "%s_lookup_table_find" % self._name

        key_tensor = keys
        if isinstance(keys, sparse_tensor.SparseTensor):
            key_tensor = keys.values

        if keys.dtype != self._key_dtype:
            raise TypeError(
                "Signature mismatch. Keys must be dtype %s, got %s." %
                (self._key_dtype, keys.dtype))

        # pylint: disable=protected-access
        values = gen_data_flow_ops._lookup_table_find(self._table_ref,
                                                      key_tensor,
                                                      self._default_value,
                                                      name=name)
        # pylint: enable=protected-access

        values.set_shape(key_tensor.get_shape())
        if isinstance(keys, sparse_tensor.SparseTensor):
            return sparse_tensor.SparseTensor(keys.indices, values, keys.shape)
        else:
            return values
Example #9
0
 def lookup(self, keys, default):
   return gen_data_flow_ops._lookup_table_find(self.table_ref, keys, default)
Example #10
0
 def lookup(self, keys, default):
     return gen_data_flow_ops._lookup_table_find(self.table_ref, keys,
                                                 default)