def CreateTpuEmbeddingEnqueueOps(self): """Creates the TpuEmbedding enqueue ops on the host. Note that this must be called after the instantiation of the monolithic TPUEmbeddingLayer. """ p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) enqueue_ops = [] if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) if not tpu_embedding: return for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = self._batch[key] tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host) for core, split in enumerate(tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where(tf.not_equal(split, -1)) embedding_indices = tf.gather_nd(split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data enqueue_ops += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) self._tpu_infeed_op.append(tf.group(*enqueue_ops))
def FProp(self, theta, inputs, paddings): """Builds FProp graph. Args: theta: A NestedMap of Tensors, see base class. inputs: A Tensor of shape [batch, seqlen, dim0]. paddings: A Tensor of shape [batch, seqlen]. Returns: output: A Tensor of shape [batch, seqlen, dim0]. out_paddings: A Tensor of shape [batch, seqlen]. """ p = self.params with tf.name_scope(p.name): unnormalized_inputs = inputs inputs = self.ln.FProp(theta.ln, inputs) if p.split_act_gated_linear_start: act_inputs = self.linear_start_act.FProp( theta.linear_start_act, inputs) gated_inputs = self.linear_start_gated.FProp( theta.linear_start_gated, inputs) else: inputs = self.linear_start.FProp(theta.linear_start, inputs) gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1) inputs = self._GLU(gated_inputs, act_inputs) # TODO(jamesqin): inroduce depthwise conv2d with 3d inputs. # [b, t, d] --> [b, t, 1, d] inputs = tf.expand_dims(inputs, 2) adapted_blf_dims_mapping = None if p.activation_split_dims_mapping.blf is not None: adapted_blf_dims_mapping = p.activation_split_dims_mapping.blf.copy( ) adapted_blf_dims_mapping.insert(2, -1) inputs = xla_sharding_utils.MeshSplit(inputs, p.device_mesh, adapted_blf_dims_mapping) theta.depthwise_conv1d.w = xla_sharding_utils.MeshSplit( theta.depthwise_conv1d.w, p.device_mesh, p.weight_split_dims_mapping.hwim) inputs, paddings = self.depthwise_conv1d.FProp( theta.depthwise_conv1d, inputs, paddings) inputs = xla_sharding_utils.MeshSplit(inputs, p.device_mesh, adapted_blf_dims_mapping) inputs = self._Normalize(theta, inputs, paddings) inputs = xla_sharding_utils.MeshSplit( inputs, p.device_mesh, p.activation_split_dims_mapping.blf) inputs = self._ApplyActivation(inputs, p.conv_activation) inputs = self.linear_end.FProp(theta.linear_end, inputs) inputs = self.dropout.FProp(theta.dropout, inputs) output = inputs + unnormalized_inputs return output, paddings
def IsWithinBBox3D(points_3d, bboxes_3d): """Checks if points are within a 3-d bbox. Args: points_3d: [num_points, 3] float32 Tensor specifying points in 3-d space as [x, y, z] coordinates. bboxes_3d: [num_bboxes, 7] float32 Tensor specifying a 3-d bboxes specified as [x, y, z, dx, dy, dz, phi] where x, y and z is the center of the box. Returns: boolean Tensor of shape [num_points, num_bboxes] indicating whether the points belong within each box. """ points_3d = py_utils.HasRank(points_3d, 2) points_3d = py_utils.HasShape(points_3d, [-1, 3]) num_points, _ = py_utils.GetShape(points_3d, 2) bboxes_3d = py_utils.HasRank(bboxes_3d, 2) bboxes_3d = py_utils.HasShape(bboxes_3d, [-1, 7]) num_bboxes, _ = py_utils.GetShape(bboxes_3d, 2) # Compute the 3-D corners of the bounding boxes. bboxes_3d_b = tf.expand_dims(bboxes_3d, 0) bbox_corners = BBoxCorners(bboxes_3d_b) bbox_corners = py_utils.HasShape(bbox_corners, [1, -1, 8, 3]) # First four points are the top of the bounding box. # Counter-clockwise arrangement of points specifying 2-d Euclidean box. # (x0, y1) <--- (x1, y1) # ^ # | # | # (x0, y0) ---> (x1, y0) bboxes_2d_corners = bbox_corners[0, :, 0:4, 0:2] bboxes_2d_corners = py_utils.HasShape(bboxes_2d_corners, [-1, 4, 2]) # Determine if points lie within 2-D (x, y) plane for all bounding boxes. points_2d = points_3d[:, :2] is_inside_2d = IsWithinBBox(points_2d, bboxes_2d_corners) is_inside_2d = py_utils.HasShape(is_inside_2d, [num_points, num_bboxes]) # Determine if points lie with the z-dimension for all bounding boxes. [_, _, z, _, _, dz, _] = tf.split(bboxes_3d, 7, axis=-1) def _ComputeLimits(center, width): left = center - width / 2.0 right = center + width / 2.0 return left, right z0, z1 = _ComputeLimits(z, dz) z_points = tf.expand_dims(points_3d[:, 2], -1) is_inside_z = tf.math.logical_and( tf.less_equal(z_points, z1[tf.newaxis, :, 0]), tf.greater_equal(z_points, z0[tf.newaxis, :, 0])) is_inside_z = py_utils.HasShape(is_inside_z, [num_points, num_bboxes]) return tf.math.logical_and(is_inside_z, is_inside_2d)
def dec_callback(self, tgt_id, tgt_pos, tgt_segment_id, tgt_mask, dec_state, t): del tgt_pos, tgt_segment_id [buf] = dec_state if tgt_id.shape == (self.batch_size, self.beam_size): buf = inplace_ops.alias_inplace_update(buf, t, tgt_id) else: div = int(tgt_id.shape[1] // self.beam_size) for i, x_i in enumerate(tf.split(tgt_id, div, 1)): buf = inplace_ops.alias_inplace_update(buf, t + i, x_i) buf1 = tf.transpose(buf, [1, 0, 2]) buf1 = tf.reshape(buf1, [self.batch_size, self.max_steps * self.beam_size]) # select next_tgt_id as a function of previous target tokens if self.rule == '+1': next_tgt_id = (tgt_id + 1) next_tgt_id %= self.vocab_size elif self.rule == 'sum': # sum over all previous tokens in tgt_mask next_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(tgt_mask, tf.int32)) next_tgt_id %= self.vocab_size elif self.rule == 'fib': # select last token according to tgt_mask m = tgt_mask m *= tf.cast( tf.equal(tf.cumsum(m, -1), tf.reduce_sum(m, -1, keepdims=True) - 1), m.dtype) last_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(m, tf.int32)) next_tgt_id = (last_tgt_id + tgt_id) % self.vocab_size # with a lower probably add extra +1 to the correct next_tgt_id n = self.vocab_size logits = 5 * tf.one_hot(next_tgt_id % n, n) logits += 4 * tf.one_hot((next_tgt_id + 1) % n, n) logits += 3 * tf.one_hot((next_tgt_id + 2) % n, n) logits += 2 * tf.one_hot((next_tgt_id + 3) % n, n) logits += 1 * tf.one_hot((next_tgt_id + 4) % n, n) # increase eos_score if current tgt_id contains 9 eos_id = 0 tgt_id_contains_9 = tf.logical_or( tf.equal(tgt_id % 10, 9), tf.equal((tgt_id // 10) % 10, 9)) logits += 9 * tf.einsum('V,BK->BKV', tf.one_hot(eos_id, self.vocab_size), tf.cast(tgt_id_contains_9, tf.float32)) # tie-breaking -- lower token id wins a little bit tie = np.arange(0., 1., 1. / n) tie /= tie.sum() logits -= tie logits = tf.nn.log_softmax(logits) dec_state = [buf] return logits, dec_state
def testDecoderFPropSplitBatch(self, dtype=tf.float32): with self.session(use_gpu=True) as sess: tf.random.set_seed(_TF_RANDOM_SEED) p = self._DecoderParams(dtype=dtype) dec = decoder.TransformerDecoder(p) encoder_outputs, targets, _ = self._Inputs(dtype=dtype) src_enc1, src_enc2 = tf.split(encoder_outputs.encoded, 2, 1) src_paddings1, src_paddings2 = tf.split(encoder_outputs.padding, 2, 1) # source idx <-> target idx: # 0 <-> (0, 4), 1 <-> (1, 5), 2 <-> (2, 6), 3 <-> (3, 7) tgts = ig_helper.SplitDictOfTensors(targets, 4) targets1 = py_utils.NestedMap({ 'ids': tf.concat([tgts[0]['ids'], tgts[2]['ids']], 0), 'labels': tf.concat([tgts[0]['labels'], tgts[2]['labels']], 0), 'weights': tf.concat([tgts[0]['weights'], tgts[2]['weights']], 0), 'paddings': tf.concat([tgts[0]['paddings'], tgts[2]['paddings']], 0) }) targets2 = py_utils.NestedMap({ 'ids': tf.concat([tgts[1]['ids'], tgts[3]['ids']], 0), 'labels': tf.concat([tgts[1]['labels'], tgts[3]['labels']], 0), 'weights': tf.concat([tgts[1]['weights'], tgts[3]['weights']], 0), 'paddings': tf.concat([tgts[1]['paddings'], tgts[3]['paddings']], 0) }) loss, _ = dec.FPropDefaultTheta(encoder_outputs, targets).metrics['loss'] encoder_outputs1 = py_utils.NestedMap( encoded=src_enc1, padding=src_paddings1, segment_id=None) loss1, _ = dec.FPropDefaultTheta(encoder_outputs1, targets1).metrics['loss'] encoder_outputs2 = py_utils.NestedMap( encoded=src_enc2, padding=src_paddings2, segment_id=None) loss2, _ = dec.FPropDefaultTheta(encoder_outputs2, targets2).metrics['loss'] tf.global_variables_initializer().run() actual_loss, actual_loss1, actual_loss2 = sess.run([loss, loss1, loss2]) print('actual loss = ', actual_loss) print('actual loss1 = ', actual_loss1) print('actual loss2 = ', actual_loss2) self.assertAlmostEqual( actual_loss, np.mean([actual_loss1, actual_loss2]), delta=0.0001)
def ComputeNormalizedWER(self, hyps, refs, num_hyps_per_beam): # Filter out all '<epsilon>' tokens for norm_wer computation. hyps_no_epsilon = tf.strings.regex_replace(hyps, '(<epsilon>)+', ' ') # norm_wer is size [num_transcripts * hyps_per_beam, 2] norm_wer = decoder_utils.ComputeWer(hyps_no_epsilon, refs) # Split into two tensors of size [num_transcripts * hyps_per_beam, 1] norm_wer_errors, norm_wer_words = tf.split(norm_wer, [1, 1], 1) shape = [-1, num_hyps_per_beam] norm_wer_errors = tf.reshape(norm_wer_errors, shape) norm_wer_words = tf.reshape(norm_wer_words, shape) return norm_wer_errors, norm_wer_words
def testForwardPassSplitBatch(self): with self.session(use_gpu=False): bs = 8 sl = 20 tf.random.set_seed(8372749040) p = self._EncoderParams() p.random_seed = 1234 mt_enc = encoder.TransformerEncoder(p) batch = py_utils.NestedMap() batch.ids = tf.constant( np.random.randint(low=0, high=63, size=[bs, sl], dtype=np.int32)) batch.paddings = tf.zeros([bs, sl]) out = mt_enc.FPropDefaultTheta(batch) enc_out = out.encoded emb_out = out.embedded_inputs inputs1, inputs2 = tf.split(batch.ids, 2, 0) paddings1, paddings2 = tf.split(batch.paddings, 2, 0) batch.ids = inputs1 batch.paddings = paddings1 out1 = mt_enc.FPropDefaultTheta(batch) enc_out1 = out1.encoded emb_out1 = out1.embedded_inputs batch.ids = inputs2 batch.paddings = paddings2 out2 = mt_enc.FPropDefaultTheta(batch) enc_out2 = out2.encoded emb_out2 = out2.embedded_inputs self.evaluate(tf.global_variables_initializer()) actual_enc_out, actual_enc_out1, actual_enc_out2, \ actual_emb_out, actual_emb_out1, actual_emb_out2 = self.evaluate( [enc_out, enc_out1, enc_out2, emb_out, emb_out1, emb_out2]) self.assertAllClose(actual_enc_out, np.concatenate([actual_enc_out1, actual_enc_out2], 1)) self.assertAllClose(actual_emb_out, np.concatenate([actual_emb_out1, actual_emb_out2], 1))
def StreamStep(self, theta, inputs, paddings, state0): """Runs single step. Args: theta: A NestedMap of layer params. inputs: [b, 1, d]. paddings: A 0/1 valued tensor of shape [b, 1]. state0: A NestedMap of tensors of the same struct as returned by zero_state(). Returns: outputs: A NestedMap of tensors consisting: padding: the same as input paddings. state1: A NestedMap of tensors of the same struct as state0. """ p = self.params assert p.is_causal state1 = py_utils.NestedMap() with tf.name_scope(f'{p.name}/StreamStep'): unnormalized_inputs = inputs inputs = self.ln.FProp(theta.ln, inputs) if p.split_act_gated_linear_start: act_inputs = self.linear_start_act.FProp( theta.linear_start_act, inputs) gated_inputs = self.linear_start_gated.FProp( theta.linear_start_gated, inputs) else: inputs = self.linear_start.FProp(theta.linear_start, inputs) gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1) inputs = self._GLU(gated_inputs, act_inputs) # TODO(jamesqin): inroduce depthwise conv2d with 3d inputs. # TODO(jamesqin): optimize DepthwiseConv1D.StreamStep() # [b, t, d] --> [b, t, 1, d] inputs = tf.expand_dims(inputs, 2) # [b, t, 1, d] inputs, paddings, conv_state1 = self.depthwise_conv1d.StreamStep( theta.depthwise_conv1d, inputs, paddings, state0.conv_state) state1.conv_state = conv_state1 # [b, t, d] inputs = self._NormalizeStep(theta, inputs, paddings, state0, state1) inputs = self._ApplyActivation(inputs, p.conv_activation) inputs = self.linear_end.FProp(theta.linear_end, inputs) inputs = self.dropout.FProp(theta.dropout, inputs) output = inputs + unnormalized_inputs return output, paddings, state1
def _InputBatch(self): np.random.seed(1) bs, sl = 10, 7 src_ids = tf.constant( np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32)) tgt_ids = tf.constant( np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32)) tgt_labels = tf.constant( np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32)) tgt_weights = tf.constant(np.ones(shape=[bs, sl], dtype=np.float32)) src_paddings = tf.zeros([bs, sl]) tgt_paddings = tf.zeros([bs, sl]) ret = py_utils.NestedMap() ret.src = py_utils.NestedMap() ret.tgt = py_utils.NestedMap() if self.params.split: src_ids = tf.split(src_ids, 2, 0) src_paddings = tf.split(src_paddings, 2, 0) tgt_ids = tf.split(tgt_ids, 2, 0) tgt_labels = tf.split(tgt_labels, 2, 0) tgt_paddings = tf.split(tgt_paddings, 2, 0) tgt_weights = tf.split(tgt_weights, 2, 0) ret.src.ids = tf.cond( tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0), lambda: src_ids[0], lambda: src_ids[1]) ret.src.paddings = tf.cond( tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0), lambda: src_paddings[0], lambda: src_paddings[1]) ret.tgt.ids = tf.cond( tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_ids[0], lambda: tgt_ids[1]) ret.tgt.labels = tf.cond( tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_labels[0], lambda: tgt_labels[1]) ret.tgt.paddings = tf.cond( tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_paddings[0], lambda: tgt_paddings[1]) ret.tgt.weights = tf.cond( tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0), lambda: tgt_weights[0], lambda: tgt_weights[1]) else: ret.src.ids = src_ids ret.src.paddings = src_paddings ret.tgt.ids = tgt_ids ret.tgt.labels = tgt_labels ret.tgt.paddings = tgt_paddings ret.tgt.weights = tgt_weights return ret
def partition_tensor(cls, tensor, partition_info): """Returns partitioned tensors.""" metadata = (TensorPartitioner.partition_metadata( tensor, partition_info)) # Split from last to first axis. partitioned_tensors = [tensor] rank = len(metadata.num_splits_per_dim) for raxis, (num_splits, sizes) in enumerate( zip(reversed(metadata.num_splits_per_dim), reversed(metadata.split_sizes_per_dim))): if num_splits > 1: tmp_partitioned_tensors = [] for item in partitioned_tensors: tmp_partitioned_tensors += tf.split(item, sizes, axis=rank - raxis - 1) partitioned_tensors = tmp_partitioned_tensors return partitioned_tensors
def FProp(self, theta, inputs, paddings): """Builds FProp graph. Args: theta: A NestedMap of Tensors, see base class. inputs: A Tensor of shape [batch, seqlen, dim0]. paddings: A Tensor of shape [batch, seqlen]. Returns: output: A Tensor of shape [batch, seqlen, dim0]. out_paddings: A Tensor of shape [batch, seqlen]. """ p = self.params with tf.name_scope(p.name): unnormalized_inputs = inputs inputs = self.ln.FProp(theta.ln, inputs) if p.split_act_gated_linear_start: act_inputs = self.linear_start_act.FProp( theta.linear_start_act, inputs) gated_inputs = self.linear_start_gated.FProp( theta.linear_start_gated, inputs) else: inputs = self.linear_start.FProp(theta.linear_start, inputs) gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1) inputs = self._GLU(gated_inputs, act_inputs) # TODO(jamesqin): inroduce depthwise conv2d with 3d inputs. # [b, t, d] --> [b, t, 1, d] inputs = tf.expand_dims(inputs, 2) theta.depthwise_conv1d.w = moe_layers.Split( theta.depthwise_conv1d.w, 2, p.xla_num_partitions) inputs, paddings = self.depthwise_conv1d.FProp( theta.depthwise_conv1d, inputs, paddings) inputs = self._Normalize(theta, inputs, paddings) inputs = self._ApplyActivation(inputs, p.conv_activation) inputs = self.linear_end.FProp(theta.linear_end, inputs) inputs = self.dropout.FProp(theta.dropout, inputs) output = inputs + unnormalized_inputs return output, paddings
def SplitTensors(xs, num_splits): """Splits tensors in `xs` evenly into num_splits along the 1st dimenion. Args: xs: A tuple of tensors. Each tensor's 1st dimension is the same size. num_splits: A python integer. Returns: A tuple of lists of tensors, num elements in the tuple = len(xs). i-th element in each list corresponds to i-th split of each tensor in xs along the first dimension of each tensor. """ # assert first dim of all tensors in xs is equal batch_dims = [tf.shape(x)[0] for x in xs] all_batch_dims = tf.stack(batch_dims) all_batch_dims = py_utils.with_dependencies([ py_utils.assert_equal(all_batch_dims, tf.shape(xs[0])[0], message='first dim of tensors in xs must match'), py_utils.assert_greater_equal( tf.shape(xs[0])[0], num_splits, message='first dim of tensors in xs must be greater than num_splits' ) ], all_batch_dims) splits = ComputeSplits(tf.shape(xs[0])[0], num_splits) print("splits " + str(splits)) # add the above assertion into the compute graph splits = py_utils.with_dependencies([all_batch_dims], splits) print("splits 2 " + str(splits)) print("xs " + str(xs)) # this step get x # splits is not the number of spilits, it is the #split_xs = [tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs] split_xs = [ tf.split(axis=0, num_or_size_splits=num_splits, value=x) for x in xs ] print("split_xs " + str(split_xs)) return split_xs
def _DistortBrightnessAndColor(image): """Distorts brightness and color of the input image. Args: image: 3-D Tensor containing single image in [0, 1]. Returns: 3-D Tensor color-distorted image in range [0, 1] """ br_delta = tf.random.uniform([], -32. / 255., 32. / 255.) cb_factor = tf.random.uniform([], -0.1, 0.1) cr_factor = tf.random.uniform([], -0.1, 0.1) channels = tf.split(axis=2, num_or_size_splits=3, value=image) red_offset = 1.402 * cr_factor + br_delta green_offset = -0.344136 * cb_factor - 0.714136 * cr_factor + br_delta blue_offset = 1.772 * cb_factor + br_delta channels[0] += red_offset channels[1] += green_offset channels[2] += blue_offset return tf.clip_by_value(tf.concat(channels, axis=2), 0., 1.)
def _GLUFn(inputs): gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1) return act_inputs * tf.sigmoid(gated_inputs)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def FProp(self, theta, *args): """Run multiple cells in different devices in a pipelining manner. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: Non-keyworded variable length argument list of input tensors. Returns: A list of output tensors """ # TODO(huangyp): handle optional None inputs. p = self.params if p.is_eval: outputs = _ToTuple(args) for (name, l) in self._before_layers: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) for (name, l) in self._cells: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) return outputs num_cells = len(p.cell_tpl) cluster = self.cluster # Compute shapes of input and output tenors. input_tenors = _ToTuple(args) mini_batch_size = input_tenors[0].get_shape().as_list()[p.batch_dim] if p.state_dtype: state_dtype = p.state_dtype else: state_dtype = input_tenors[0].dtype if p.num_micro_batches > mini_batch_size: p.num_micro_batches = mini_batch_size micro_batch_size = mini_batch_size // p.num_micro_batches input_shapes = () for input_tensor in input_tenors: if input_tensor is not None: input_shape = input_tensor.get_shape().as_list() input_shape[p.batch_dim] = micro_batch_size input_shapes += (tf.TensorShape(input_shape),) else: input_shapes += (None,) state_shapes = self._CalculateOutputShapes(input_shapes) def GetCellFn(i): """Get the ith feature extraction layer.""" def CellFn(theta, state0, inputs): """A cell fn is exectued inside of StackedRecurrent.""" del state0 frop_inputs = [] for input_idx in range(len(state_shapes[i])): name = 's{}'.format(input_idx) if state_shapes[i][input_idx] is not None: inputs[name].set_shape(state_shapes[i][input_idx]) frop_inputs.append(inputs[name]) else: frop_inputs.append(None) with CellFnFropOpReplacementWrapper(): tf.logging.info('cell {} input {}'.format(i, frop_inputs)) mb_tensor = inputs[_MICRO_BATCH_STATE_NAME] SetOverWriteGlobalStep(mb_tensor) _, cell = self._cells[i] outputs = cell.FProp(theta, *frop_inputs) state1 = py_utils.NestedMap() state1[_MICRO_BATCH_STATE_NAME] = mb_tensor outputs = _ToTuple(outputs) assert len(outputs) == len(state_shapes[i + 1]) for output_idx in range(len(outputs)): if outputs[output_idx] is not None: name = 's{}'.format(output_idx) state1[name] = outputs[output_idx] return state1, py_utils.NestedMap() return CellFn cell_fns = [] accumulator_layers = [] thetas = [] init_states = [] devices = [] for cell_idx in range(num_cells): cell_name, cell = self._cells[cell_idx] accumulator_layers.append(cell) cell_fns.append(GetCellFn(cell_idx)) thetas.append(theta[cell_name]) init_state = py_utils.NestedMap() init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype) for output_idx in range(len(state_shapes[cell_idx + 1])): name = 's{}'.format(output_idx) if state_shapes[cell_idx + 1][output_idx] is not None: init_state[name] = tf.zeros( state_shapes[cell_idx + 1][output_idx], dtype=state_dtype) init_states.append(init_state) devices.append(cluster.WorkerDeviceInModelSplit(cell_idx)) cell_grads = [None] * num_cells cell_outs = [lambda x: x] * num_cells cell_out_grads = [lambda x: x] * num_cells with tf.device(devices[0]): previous = input_tenors for (name, l) in self._before_layers: previous = l.FProp(theta[name], *previous) previous = _ToTuple(previous) inputs = py_utils.NestedMap() gs_tensor = py_utils.GetGlobalStep() inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([ tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype) for t in range(p.num_micro_batches) ]) # TODO(huangyp, dehao): apply dehao's trick to reshape the input tensor # to [p.num_micro_batches, -1, 128]. for output_idx, output_tenor in enumerate(previous): name = 's{}'.format(output_idx) if output_tenor is not None: output_tenor = tf.stack( tf.split(output_tenor, p.num_micro_batches, axis=p.batch_dim)) inputs[name] = output_tenor output, _ = recurrent.StackedRecurrent( devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs, accumulator_layers=accumulator_layers, unused_acc_state=True) with tf.device(devices[-1]): output_tensors = [] for output_idx in range(len(state_shapes[-1])): state_shape = state_shapes[-1][output_idx] if state_shape is None: output_tensors.append(None) continue output_name = 's{}'.format(output_idx) output_tensor = output[output_name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, len(state_shape) + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) state_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, state_shape) output_tensors.append(output_tensor) tf.logging.info('pipeline output = {}'.format(output_tensors)) if len(output_tensors) == 1: return output_tensors[0] return tuple(output_tensors)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if 'bucket_keys' in batch: # Hack: bucket_keys are not needed on TPU. del batch['bucket_keys'] tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) # For executor-driven multiple programs, we need more fine-grained # access rather than using a single global graph collection. self.tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def flat_beam_search(batch_size, beam_size, max_steps, dec_callback, dec_state, bos_id=1, eos_id=2, length_norm_alpha=0.8, beam_gap=3.0, top_k_fn=tf.math.top_k, prefix=None, prefix_len=None, fprop_dtype=tf.float32, ext_size=0, nbest_size=None, debug=True): """Flat beam search. Args: batch_size: batch size beam_size: beam size limit in number of hyps max_steps: max steps dec_callback: decoder callback (see above) dec_state: decoder state bos_id: <s> token id eos_id: </s> token id length_norm_alpha: length normalization parameter beam_gap: early stopping threshold; None to disable top_k_fn: top_k function to call prefix: (optional) int32 tensor [batch_size, prefix_max] prefix_len: (optional) int32 tensor [batch_size] fprop_dtype: fprop dtype ext_size: int >= beam_size, extension buffer size nbest_size: number of returned hyps, default is beam_size debug: log intermediate vlaues with tpu_summary.tensor() Returns: (loop_vars, dec_state, nbest) where nbest = (topk_ids, topk_len, topk_score) """ assert beam_size > 0 assert batch_size > 0 assert max_steps > 0 buf_size = beam_size * max_steps output_len = max_steps if prefix is None: assert prefix_len is None prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32) prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id prefix_len = tf.ones([batch_size], dtype=tf.int32) else: assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape) assert int(prefix_len.shape[0]) == batch_size, (batch_size, prefix_len.shape) output_len += int(prefix.shape[1]) if debug: tpu_summary.tensor('prefix', prefix) tpu_summary.tensor('prefix_len', prefix_len) with tf.name_scope('init_state'): t = tf.constant(0) tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_id += bos_id tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size), buf_size, dtype=fprop_dtype) hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype) # penalize all hyps except the first hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) nbest_size = nbest_size or beam_size nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype) nbest_score -= 1e9 nbest_score_norm = nbest_score nbest_mask = tf.zeros([batch_size, nbest_size, buf_size], dtype=fprop_dtype) with tf.name_scope('init_ext'): # Initialize the extension buffer. # # Extension buffer stores a (potentially large) set of 'extensions', # which consist of a hypothesis (represented by ext_mask) and next token # (represented by ext_id). At each decoder iteration, top_k extensions # from each hypothesis are added to the buffer and sorted by score. # # Then top beam_size extensions are removed from the buffer and used # in the next decoder iteration. And top 'ext_size' remaining extensions # are carried over to be possibly evaluated at a later step. # # As a result of this manipulation, the decoder is no longer restricted # to always compare hyps of the same token length at each iteration. # In particular, for a fixed length N it can generate more than beam_size # terminated hyps. # # Setting ext_size = 0 disables this feautre. if ext_size: ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32) ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype) ext_score -= 1e9 ext_mask = tf.zeros([batch_size, ext_size, buf_size], dtype=fprop_dtype) else: ext_size = ext_id = ext_score = ext_mask = 0 with tf.name_scope('init_prefix'): # rename prefix->pfx for shorter variables pfx = tf.cast(prefix, tf.int32) pfx_len = tf.cast(prefix_len, tf.int32) del prefix, prefix_len # Before the first call to dec_callback() the prefix shall be packed into # the tgt_id buffer as follows: # # [ P P P P P P - - - - - - P* - - - ] ^ # [ P P P P P P P P P P - - P* - - - ] | batch # [ P - - - - - - - - - - - P* - - - ] V # |<---- prefix len ----> |<-- beam --> # # The last meaningful token in the prefix (P*) # must be located at the same position in all batch rows. # # We then make one dec_callback() with full prefix (minus P*) # which will populate the initial dec_state # (for transformer -- self-attention key/value cache) # # The last block [batch, beam] then becomes the first tgt_id for the loop. pfx_max = int(pfx.shape[1]) pfx_mul = pfx_max // beam_size assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size) pfx_time = tf.range(pfx_max) pfx_pad = tf.cast( tf.less(tf.expand_dims(pfx_time, 0), tf.expand_dims(pfx_len - 1, 1)), tf.int32) pfx_id = pfx * pfx_pad pfx_last = einsum_i32( 'BT,BT->B', pfx, tf.one_hot(pfx_len - 1, pfx_max, dtype=fprop_dtype)) buf_time = tf.range(buf_size) pfx_time_mask = tf.cast( tf.less_equal(tf.expand_dims(buf_time, 0), tf.expand_dims(pfx_time, 1)), fprop_dtype) pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype), pfx_time_mask) pfx_segment_id = pfx_pad pfx_pos = pfx_time * pfx_pad if debug: tpu_summary.tensor('pfx_id', pfx_id) tpu_summary.tensor('pfx_len', pfx_len) tpu_summary.tensor('pfx_pos', pfx_pos) tpu_summary.tensor('pfx_last', pfx_last) # Now call decoder with prefix minus P*: # 'dec_state' now shall contain the key/value cache for prefix tokens # (for transformer models), and 'logits' we can either discard or # roll into the initial hyp_score. Discard is simpler. with tf.name_scope('prefix_fprop'): # TODO(krikun): remove extra type checks assert (pfx_id.dtype == tf.int32), (pfx_id.dtype) assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype) assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype) assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype) assert (t.dtype == tf.int32), (t.dtype) logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos, pfx_mask, dec_state, t) del logits # Now construct the initial state for the rest of the beam search loop. # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape # 'tgt_pos' is different for each batch row and is equal to prefix_len # 'tgt_segment_id' always 1 (no packing) # 'hyp_score' is 0 for beam=0 and negative for beam>=1 tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( pfx_last, 1) tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( (pfx_len - 1), 1) hyp_score = tf.zeros( [batch_size, beam_size], dtype=fprop_dtype) - tf.cast( tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) # TODO(krikun) Here we make initial 't' constant and determined by the # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic # as t ~ max(pfx_len) / beam_size and this will more steps for beam search # however 'max' results in a very slow all-to-all for 'max' on 16x16 # and variable number of decoder steps may result in bad latency. t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32) # Initial tgt_mask is such that each token P* has attention on itself # (as usual) and on all prefix tokens before it, which are not padding. tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.cast( tf.expand_dims( tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1), fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) if debug: tpu_summary.tensor('tgt_id', tgt_id) tpu_summary.tensor('tgt_pos', tgt_pos) tpu_summary.tensor('tgt_mask', tgt_mask) tpu_summary.tensor('t', t) with tf.name_scope('init_hist'): # h_tgt_id is used to recover topk_ids from nbest_mask h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps) h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps) # When non-trivial prefix is present we also write prefix ids to # h_tgt_id so that the full sequence including prefix can be recovered # by unmask() below. When prefix is empty, pfx_id shape is [batch, 0] # and the loop below becomes a no-op. # TODO(krikun): maybe a tf.while_loop is more appropriate here. for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)): h_tgt_id = h_tgt_id.write(i, x_i) for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)): h_tgt_pos = h_tgt_pos.write(i, x_i) hist = (h_tgt_id, h_tgt_pos) tf.logging.info('hist=%r', hist) nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm) tf.logging.info('nbest_hyps=%r', nbest_hyps) ext = (ext_id, ext_score, ext_mask) tf.logging.info('ext=%r', ext) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) def loop_step(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (ext_id, ext_score, ext_mask) = ext (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id') h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos') # not using tf.ones() here because of XLA compilation error tgt_segment_id = tgt_id * 0 + 1 logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos, tgt_mask, dec_state, t) # take predicted EOS score for each hyp and compute normalized score eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype) def length_norm(t): t = tf.cast(t, fprop_dtype) alpha = length_norm_alpha tf.logging.info('length_norm.alpha=%r', alpha) return tf.math.pow((t + 5.) / 5., alpha) hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1) eos_score_norm = eos_score / length_norm(hyp_len) # update the n-best list nbest_hyps = update_nbest(nbest_hyps, (tgt_mask, hyp_score, eos_score_norm)) if debug: tpu_summary.tensor('eos_score', eos_score) tpu_summary.tensor('hyp_len', hyp_len) # take top k tokens for each hyp k = beam_size with tf.name_scope('topk1'): top_score, top_id = top_k_fn(logits, k) top_score = tf.cast(top_score, fprop_dtype) top_score += tf.expand_dims(hyp_score, -1) top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype) top_score = tf.reshape(top_score, [batch_size, beam_size * k]) top_id = tf.reshape(top_id, [batch_size, beam_size * k]) top_mask = tf.repeat(tgt_mask, beam_size, 1) if debug: tpu_summary.tensor('top_id', top_id) tpu_summary.tensor('top_score', top_score) # tpu_summary.tensor('top_mask', top_mask) with tf.name_scope('update_ext'): # combine top k tokens with extension buffer (if any) if ext_size: ext_id = tf.concat([ext_id, top_id], 1) ext_score = tf.concat([ext_score, top_score], 1) ext_mask = tf.concat([ext_mask, top_mask], 1) else: ext_id, ext_score, ext_mask = top_id, top_score, top_mask # sort by score ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size) i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype) ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1) ext_id = einsum_i32('bk,bjk->bj', ext_id, i1) # pick top beam_size extensions to evaluate at next iteration if ext_size: hyp_score = ext_score[:, :beam_size] ext_score = ext_score[:, beam_size:] tgt_id = ext_id[:, :beam_size] ext_id = ext_id[:, beam_size:] tgt_mask = ext_mask[:, :beam_size] ext_mask = ext_mask[:, beam_size:] else: hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask ext_score = ext_id = ext_mask = 0 tgt_pos = tf.reduce_sum(tgt_mask, -1) tgt_pos = tf.cast(tgt_pos, tf.int32) t += 1 with tf.name_scope('tgt_mask_extend'): tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) ext = (ext_id, ext_score, ext_mask) hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) return loop_vars, dec_state def loop_cond(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) if beam_gap is None: (t, _, _, _, _, _, _, _) = loop_vars return t < max_steps else: (t, _, _, _, _, nbest_hyps, _, _) = loop_vars (_, nbest_score, _) = nbest_hyps # stop early if all current hyps are significantly worse than nbest diff = tf.reduce_min( tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1)) return tf.math.logical_and(t < max_steps, diff < beam_gap) with tf.name_scope('flat_beam_search_loop'): (loop_vars, dec_state) = tf.while_loop(loop_cond, loop_step, loop_vars=(loop_vars, dec_state), back_prop=False, swap_memory=False, maximum_iterations=max_steps) # flatten all tensorarrays into tensors (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.stack() h_tgt_pos = h_tgt_pos.stack() hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) # recover topk_ids from nbest_mask and tgt_id history h = tf.transpose(h_tgt_id, [1, 0, 2]) h = tf.reshape(h, [batch_size, buf_size]) def unmask(h, m): with tf.name_scope('unmask'): tpu_summary.tensor('unmask_h', h) tpu_summary.tensor('unmask_m', m) t = tf.cumsum(m, -1) * m - 1 mh = einsum_i32('bkt,bt->bkt', m, h) t2 = tf.one_hot(tf.cast(t, tf.int32), output_len, dtype=fprop_dtype) x = einsum_i32('bkt,bktT->bkT', mh, t2) return tf.cast(x, h.dtype) topk_ids = unmask(h, nbest_mask) topk_len = tf.reduce_sum(nbest_mask, -1) topk_len = tf.cast(topk_len, tf.int32) # add eos, because nbest_mask does not encode eos topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32) topk_len += 1 topk_len = tf.minimum(topk_len, output_len) topk_score = nbest_score_norm nbest = (topk_ids, topk_len, topk_score) return loop_vars, dec_state, nbest
def _GLU(self, inputs): p = self.params gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1) return self._ApplyActivation( act_inputs, p.glu_activation) * tf.sigmoid(gated_inputs)
def IsWithinBBox3D(points_3d, bboxes_3d): """Checks if points are within a 3-d bbox. Args: points_3d: [..., num_points, 3] float32 Tensor specifying points in 3-d space as [x, y, z] coordinates. bboxes_3d: [..., num_bboxes, 7] float32 Tensor specifying a 3-d bboxes specified as [x, y, z, dx, dy, dz, phi] where x, y and z is the center of the box. Returns: boolean Tensor of shape [..., num_points, num_bboxes] indicating whether the points belong within each box. """ # Check that points_3d and bboxes_3d have the same rank. bboxes_rank = py_utils.GetRank(bboxes_3d) points_3d = py_utils.HasRank(points_3d, bboxes_rank) leading_shape = py_utils.GetShape(bboxes_3d)[:-2] # Check that both points_3d and bboxes_3d have the same leading shape. points_3d = py_utils.HasShape(points_3d, leading_shape + [-1, 3]) bboxes_3d = py_utils.HasShape(bboxes_3d, leading_shape + [-1, 7]) num_points = py_utils.GetShape(points_3d)[-2] num_bboxes = py_utils.GetShape(bboxes_3d)[-2] bbox_corners = BBoxCorners(bboxes_3d) bbox_corners = py_utils.HasShape(bbox_corners, leading_shape + [num_bboxes, 8, 3]) # First four points are the top of the bounding box. # Counter-clockwise arrangement of points specifying 2-d Euclidean box. # (x0, y1) <--- (x1, y1) # ^ # | # | # (x0, y0) ---> (x1, y0) bboxes_2d_corners = bbox_corners[..., 0:4, 0:2] # Determine if points lie within 2-D (x, y) plane for all bounding boxes. points_2d = points_3d[..., :2] is_inside_2d = IsWithinBBox(points_2d, bboxes_2d_corners) is_inside_2d = py_utils.HasShape(is_inside_2d, leading_shape + [num_points, num_bboxes]) # Determine if points lie with the z-dimension for all bounding boxes. [_, _, z, _, _, dz, _] = tf.split(bboxes_3d, 7, axis=-1) def _ComputeLimits(center, width): left = center - width / 2.0 right = center + width / 2.0 return left, right z0, z1 = _ComputeLimits(z[..., 0], dz[..., 0]) z_points = points_3d[..., 2:] is_inside_z = tf.math.logical_and( tf.less_equal(z_points, z1[..., tf.newaxis, :]), tf.greater_equal(z_points, z0[..., tf.newaxis, :])) is_inside_z = py_utils.HasShape(is_inside_z, leading_shape + [num_points, num_bboxes]) return tf.math.logical_and(is_inside_z, is_inside_2d)
def _StackAndSplit(x): # Split tensors into microbatches. if x is None: return None return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim))
def Gate(x): u, v = tf.split(x, 2, axis=-1) return u * tf.sigmoid(v)
def _GatedTanhFn(inputs): gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1) return tf.tanh(act_inputs) * tf.sigmoid(gated_inputs)