def _create_or_get_trainable(trainable_name): if trainable_name is None: if context.executing_eagerly(): raise ValueError( 'Must provide a name for embedding_lookup when using eager execution.' ) trainable_name = ops.get_default_graph().unique_name( _ANONYMOUS_TRAINABLE_STORE_KEY) if not context.executing_eagerly() and not ops.inside_function( ): wrapper = de.TrainableWrapper( params, ids, max_norm=max_norm, initial_value=initial_value, dtype=params.value_dtype, trainable=params.trainable, collections=collections, model_mode=ModelMode.CURRENT_SETTING, name=trainable_name) params._trainable_store[trainable_name] = wrapper return wrapper else: with ops.init_scope(): shadow = params._trainable_store.get( trainable_name, None) if shadow is None: shadow = de.shadow_ops.ShadowVariable( params, name=trainable_name, max_norm=max_norm, trainable=params.trainable, model_mode=ModelMode.CURRENT_SETTING) params._trainable_store[trainable_name] = shadow return shadow
def _create_trainable(trainable_name): return de.TrainableWrapper(params, ids, max_norm=max_norm, initial_value=initial_value, dtype=params.value_dtype, trainable=params.trainable, collections=collections, model_mode=ModelMode.CURRENT_SETTING, name=trainable_name)
def embedding_lookup( params, ids, partition_strategy=None, # pylint: disable=unused-argument name=None, validate_indices=None, # pylint: disable=unused-argument max_norm=None, return_trainable=False): """Provides a dynamic version of embedding_lookup similar with tf.nn.embedding_lookup. Ids are flattened to a 1d tensor before being passed to embedding_lookup then, they are unflattend to match the original ids shape plus an extra leading dimension of the size of the embeddings. Args: params: A dynamic_embedding.Variable instance. ids: a tensor with any shape as same dtype of params.key_dtype. partition_strategy: No used, for API compatiblity with `nn.emedding_lookup`. name: A name for the operation (optional). validate_indices: No used, just for compatible with nn.embedding_lookup . max_norm: If not `None`, each embedding is clipped if its l2-norm is larger than this value. return_trainable: optional, If True, also return TrainableWrapper Returns: A tensor with shape [shape of ids] + [dim], dim is equal to the value dim of params. containing the values from the params tensor(s) for keys in ids. trainable_wrap: A TrainableWrapper object used to fill the Optimizers `var_list` Only provided if `return_trainable` is True. """ if isinstance(params, (list, tuple)) and len(params) > 1: raise ValueError("Only one params is allowed.") if isinstance(params, (list, tuple)): params = params[0] if not isinstance(params, Variable): raise TypeError("params should be a Variable instance.") if params.key_dtype != ids.dtype: raise TypeError( "params.key_dtype should be same with ids.dtype: {} vs. {}".format( params.key_dtype, ids.dtype)) scope = variable_scope.get_variable_scope() full_name = scope.name + "/" if scope.name else "" full_name += (name + "/") if name else "embedding_lookup/" with ops.name_scope(full_name): initial_value = None trainable_wrap = None ids = ops.convert_to_tensor(ids, name="ids") if ids.get_shape() == tensor_shape.unknown_shape(): ids = array_ops.reshape(ids, shape=[-1]) initial_shape = (1, params.dim) trainable_shape = tensor_shape.unknown_shape() else: initial_shape = [ d if d else 1 for d in ids.get_shape().as_list()] \ + [params.dim] trainable_shape = ids.get_shape().concatenate([params.dim]) initial_value = array_ops.zeros(shape=initial_shape, dtype=params.value_dtype) if isinstance(initial_value, ops.Tensor) and hasattr( initial_value, "graph") and initial_value.graph.building_function: initial_value = lambda: array_ops.zeros(initial_shape, dtype=params.value_dtype) with ops.colocate_with(None, ignore_existing=True): collections = [ops.GraphKeys.LOCAL_VARIABLES] if params.trainable: collections += [ops.GraphKeys.TRAINABLE_VARIABLES] trainable_ = de.TrainableWrapper(params, ids, max_norm=max_norm, initial_value=initial_value, dtype=params.value_dtype, trainable=params.trainable, collections=collections) embeddings = array_ops.identity(trainable_) embeddings.set_shape(trainable_shape) if trainable_ not in params.trainable_wrappers: params.trainable_wrappers.append(trainable_) return (embeddings, trainable_) if return_trainable else embeddings