def _create_or_get_trainable(trainable_name):
     if trainable_name is None:
         if context.executing_eagerly():
             raise ValueError(
                 'Must provide a name for embedding_lookup when using eager execution.'
             )
         trainable_name = ops.get_default_graph().unique_name(
             _ANONYMOUS_TRAINABLE_STORE_KEY)
     if not context.executing_eagerly() and not ops.inside_function(
     ):
         wrapper = de.TrainableWrapper(
             params,
             ids,
             max_norm=max_norm,
             initial_value=initial_value,
             dtype=params.value_dtype,
             trainable=params.trainable,
             collections=collections,
             model_mode=ModelMode.CURRENT_SETTING,
             name=trainable_name)
         params._trainable_store[trainable_name] = wrapper
         return wrapper
     else:
         with ops.init_scope():
             shadow = params._trainable_store.get(
                 trainable_name, None)
             if shadow is None:
                 shadow = de.shadow_ops.ShadowVariable(
                     params,
                     name=trainable_name,
                     max_norm=max_norm,
                     trainable=params.trainable,
                     model_mode=ModelMode.CURRENT_SETTING)
                 params._trainable_store[trainable_name] = shadow
         return shadow
 def _create_trainable(trainable_name):
   return de.TrainableWrapper(params,
                              ids,
                              max_norm=max_norm,
                              initial_value=initial_value,
                              dtype=params.value_dtype,
                              trainable=params.trainable,
                              collections=collections,
                              model_mode=ModelMode.CURRENT_SETTING,
                              name=trainable_name)
def embedding_lookup(
        params,
        ids,
        partition_strategy=None,  # pylint: disable=unused-argument
        name=None,
        validate_indices=None,  # pylint: disable=unused-argument
        max_norm=None,
        return_trainable=False):
    """Provides a dynamic version of embedding_lookup
    similar with tf.nn.embedding_lookup.

  Ids are flattened to a 1d tensor before being passed to embedding_lookup
  then, they are unflattend to match the original ids shape plus an extra
  leading dimension of the size of the embeddings.

  Args:
    params: A dynamic_embedding.Variable instance.
    ids: a tensor with any shape as same dtype of params.key_dtype.
    partition_strategy: No used, for API compatiblity with `nn.emedding_lookup`.
    name: A name for the operation (optional).
    validate_indices: No used, just for compatible with nn.embedding_lookup .
    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
      than this value.
    return_trainable: optional, If True, also return TrainableWrapper
  Returns:
    A tensor with shape [shape of ids] + [dim],
      dim is equal to the value dim of params.
      containing the values from the params tensor(s) for keys in ids.
    trainable_wrap:
      A TrainableWrapper object used to fill the Optimizers `var_list`
        Only provided if `return_trainable` is True.
  """
    if isinstance(params, (list, tuple)) and len(params) > 1:
        raise ValueError("Only one params is allowed.")
    if isinstance(params, (list, tuple)):
        params = params[0]
    if not isinstance(params, Variable):
        raise TypeError("params should be a Variable instance.")
    if params.key_dtype != ids.dtype:
        raise TypeError(
            "params.key_dtype should be same with ids.dtype: {} vs. {}".format(
                params.key_dtype, ids.dtype))

    scope = variable_scope.get_variable_scope()
    full_name = scope.name + "/" if scope.name else ""
    full_name += (name + "/") if name else "embedding_lookup/"
    with ops.name_scope(full_name):
        initial_value = None
        trainable_wrap = None
        ids = ops.convert_to_tensor(ids, name="ids")
        if ids.get_shape() == tensor_shape.unknown_shape():
            ids = array_ops.reshape(ids, shape=[-1])
            initial_shape = (1, params.dim)
            trainable_shape = tensor_shape.unknown_shape()
        else:
            initial_shape = [ d if d else 1 for d in ids.get_shape().as_list()] \
                            + [params.dim]
            trainable_shape = ids.get_shape().concatenate([params.dim])
        initial_value = array_ops.zeros(shape=initial_shape,
                                        dtype=params.value_dtype)
        if isinstance(initial_value, ops.Tensor) and hasattr(
                initial_value,
                "graph") and initial_value.graph.building_function:
            initial_value = lambda: array_ops.zeros(initial_shape,
                                                    dtype=params.value_dtype)

        with ops.colocate_with(None, ignore_existing=True):
            collections = [ops.GraphKeys.LOCAL_VARIABLES]
            if params.trainable:
                collections += [ops.GraphKeys.TRAINABLE_VARIABLES]
            trainable_ = de.TrainableWrapper(params,
                                             ids,
                                             max_norm=max_norm,
                                             initial_value=initial_value,
                                             dtype=params.value_dtype,
                                             trainable=params.trainable,
                                             collections=collections)
            embeddings = array_ops.identity(trainable_)
            embeddings.set_shape(trainable_shape)

        if trainable_ not in params.trainable_wrappers:
            params.trainable_wrappers.append(trainable_)

    return (embeddings, trainable_) if return_trainable else embeddings