def CpuEmbLookup(self, ids_map: Dict[str, tf.Tensor], partition_strategy: str) -> Dict[str, tf.Tensor]: """CPU evaluation embedding lookup for dense tensors. Args: ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor. For sequence embeddings, -1 is used as a padding id. Non-sequence embeddings do not support padded ids. partition_strategy: See TPUEmbeddingLayer partition_strategy param. Returns: An activations dict of string -> float32 Tensor. For non-sequence embeddings: [batch, 1, embedding_dim] For sequence embeddings: [batch, max_sequence_length, embedding_dim] """ rets = py_utils.NestedMap() if self.max_sequence_length > 0: # "Sequence embedding", no combiner case for k, ids in ids_map.items(): rets[k] = self._SequenceEmbLookup(ids, partition_strategy) else: # Non-"Sequence embedding", combiner case for k, ids in ids_map.items(): # Dense to sparse. dense_shape = tf.shape(ids, out_type=tf.int64) sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64) embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64) # [?, embedding_dim] sparse_ids = tf.SparseTensor( indices=sample_indices, values=embedding_indices, dense_shape=dense_shape) rets[k] = self._CombinerEmbLookup(sparse_ids, partition_strategy) return rets
def CreateTpuEmbeddingEnqueueOps(self): """Creates the TpuEmbedding enqueue ops on the host. Note that this must be called after the instantiation of the monolithic TPUEmbeddingLayer. """ p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) enqueue_ops = [] if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) if not tpu_embedding: return for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = self._batch[key] tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host) for core, split in enumerate(tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where(tf.not_equal(split, -1)) embedding_indices = tf.gather_nd(split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data enqueue_ops += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) self._tpu_infeed_op.append(tf.group(*enqueue_ops))
def CpuEmbLookup(self, ids_map, partition_strategy): """CPU evaluation embedding lookup. Args: ids_map: A dict of `input_key` string -> [batch, sequence] int32 Tensor. -1 is used as a padding id. partition_strategy: See TPUEmbeddingLayer partition_strategy param. Returns: An activations dict of string -> float32 Tensor. For non-sequence embeddings: [batch, 1, embedding_dim] For sequence embeddings: [batch, max_sequence_length, embedding_dim] """ p = self.params rets = py_utils.NestedMap() if self.max_sequence_length > 0: # "Sequence embedding", no combiner case for k, ids in ids_map.items(): embs = tf.nn.embedding_lookup( self.theta.wm, tf.reshape(ids, [-1]), partition_strategy=partition_strategy) out_shape = tf.concat([tf.shape(ids), [p.embedding_dim]], 0) rets[k] = tf.reshape(embs, out_shape) else: # Non-"Sequence embedding", combiner case for k, ids in ids_map.items(): # Dense to sparse. dense_shape = tf.shape(ids, out_type=tf.int64) sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64) embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64) sparse_ids = tf.SparseTensor(indices=sample_indices, values=embedding_indices, dense_shape=dense_shape) # [?, embedding_dim] # For tf.nn.embedding_lookup_sparse, output.dim0 might be different from # sparse_ids.dense_shape.dim0. # In fact, the '?' is the smallest span starting from the index=0 that # covers all the results. embs = tf.nn.embedding_lookup_sparse( self.theta.wm, sparse_ids, None, # sp_weights combiner=p.combiner, partition_strategy=partition_strategy) batch_size = dense_shape[0] # Explicitly pad results to maintain dim0=batch. dim0_padlen = tf.cast(batch_size, tf.int32) - tf.shape(embs)[0] embs = tf.pad(embs, [[0, dim0_padlen], [0, 0]]) # [batch, 1, embedding_dim] embs = py_utils.HasShape(embs, [batch_size], ndims=1) rets[k] = tf.expand_dims(embs, 1) return rets
def _Notvisible(x): a, b = tf.expand_dims(x, -1), tf.expand_dims(x, -2) return tf.cast( tf.math.logical_or( tf.not_equal(a, b), # also ignoring segment_id=0 tf.math.logical_not( tf.math.logical_or(tf.cast(a, tf.bool), tf.cast(b, tf.bool)))), tf.float32)
def _Notvisible(seg_id, seg_pos): a, b = tf.expand_dims(seg_id, -1), tf.expand_dims(seg_id, -2) return tf.cast( tf.math.logical_or( tf.less(tf.expand_dims(seg_pos, -1), tf.expand_dims(seg_pos, -2)), tf.math.logical_or( tf.not_equal(a, b), tf.math.logical_not( tf.math.logical_or( tf.cast(a, tf.bool), tf.cast(b, tf.bool))))), tf.float32)
def _Lookup(ids): # Dense to sparse. dense_shape = tf.shape(ids, out_type=tf.int64) sample_indices = tf.cast(tf.where(tf.not_equal(ids, -1)), tf.int64) embedding_indices = tf.cast(tf.gather_nd(ids, sample_indices), tf.int64) # [?, embedding_dim] sparse_ids = tf.SparseTensor( indices=sample_indices, values=embedding_indices, dense_shape=dense_shape) return self._CombinerEmbLookup(sparse_ids, partition_strategy)
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0): """Concats sequence `x` with sequence `y`. This function is length aware (based off the paddings). Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. y: A sequence of tokens of shape [batch_size, y_len_max]. y_paddings: The paddings of `y`. pad: The <pad> token to fill the concatenated sequence (of type integer). Returns: A tuple. - Concatenation of `x` and `y` of shape [batch_size, x_len_max + y_len_max]. - Paddings of the concatenation of shape [batch_size, x_len_max + y_len_max]. """ # Get the length (w/ eos). x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32) batch_size = py_utils.GetShape(x)[0] y_len_max = py_utils.GetShape(y)[1] # Pad `x` with necessary <pad>. x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1) # Replace all <pad> with 0. x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0)) # Compute the write indices of `y` in `xy`. indices = tf.stack([ tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]), (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) + tf.expand_dims(x_len, 1)), ], 2) xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x)) # We need to remap all <pad> to `pad`. xy = tf.where( tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0), tf.expand_dims(x_len + y_len, 1)), xy, tf.fill(py_utils.GetShape(xy), pad)) xy_paddings = 1 - tf.sequence_mask(x_len + y_len, py_utils.GetShape(xy)[1], x_paddings.dtype) return xy, xy_paddings
def TransformFeatures(self, features): # We assume that the lasers are not padded, and all points are real. if ('points_padding' in features.lasers and features.lasers.points_padding is not None): raise ValueError('FilterNLZPoints preprocessor does not support ' 'padded lasers.') # The 3rd feature in the laser is 1.0 for points in a no-label-zone # and -1. for normal points. is_not_nlz = tf.not_equal(features.lasers.points_feature[:, 2], 1.0) features.lasers.points_xyz = tf.boolean_mask(features.lasers.points_xyz, is_not_nlz) features.lasers.points_feature = tf.boolean_mask( features.lasers.points_feature, is_not_nlz) return features
def _ShouldMerge(unused_tokens, candidates): """Merge until not possible, or we abort early according to merge_prob.""" return tf.logical_and( tf.reduce_any(tf.not_equal(candidates, NO_TOKEN)), tf.random.uniform([]) < self._merge_prob)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if 'bucket_keys' in batch: # Hack: bucket_keys are not needed on TPU. del batch['bucket_keys'] tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) # For executor-driven multiple programs, we need more fine-grained # access rather than using a single global graph collection. self.tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def ComputeLoss(self, theta, activation, labels, segment_ids): p = self.params activation = self._MaybeSplit( self._MaybeSplit(activation) * (p.embedding_dim**-0.5)) softmax_weights = theta.embedding if activation.dtype != softmax_weights.dtype: softmax_weights = tf.cast(softmax_weights, activation.dtype) logits = self._MaybeSplit( tf.einsum('BLM,VM->BLV', activation, softmax_weights)) if p.logits_abs_max is not None: logits = self._MaybeSplit( py_utils.clip_by_value(logits, -p.logits_abs_max, p.logits_abs_max)) off_value = p.label_smoothing / p.vocab_size on_value = 1.0 - p.label_smoothing + off_value soft_targets = self._MaybeSplit( tf.one_hot( labels, p.vocab_size, on_value=on_value, off_value=off_value)) xent = self._MaybeSplit( tf.nn.softmax_cross_entropy_with_logits( labels=tf.one_hot(labels, p.vocab_size), logits=logits)) loss = self._MaybeSplit( tf.nn.softmax_cross_entropy_with_logits( labels=soft_targets, logits=logits)) soft_targets_xent = loss if p.z_loss_coef > 0.0: log_z = tf.math.reduce_logsumexp(logits, -1) z_loss_inc = p.z_loss_coef * tf.math.square(log_z) loss += z_loss_inc non_padding = self._MaybeSplit( tf.cast(tf.not_equal(segment_ids, 0), py_utils.FPropDtype(p))) per_token_loss = self._MaybeSplit(loss * non_padding) if p.z_loss_coef > 0.0: per_token_z_loss_inc = self._MaybeSplit(z_loss_inc * non_padding) if p.use_tgt_labels_size_as_loss_denominator: # E.g. loss is going to be tiny if inputs are not packed and only a # fraction of tgt_labels are non-padding. loss_denom = tf.reduce_sum(tf.ones_like(non_padding)) per_example_loss_denom = tf.reduce_sum(tf.ones_like(non_padding), 1) else: loss_denom = tf.reduce_sum(non_padding) per_example_loss_denom = tf.reduce_sum(non_padding, 1) avg_loss = tf.reduce_sum(per_token_loss) / loss_denom avg_z_loss_inc = (tf.reduce_sum(per_token_z_loss_inc) / loss_denom) if p.z_loss_coef > 0.0 else 0.0 soft_targets_xent = ( tf.reduce_sum(self._MaybeSplit(soft_targets_xent * non_padding)) / tf.reduce_sum(non_padding)) # TODO(lepikhin): consider returning # {'loss': (unnormalized per_token_loss, tf.reduce_sum(non_padding))} per_example_loss = { 'loss': tf.reduce_sum(per_token_loss, 1) / per_example_loss_denom } return { 'mean_xent': (tf.reduce_sum(self._MaybeSplit(xent * non_padding)) / tf.reduce_sum(non_padding), tf.reduce_sum(non_padding)), 'soft_targets_xent': (soft_targets_xent, tf.reduce_sum(non_padding)), 'weight': (tf.reduce_sum(non_padding), 1.0), 'loss': (avg_loss, 1.0), 'avg_z_loss_inc': (avg_z_loss_inc, 1.0), }, per_example_loss