def testSimple(self): with self.test_session(use_gpu=False): indices = [ops.convert_to_tensor([0, 1]), ops.convert_to_tensor([2, 3])] values = [ops.convert_to_tensor([2, 3]), ops.convert_to_tensor([1, 1])] self.assertAllEqual( data_flow_ops.parallel_dynamic_stitch(indices, values).eval(), [2, 3, 1, 1])
def testInt32Cpu(self): with self.test_session(use_gpu=False): indices = [ ops.convert_to_tensor([0, 1, 5, 6, 7]), ops.convert_to_tensor([2, 4, 3]) ] values = [ ops.convert_to_tensor([12, 23, 34, 45, 56]), ops.convert_to_tensor([1, 3, 2]) ] self.assertAllEqual( data_flow_ops.parallel_dynamic_stitch(indices, values).eval(), [12, 23, 1, 2, 3, 34, 45, 56])
def _embedding_lookup_and_transform(params, ids, partition_strategy="mod", name=None, max_norm=None, transform_fn=None): """Helper function for embedding_lookup and _compute_sampled_logits. This function is a generalization of embedding_lookup that optionally applies a caller-specified transformation to each embedding. This is done through the `transform_fn` argument. If provided, the function is applied to each partitioned tensor of retrieved embeddings, colocated with the embeddings. This function will be called with a single `Tensor` argument of the same type as the `params` tensor and should return a `Tensor`. The shape of the argument will be the same as `params` except for the size of the first dimension. The first dimension of the result's shape must be the same size as the argument's. Args: params: See embedding_lookup. ids: See embedding_lookup. partition_strategy: See embedding_lookup. name: See embedding_lookup. max_norm: See embedding_lookup. transform_fn: An optional function to apply to each retrieved embedding. If max_norm is provided, transform_fn is applied to the norm-limited embeddings. Returns: See embedding_lookup for details. Raises: ValueError: If `params` is empty. """ if params is None or params in ((), []): raise ValueError("Need at least one param") if isinstance(params, variables.PartitionedVariable): params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): params = [params] with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: np = len(params) # Number of partitions # Preserve the resource variable status to avoid accidental dense reads. if not any( isinstance(p, resource_variable_ops.ResourceVariable) for p in params): params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") ids = ops.convert_to_tensor(ids, name="ids") if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): with ops.colocate_with(params[0]): result = _clip(array_ops.gather(params[0], ids, name=name), ids, max_norm) if transform_fn: result = transform_fn(result) return result else: # Flatten the ids. There are two cases where we need to do this. # - There is more than one params tensor. # - There is a transform_fn and ids is not statically known to be 1-D. # We must flatten in this case because transform_fn expects a flat # tensor of embeddings. flat_ids = array_ops.reshape(ids, [-1]) original_indices = math_ops.range(array_ops.size(flat_ids)) # Create p_assignments and set new_ids depending on the strategy. if partition_strategy == "mod": p_assignments = flat_ids % np new_ids = flat_ids // np elif partition_strategy == "div": # Compute num_total_ids as the sum of dim-0 of params, then assign to # partitions based on a constant number of ids per partition. Optimize # if we already know the full shape statically. dim_0_size = params[0].get_shape()[0] for p in xrange(1, np): dim_0_size += params[p].get_shape()[0] if dim_0_size.value: num_total_ids = constant_op.constant( dim_0_size.value, flat_ids.dtype) else: dim_0_sizes = [] for p in xrange(np): if params[p].get_shape()[0].value is not None: dim_0_sizes.append(params[p].get_shape()[0].value) else: with ops.colocate_with(params[p]): dim_0_sizes.append( array_ops.shape(params[p])[0]) num_total_ids = math_ops.reduce_sum( math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) ids_per_partition = num_total_ids // np extras = num_total_ids % np p_assignments = math_ops.maximum( flat_ids // (ids_per_partition + 1), (flat_ids - extras) // ids_per_partition) # Emulate a conditional using a boolean indicator tensor new_ids = array_ops.where( p_assignments < extras, flat_ids % (ids_per_partition + 1), (flat_ids - extras) % ids_per_partition) else: raise ValueError("Unrecognized partition strategy: " + partition_strategy) # Cast partition assignments to int32 for use in dynamic_partition. # There really should not be more than 2^32 partitions. p_assignments = math_ops.cast(p_assignments, dtypes.int32) # Partition list of ids based on assignments into np separate lists gather_ids = data_flow_ops.dynamic_partition( new_ids, p_assignments, np) # Similarly, partition the original indices. pindices = data_flow_ops.dynamic_partition(original_indices, p_assignments, np) # Do np separate lookups, finding embeddings for plist[p] in params[p] partitioned_result = [] for p in xrange(np): pids = gather_ids[p] with ops.colocate_with(params[p]): result = array_ops.gather(params[p], pids) if transform_fn: # If transform_fn is provided, the clip_by_norm precedes # the transform and hence must be co-located. See below # for the counterpart if transform_fn is not proveded. result = transform_fn(_clip(result, pids, max_norm)) partitioned_result.append(result) # Stitch these back together ret = data_flow_ops.parallel_dynamic_stitch(pindices, partitioned_result, name=name) # Determine the static element shape. if transform_fn is None: element_shape_s = params[0].get_shape()[1:] for p in params[1:]: element_shape_s = element_shape_s.merge_with( p.get_shape()[1:]) else: element_shape_s = ret.get_shape()[1:] # Compute the dynamic element shape. if element_shape_s.is_fully_defined(): element_shape_d = element_shape_s elif transform_fn is None: # It's important that we compute params[0].shape on the right device # to avoid data motion. with ops.colocate_with(params[0]): params_shape = array_ops.shape(params[0]) element_shape_d = params_shape[1:] else: element_shape_d = array_ops.shape(ret)[1:] # Reshape to reverse the flattening of ids. ret = array_ops.reshape( ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0)) # Normally the reshape is sufficient, but setting shape explicitly # teaches shape inference that params[1:].get_shape() matters # (in the case that transform_fn is None). ret.set_shape(ids.get_shape().concatenate(element_shape_s)) if not transform_fn: # If transform_fn was provided, the clip_by_norm was done above. ret = _clip(ret, ids, max_norm) return ret
def _embedding_lookup_and_transform(params, ids, partition_strategy="mod", name=None, max_norm=None, transform_fn=None): """Helper function for embedding_lookup and _compute_sampled_logits. This function is a generalization of embedding_lookup that optionally applies a caller-specified transformation to each embedding. This is done through the `transform_fn` argument. If provided, the function is applied to each partitioned tensor of retrieved embeddings, colocated with the embeddings. This function will be called with a single `Tensor` argument of the same type as the `params` tensor and should return a `Tensor`. The shape of the argument will be the same as `params` except for the size of the first dimension. The first dimension of the result's shape must be the same size as the argument's. Args: params: See embedding_lookup. ids: See embedding_lookup. partition_strategy: See embedding_lookup. name: See embedding_lookup. max_norm: See embedding_lookup. transform_fn: An optional function to apply to each retrieved embedding. If max_norm is provided, transform_fn is applied to the norm-limited embeddings. Returns: See embedding_lookup for details. Raises: ValueError: If `params` is empty. """ if params is None or params in ((), []): raise ValueError("Need at least one param") if isinstance(params, variables.PartitionedVariable): params = list(params) # Iterate to get the underlying Variables. if not isinstance(params, list): params = [params] with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: np = len(params) # Number of partitions # Preserve the resource variable status to avoid accidental dense reads. if not any( isinstance(p, resource_variable_ops.ResourceVariable) for p in params): params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") ids = ops.convert_to_tensor(ids, name="ids") if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): with ops.colocate_with(params[0]): result = _clip(array_ops.gather(params[0], ids, name=name), ids, max_norm) if transform_fn: result = transform_fn(result) # Make sure the final result does not have colocation contraints on the # params. Similar to the case np > 1 where parallel_dynamic_stitch is # outside the scioe of all with ops.colocate_with(params[p]). return array_ops.identity(result) else: # Flatten the ids. There are two cases where we need to do this. # - There is more than one params tensor. # - There is a transform_fn and ids is not statically known to be 1-D. # We must flatten in this case because transform_fn expects a flat # tensor of embeddings. flat_ids = array_ops.reshape(ids, [-1]) original_indices = math_ops.range(array_ops.size(flat_ids)) # Create p_assignments and set new_ids depending on the strategy. if partition_strategy == "mod": p_assignments = flat_ids % np new_ids = flat_ids // np elif partition_strategy == "div": # Compute num_total_ids as the sum of dim-0 of params, then assign to # partitions based on a constant number of ids per partition. Optimize # if we already know the full shape statically. dim_0_size = tensor_shape.Dimension(tensor_shape.dimension_value( params[0].get_shape()[0])) for p in xrange(1, np): dim_0_size += tensor_shape.Dimension(tensor_shape.dimension_value( params[p].get_shape()[0])) if dim_0_size.value: num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) else: dim_0_sizes = [] for p in xrange(np): param_p_dim = tensor_shape.dimension_value(params[p].get_shape()[0]) if param_p_dim is not None: dim_0_sizes.append(param_p_dim) else: with ops.colocate_with(params[p]): dim_0_sizes.append(array_ops.shape(params[p])[0]) num_total_ids = math_ops.reduce_sum( math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) ids_per_partition = num_total_ids // np extras = num_total_ids % np p_assignments = math_ops.maximum( flat_ids // (ids_per_partition + 1), (flat_ids - extras) // ids_per_partition) # Emulate a conditional using a boolean indicator tensor new_ids = array_ops.where(p_assignments < extras, flat_ids % (ids_per_partition + 1), (flat_ids - extras) % ids_per_partition) else: raise ValueError("Unrecognized partition strategy: " + partition_strategy) # Cast partition assignments to int32 for use in dynamic_partition. # There really should not be more than 2^32 partitions. p_assignments = math_ops.cast(p_assignments, dtypes.int32) # Partition list of ids based on assignments into np separate lists gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) # Similarly, partition the original indices. pindices = data_flow_ops.dynamic_partition(original_indices, p_assignments, np) # Do np separate lookups, finding embeddings for plist[p] in params[p] partitioned_result = [] for p in xrange(np): pids = gather_ids[p] with ops.colocate_with(params[p]): result = array_ops.gather(params[p], pids) if transform_fn: # If transform_fn is provided, the clip_by_norm precedes # the transform and hence must be co-located. See below # for the counterpart if transform_fn is not proveded. result = transform_fn(_clip(result, pids, max_norm)) partitioned_result.append(result) # Stitch these back together ret = data_flow_ops.parallel_dynamic_stitch( pindices, partitioned_result, name=name) # Determine the static element shape. if transform_fn is None: element_shape_s = params[0].get_shape()[1:] for p in params[1:]: element_shape_s = element_shape_s.merge_with(p.get_shape()[1:]) else: element_shape_s = ret.get_shape()[1:] # Compute the dynamic element shape. if element_shape_s.is_fully_defined(): element_shape_d = element_shape_s elif transform_fn is None: # It's important that we compute params[0].shape on the right device # to avoid data motion. with ops.colocate_with(params[0]): params_shape = array_ops.shape(params[0]) element_shape_d = params_shape[1:] else: element_shape_d = array_ops.shape(ret)[1:] # Reshape to reverse the flattening of ids. ret = array_ops.reshape(ret, array_ops.concat( [array_ops.shape(ids), element_shape_d], 0)) # Normally the reshape is sufficient, but setting shape explicitly # teaches shape inference that params[1:].get_shape() matters # (in the case that transform_fn is None). ret.set_shape(ids.get_shape().concatenate(element_shape_s)) if not transform_fn: # If transform_fn was provided, the clip_by_norm was done above. ret = _clip(ret, ids, max_norm) return ret
def testSimple(self): with self.session(use_gpu=False): indices = [ops.convert_to_tensor([0, 1]), ops.convert_to_tensor([2, 3])] values = [ops.convert_to_tensor([2, 3]), ops.convert_to_tensor([1, 1])] self.assertAllEqual( data_flow_ops.parallel_dynamic_stitch(indices, values), [2, 3, 1, 1])
def adaptive_embedding_lookup(params, ids, transforms, max_norm=None, name=None): """Looks up `ids` in a list of embedding tensors. This function is used to perform parallel lookups on the list of tensors in `params`. It is a generalization of `tf.gather`, where `params` is interpreted as a partitioning of a large embedding tensor with different number of rows and columns. `params` may be a `PartitionedVariable`s as returned by using `tf.compat.v1.get_variable()` with a partitioner. Each element `id` of `ids` is partitioned between the elements of `params` according to the `div` partition strategy. Id space should be the same with total number of `params` rows. The results of the lookup are transformed with corresponding functions from `transforms` and concatenated into a dense tensor. The returned tensor shape selected by `transforms`. Args: params: A list of P tensors of different shape, representing sharded embedding tensors. Alternatively, a `PartitionedVariable`, created by partitioning along dimension 0. ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. transforms: Functions applied to each retrieved embedding before concatenation. Required due to possible different embedding sizes of `params`. max_norm: If not `None`, each embedding is clipped if its l2-norm is larger than this value. name: A name for the operation (optional). Returns: A `Tensor` with the same type as the tensors in `params`. """ if isinstance(ids, tf.RaggedTensor): return tf.ragged.map_flat_values(adaptive_embedding_lookup, params, ids, transforms, max_norm, name) if not isinstance(params, (list, tuple)) or len(params) < 2: raise ValueError('At least 2 variables required in params') if not isinstance(transforms, (list, tuple)) or len(transforms) != len(params): raise ValueError('Each param should have corresponding transform') if not all([callable(t) for t in transforms]): raise ValueError('Each transform should be callable') with tf.name_scope(name or 'adaptive_embedding_lookup') as name: np = len(params) # Number of partitions # Preserve the resource variable status to avoid accidental dense reads. if not any( isinstance(p, resource_variable_ops.ResourceVariable) for p in params): params = ops.convert_n_to_tensor_or_indexed_slices(params, name='params') ids = tf.convert_to_tensor(ids, name='ids') # Flatten the ids. There is more than one params tensor. flat_ids = tf.reshape(ids, [-1]) original_indices = tf.range(tf.size(flat_ids)) # Create p_assignments and set new_ids for adaptive strategy. # Compute total_ids_capacity as the sum of dim-0 of params, then assign to # partitions based on a variable number of ids per partition. Optimize # if we already know the full shape statically. dim_0_sizes = [] for p in range(np): param_p_dim = tensor_shape.Dimension( tensor_shape.dimension_value(params[p].get_shape()[0])) dim_0_sizes.append(param_p_dim) dim_0_size_value = sum(dim_0_sizes).value if dim_0_size_value: dim_0_sizes = tf.TensorShape(dim_0_sizes).as_list() total_ids_capacity = tf.constant(dim_0_size_value, dtype=flat_ids.dtype) else: dim_0_sizes = [] for p in range(np): param_p_dim = tensor_shape.dimension_value( params[p].get_shape()[0]) if param_p_dim is not None: dim_0_sizes.append(param_p_dim) else: with ops.colocate_with(params[p]): dim_0_sizes.append(tf.shape(params[p])[0]) dim_0_sizes = tf.stack(dim_0_sizes) total_ids_capacity = tf.reduce_sum(dim_0_sizes) p_cumsum = tf.cumsum(tf.cast(dim_0_sizes, dtype=flat_ids.dtype)) assert_max_id = tf.debugging.assert_less( tf.math.reduce_max(flat_ids), total_ids_capacity, 'Invalid id. Maximum id should be less then total number of params rows' ) with tf.control_dependencies([assert_max_id]): p_assignments = tf.searchsorted( p_cumsum, flat_ids, side='right', ) # Cast partition assignments to int32 for use in dynamic_partition. # There really should not be more than 2^32 partitions. p_assignments = tf.cast(p_assignments, tf.int32) # Partition list of ids based on assignments into np separate lists p_intervals = tf.concat(([0], p_cumsum), 0) new_ids = flat_ids - tf.gather(p_intervals, p_assignments) gather_ids = tf.dynamic_partition(new_ids, p_assignments, np) # Similarly, partition the original indices. p_indices = tf.dynamic_partition(original_indices, p_assignments, np) # Do np separate lookups, finding embeddings for plist[p] in params[p] partitioned_result = [] for p in range(np): pids = gather_ids[p] transform_fn = transforms[p] with ops.colocate_with(params[p]): result = tf.gather(params[p], pids) result = embedding_ops._clip(transform_fn(result), pids, max_norm) partitioned_result.append(result) # Stitch these back together ret = data_flow_ops.parallel_dynamic_stitch(p_indices, partitioned_result, name=name) # Determine the static element shape. element_shape_s = ret.get_shape()[1:] # Compute the dynamic element shape. if element_shape_s.is_fully_defined(): element_shape_d = element_shape_s else: element_shape_d = tf.shape(ret)[1:] # Reshape to reverse the flattening of ids. ret = tf.reshape(ret, tf.concat([tf.shape(ids), element_shape_d], 0)) # Normally the reshape is sufficient, but setting shape explicitly # teaches shape inference that params[1:].get_shape() matters. ret.set_shape(ids.get_shape().concatenate(element_shape_s)) return ret