def FProp(self, theta, inputs, paddings, class_emb): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [batch, ..., dim]. paddings: The paddings tensor. Shaped [batch, ..., 1], with the same rank as the input tensor. class_emb: The conditioning inputs, Shaped [batch, emb_dim]. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ if py_utils.testonly_skip_norm_layers(): return inputs p = self.params batch = py_utils.GetShape(inputs)[0] class_emb = py_utils.HasShape(class_emb, [batch, p.class_emb_dim]) if not py_utils.use_tpu(): class_emb = py_utils.with_dependencies([ py_utils.assert_less_equal( tf.cast(class_emb, tf.int32), 1, name='one_hot_assert1'), py_utils.assert_greater_equal( tf.cast(class_emb, tf.int32), 0, name='one_hot_assert2'), py_utils.assert_equal(tf.ones([batch], tf.int32), tf.cast(tf.reduce_sum(class_emb, -1), tf.int32), name='one_hot_assert3'), ], class_emb) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings=paddings, class_emb=class_emb) return self._ComputeBN(inputs, paddings, gamma, beta, norm_mean, norm_variance)
def FProp(self, theta, inputs): """Apply projection to inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., input_dims]. Returns: Projected inputs. """ p = self.params with tf.name_scope(p.name): computation_cost.Add( self, 'flops', tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) * tf.cast( symbolic.EvalExpr(symbolic.TENSOR_VALUES, p.input_dims * p.output_dims), tf.int64) * 2) use_tpu = py_utils.use_tpu() shape = inputs.shape if use_tpu and (shape is not None and shape.rank is not None and shape.rank < 26): # Avoids reshape if feasible and uses Einsum. if shape.rank == 2: return tf.matmul(inputs, theta.w) else: s = ''.join([chr(x) for x in range(97, 123)]) # abc...xyz r = shape.rank return tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), inputs, theta.w) input_dim = py_utils.GetShape(inputs)[-1] act = tf.matmul(tf.reshape(inputs, [-1, input_dim]), theta.w) output_dim = tf.shape(theta.w)[-1] act = tf.reshape( act, tf.concat([tf.shape(inputs)[:-1], [output_dim]], axis=0)) return act
def infeed_bucket_batch_limit(self): """Returns the bucket batch limit for one infeed host.""" p = self.params cluster = self.cluster infeed_bucket_batch_limit = [ b * cluster.num_splits_per_client for b in p.bucket_batch_limit ] if p.use_per_host_infeed and cluster.num_tpu_hosts > 0: if not py_utils.use_tpu(): raise ValueError( 'Scaling to TPU hosts without TPUs. {}'.format( cluster.num_tpu_hosts)) tf.logging.info( 'scaling infeed_bucket_batch_limit num_tpu_hosts={}'.format( cluster.num_tpu_hosts)) infeed_bucket_batch_limit = [ x // cluster.num_tpu_hosts for x in infeed_bucket_batch_limit ] tf.logging.info( 'infeed_bucket_batch_limit={} num_splits_per_client={} bucket_batch_limit={}' .format(infeed_bucket_batch_limit, cluster.num_splits_per_client, p.bucket_batch_limit)) return infeed_bucket_batch_limit
def _CreateLayerVariables(self): p = self.params w_pc = py_utils.WeightParams( shape=[self._ids_per_shard, p.embedding_dim], init=p.params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) embedding_table_vars = [] for i in range(p.num_tpu_hosts): device_name = self.GetDeviceName(i) with tf.device(device_name), py_utils.outside_all_rewrites(): var_name = self.GetVariableName(i) self.CreateVariable(var_name, w_pc) embedding_var = self.vars[var_name] embedding_table_vars.append(embedding_var) # Remove from _private_vars / _private_thetas to be added later as wm. del self._private_vars[var_name] del self._private_theta[var_name] self._tpu_embedding_collection.AddTableVariables( self.table_name, embedding_table_vars) if not py_utils.use_tpu(): # We don't want to add this for TrainerTpu, otherwise the identity # reference leads to copying the embedding to the TPU for no reason. # However, this is needed for CPU (eval/decode/controller). self._private_vars['wm'] = embedding_table_vars self._private_theta['wm'] = [ tf.identity(v) for v in embedding_table_vars ] # Only trainer and controller need slot variables and load/retrieve ops. if not self.do_eval: self._load_op_list, self._retrieve_op_list = ( self.optimizer.CreateSlotVariablesAndOps( embedding_table_vars, self))
def _Moments(self, inputs, group_size): """Computes mean and variance over N,H,W dimensions in inputs.""" counts, mean_ss, variance_ss, _, = tf.nn.sufficient_statistics( inputs, axes=[0, 1, 2], keep_dims=False) self.accumulators.counts.Update(counts) self.accumulators.mean_ss.Update(mean_ss) self.accumulators.variance_ss.Update(variance_ss) if py_utils.use_tpu() and group_size > 1: num_shards = tpu_function.get_tpu_context().number_of_shards assert num_shards >= group_size assert num_shards % group_size == 0 num_groups = num_shards // group_size group_assignment = [] for g in range(num_groups): replica_ids = [g * group_size + i for i in range(group_size)] group_assignment.append(replica_ids) counts *= group_size mean_ss = tf.contrib.tpu.cross_replica_sum(mean_ss, group_assignment) variance_ss = tf.contrib.tpu.cross_replica_sum( variance_ss, group_assignment) mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None) return mean, variance
def _NestedMapFromBatchedOutputs(self, outputs): """Create a NestedMap from a tuple of outputs from generic_input_op.""" batch_size = self.InfeedBatchSize() shapes = self.Shape() shapes.VLog(0, 'input extractor shape: ') flatten_shapes = shapes.Flatten() dtypes = self.DType() assert dtypes.IsCompatible(shapes), '{} vs. {}'.format( dtypes.DebugString(), shapes.DebugString()) flatten_dtypes = dtypes.FlattenItems() assert len(flatten_shapes) == len(outputs), '{} vs. {}'.format( len(flatten_shapes), len(outputs)) assert len(flatten_dtypes) == len(outputs), '{} vs. {}'.format( len(flatten_dtypes), len(outputs)) rets = [] for (output, (name, dtype), shape) in zip(outputs, flatten_dtypes, flatten_shapes): assert dtype == output.dtype, '{}: {} vs. {}'.format( name, dtype, output.dtype) # Pad every output to make shapes fixed according to the corresponding # declared shape, since the shapes of outputs are lost through # generic_input_op. try: shape.assert_is_fully_defined() except ValueError as e: raise ValueError('Invalid shape for %s: %s' % (name, e)) padded = py_utils.PadOrTrimTo(output, [batch_size] + shape.as_list()) rets += [padded] rets = shapes.Pack(rets) if py_utils.use_tpu(): # Drops tf.string tensors, which is not supported on TPUs. rets = rets.Filter(lambda x: x.dtype != tf.string) return rets
def __init__(self, params): super().__init__(params) p = self.params self._before_layers = [] self._cells = [] num_cells = len(p.cell_tpl) before_tpl_device = '' cell_devices = [''] * num_cells if py_utils.use_tpu(): cluster = self.cluster before_tpl_device = cluster.WorkerDeviceInModelSplit(0) cell_devices = [ cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells) ] for l in p.before_tpl: with tf.device(before_tpl_device): self.CreateChild(l.name, l) self._before_layers.append((l.name, self.children[l.name])) for i, l in enumerate(p.cell_tpl): with tf.device(cell_devices[i]): self.CreateChild(l.name, l) self._cells.append((l.name, self.children[l.name]))
def _ProcessMASSInput(self, source_id, src): """Perform MASS input processing.""" if self.do_eval or self.mass_layer is None: # At eval time, we copy src to tgt return self._ProcessSingleInput(source_id, src, src) _, labels, paddings = self.StringsToIds(tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) weights = 1 - paddings actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32) src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id) mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = mass_out.src.ids features.src.paddings = paddings features.src.weights = weights features.src.task_ids = tf.cast(features.src.weights, dtype=tf.int32) * src_lang_ids features.src.ids_indicator = weights features.tgt = py_utils.NestedMap() features.tgt.ids = mass_out.tgt.ids features.tgt.labels = mass_out.tgt.labels features.tgt.paddings = paddings features.tgt.weights = mass_out.tgt.weights features.tgt.task_ids = tf.ones_like(features.src.task_ids, dtype=tf.int32) * tgt_lang_ids features.tgt.ids_indicator = weights if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = src return features.Transform(tf.squeeze)
def SplitInputBatch(self, num_splits): """Splits the current InputBatch into num_splits ways. Args: num_splits: The number of splits. Returns: A list of `.NestedMap`. Each `.NestedMap` represents the input tensors in one split. """ assert num_splits >= 1 print("num_splits " + str(num_splits)) batch = self.GetPreprocessedInputBatch() if num_splits == 1: # Special case. No split is needed. # this is the place the make 1 gpu different from 4 gpu return [batch] assert not py_utils.use_tpu() print("batch " + str(batch)) print("batch.Flatten " + str(batch.Flatten)) print("num_splits " + str(num_splits)) # batch is ok without any ? its this step that get symbol ? field_split = ig_helper.SplitTensors(batch.Flatten(), num_splits) print("field_split " + str(field_split)) num_fields = len(field_split) ret = [] for j in range(num_splits): print("j " + str(j)) split_flatten = [field_split[i][j] for i in range(num_fields)] print("split_flatten " + str(split_flatten)) split = batch.Pack(split_flatten) print("split " + str(split)) ret += [split] return ret
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): # Only the Trainer needs these ops. if py_utils.use_tpu(): # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops. load_tpu_embedding_stochastic_gradient_descent_parameters( parameters=table_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table = ( tpu_embedding_lib.tpu_ops. retrieve_tpu_embedding_stochastic_gradient_descent_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = cluster_factory.Current() num_tpu_hosts = cluster.num_tpu_hosts assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] first_batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if first_batch is None: first_batch = batch flat_batch = batch.FlattenItems() shapes, types = [], [] for k, x in flat_batch: assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes.append(x.shape) types.append(x.dtype) q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def _tpu_ordinal_function(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal(replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=_tpu_ordinal_function) else: input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], device_assignment=py_utils.GetTpuDeviceAssignment()) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) with tf.device(tf.contrib.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return first_batch.Pack(tensors)
def _process(source_id, record): del source_id num = tf.strings.to_number(record, tf.int32) if not tf_py_utils.use_tpu(): num = num * num return py_utils.NestedMap(num=num), 1
def _CreateLayerVariables(self): super()._CreateLayerVariables() p = self.params load_op_list = [] retrieve_op_list = [] # At the feature level, track which are associated # with "sequence embeddings". self._sequence_features = {} if py_utils.use_tpu(): num_cores = self.cluster.params.worker.tpus_per_replica global_batch_size = (self.params.batch_size * self.cluster.num_splits_per_client) table_to_config_dict = {} feature_to_config_dict = {} for table in self.tables: table_to_config_dict[table.table_name] = table.table_config load_op_list += table.load_op_list retrieve_op_list += table.retrieve_op_list for feature in table.input_keys: if table.max_sequence_length > 0: self._sequence_features[feature] = True feature_to_config_dict[ feature] = tpu_embedding_lib.FeatureConfig( table.table_name, max_sequence_length=table.max_sequence_length) tf.logging.info('adding load and retrieve ops to collection.') tf.add_to_collection(py_utils.TPU_EMBEDDING_LOAD_OPS, load_op_list) tf.add_to_collection(py_utils.TPU_EMBEDDING_RETRIEVE_OPS, retrieve_op_list) tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) assert len(tpu_embedding_collection) <= 1 if len(tpu_embedding_collection) == 1: tf.logging.info( 'TPUEmbedding API singleton already exists, reusing') self._tpu_embedding = tpu_embedding_collection[0] else: mode = tpu_embedding_lib.TRAINING device_config = tpu_embedding_lib.DeviceConfig( num_cores=num_cores, num_hosts=self.params.tables[0].num_tpu_hosts, job_name=self.cluster.params.worker.name) self._tpu_embedding = tpu_embedding_lib.TPUEmbedding( table_to_config_dict, feature_to_config_dict, global_batch_size, mode, master=None, pipeline_execution_with_tensor_core=( self.params.pipeline_execution_with_tensor_core), partition_strategy=p.partition_strategy, device_config=device_config) tf.add_to_collection(py_utils.TPU_EMBEDDING, self._tpu_embedding) tf.add_to_collection( py_utils.TPU_EMBEDDING_GRADIENT_MULTIPLIER_SCHEDULE, self.gradient_multiplier_schedule)
def ComputePredictions(self, theta, source_encs, source_paddings, targets, src_segment_id): """Decodes `targets` given encoded source. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_encs: source encoding, of shape [time, batch, depth]. source_paddings: source encoding's padding, of shape [time, batch]. targets: A dict of string to tensors representing the targets one try to predict. Each tensor in targets is of shape [batch, time]. src_segment_id: source segment id, of shape [time, batch]. Returns: A Tensor with shape [time, batch, params.softmax.input_dim]. """ p = self.params time, batch = py_utils.GetShape(source_paddings, 2) source_encs = py_utils.HasShape(source_encs, [time, batch, p.source_dim]) with tf.name_scope(p.name): target_ids = tf.transpose(targets.ids) target_paddings = py_utils.HasRank(targets.paddings, 2) target_paddings = tf.expand_dims(tf.transpose(target_paddings), 2) if p.packed_input: target_segment_id = tf.expand_dims(tf.transpose(targets.segment_ids), 2) else: target_segment_id = tf.zeros_like(target_paddings) if py_utils.use_tpu(): emb_device = self.cluster.WorkerDeviceInModelSplit(0) else: emb_device = '' with tf.device(emb_device): inputs = self.emb.EmbLookup(theta.emb, target_ids) inputs = self.ApplyClipping(theta, inputs) summary_utils.histogram('input_emb', inputs) inputs = self.ApplyDropout(inputs) self._emb_out = inputs # Layer 0 interwines with attention. (atten_ctxs, xs, atten_probs, _) = self.frnn_with_atten.FProp( theta.frnn_with_atten, source_encs, source_paddings, inputs, target_paddings, src_segment_id=src_segment_id, segment_id=target_segment_id) self._AddAttenProbsSummary(source_paddings, targets, [atten_probs]) atten_ctxs = self.ApplyClipping(theta, atten_ctxs) summary_utils.histogram('atten_ctxs', atten_ctxs) for i, (layer, layer_theta) in enumerate(zip(self.frnn, theta.frnn)): # Forward through Layer-(i + 1) because Layer-0 handled before. ys, _ = layer.FProp( layer_theta, tf.concat([xs, atten_ctxs], 2), target_paddings, segment_id=target_segment_id) ys = self.ApplyDropout(ys) if 1 + i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys summary_utils.histogram('layer_out_%s' % i, xs) if p.feed_attention_context_vec_to_softmax: xs = tf.concat([xs, atten_ctxs], 2) return xs
def FProp(self, theta): """Forward propagation. This default `FProp` implementation here supports batch splitting in synchronous and asynchronous training when sub-classes implement `FPropTower`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. Returns: A dict containing metrics pairs. One of the keys should be 'loss' and its value should be a (loss, num_predictions) pair. """ p = self.params cluster = cluster_factory.Current() with tf.name_scope('fprop'), tf.name_scope(p.name): all_fprop_metrics = [] if py_utils.use_tpu(): batch = self.input_generator.CreateTpuFeeds() with tf.name_scope('tower_0_0'): dec_metrics = self.FPropTower(theta, batch) all_fprop_metrics.append(dec_metrics) else: # Splits the input batch on the input device. num_splits = cluster.num_splits_per_client with tf.device(cluster.input_device): batches = self.input_generator.SplitInputBatch(num_splits) assert num_splits == len(batches) # dev_list_per_replica[i][j] is the i-th worker's j-th device. dev_list_per_replica = cluster.available_devices.tolist() # Asserts invariant of the total number of splits w.r.t., # splits per worker. splits_per_replica = cluster.num_splits_per_replica assert num_splits == splits_per_replica * len( dev_list_per_replica) for w_id, w_devs in enumerate(dev_list_per_replica): # Make local copy of the vars, shard on devices for this worker. theta_local = py_utils.CreateLocalTheta(theta, w_devs, label='worker %d' % w_id) for s_id in range(splits_per_replica): # s_id-th split for the w_id-th worker. split_id = splits_per_replica * w_id + s_id with py_utils.ModelSplit(split_id): with tf.device( cluster.WorkerDeviceInModelSplit(0)): with tf.name_scope('tower_%d_%d' % (w_id, s_id)): batch = self.input_generator.PreprocessInputBatch( batches[split_id]) dec_metrics = self.FPropTower( theta_local, batch) all_fprop_metrics.append(dec_metrics) metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics) # Adds stats about the input batch. metrics['num_samples_in_batch'] = (tf.convert_to_tensor( self.input_generator.InputBatchSize()), tf.constant(1.0)) # Generates summaries. for name, (value, weight) in six.iteritems(metrics): self.AddEvalMetric(name, value, weight) # Loss. self._loss, self._num_predicts = metrics['loss'] self._loss = py_utils.CheckNumerics(self._loss) return metrics
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.transpose(paddings) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def Transform(self, dataset): """Batches a dataset containing NestedMaps of tensors.""" p = self.params require_sequential_order = p.require_sequential_order or self.do_eval seqlen_fn = getattr(self._input_generator, p.seqlen_fn) def SetBucketKeys(example): example.bucket_keys = seqlen_fn(example) return example dataset = dataset.map(SetBucketKeys, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=require_sequential_order) dataset = dataset.filter( lambda x: x.bucket_keys <= p.bucket_upper_bound[-1]) dataset_structure = py_utils.NestedMap.FromNestedDict( tf.data.experimental.get_structure(dataset)) input_shape_fn = getattr(self._input_generator, p.input_shape_fn) padded_shapes = dataset_structure.TransformWithKey( lambda k, _: tf.TensorShape(input_shape_fn(k))) input_padding_fn = getattr(self._input_generator, p.input_padding_fn) padding_values = dataset_structure.TransformWithKey(input_padding_fn) dataset_structure.VLog(0, 'dataset_structure:') padded_shapes.VLog(0, 'padded_shapes:') bucket_batch_limit = [ batch_utils.scale_split_to_infeed( b, self._input_generator.params.use_per_host_infeed) for b in p.bucket_batch_limit ] dataset = dataset.apply( tf.data.experimental.bucket_by_sequence_length( lambda x: x.bucket_keys, # Upper-bound for bucket_by_sequence_length is exclusive, so add 1 # TODO(jeffreyzhao): There is a off-by-one bug with the upper bound # boundary check, so add 2 instead. Remove when fixed. [x + 2 for x in p.bucket_upper_bound], bucket_batch_limit + [1], padded_shapes=padded_shapes, padding_values=padding_values, pad_to_bucket_boundary=True, drop_remainder=py_utils.use_tpu())) if py_utils.use_tpu(): # Set static shapes for TPU. if min(bucket_batch_limit) != max(bucket_batch_limit): raise ValueError('TPU requires constant batch sizes.') else: b = bucket_batch_limit[0] def SetShape(element): for t in element.Flatten(): t.set_shape((b, ) + t.shape[1:]) return element dataset = dataset.map( SetShape, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=require_sequential_order) return dataset
def PostProcess(self, dec_out_dict, dec_metrics_dict): p = self.params assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys()) topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] if not py_utils.use_tpu(): utt_id = dec_out_dict['utt_id'] assert len(utt_id) == len(transcripts) norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) num_samples_in_batch = len(transcripts) dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch) def GetRefIds(ref_ids, ref_paddinds): assert len(ref_ids) == len(ref_paddinds) return_ids = [] for i in range(len(ref_ids)): if ref_paddinds[i] == 0: return_ids.append(ref_ids[i]) return return_ids total_norm_wer_errs = norm_wer_errors[:, 0].sum() total_norm_wer_words = norm_wer_words[:, 0].sum() dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) for ref_str, hyps in zip(transcripts, topk_decoded): filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) filtered_hyp = decoder_utils.FilterNoise(hyps[0]) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp) total_errs = 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_accurate_sentences = 0 key_value_pairs = [] if p.include_auxiliary_metrics: for i in range(len(transcripts)): ref_str = transcripts[i] if not py_utils.use_tpu(): tf.logging.info('utt_id: %s', utt_id[i]) if self.cluster.add_summary: tf.logging.info( ' ref_str: %s', ref_str.decode('utf-8') if p.log_utf8 else ref_str) hyps = topk_decoded[i] num_hyps_per_beam = len(hyps) ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] if self.cluster.add_summary: tf.logging.info(' ref_ids: %s', ref_ids) tf.logging.info(' top_hyp_ids: %s', top_hyp_ids) total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): if self.cluster.add_summary: tf.logging.info( ' %f: %s', score, hyp_str.decode('utf-8') if p.log_utf8 else hyp_str) filtered_hyp = decoder_utils.FilterNoise(hyp_str) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) ins, subs, dels, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) # Note that these numbers are not consistent with what is used to # compute normalized WER. In particular, these numbers will be # inflated when the transcript contains punctuation. tf.logging.info(' ins: %d, subs: %d, del: %d, total: %d', ins, subs, dels, errs) # Only aggregate scores of the top hypothesis. if n == 0: total_errs += errs total_ref_words += len( decoder_utils.Tokenize(filtered_ref)) if norm_wer_errors[i, n] == 0: total_accurate_sentences += 1 oracle_errs = min(oracle_errs, norm_wer_errors[i, n]) total_oracle_errs += oracle_errs dec_metrics_dict['wer'].Update( total_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['ter'].Update( total_token_errs / max(1., total_ref_tokens), total_ref_tokens) return key_value_pairs
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] first_batch = None tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_embedding_input_keys = ( tpu_embedding.feature_to_config_dict.keys() if tpu_embedding is not None else []) for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() tpu_embedding_features = [] for tpu_embedding_input_key in tpu_embedding_input_keys: tpu_embedding_feature = batch.pop( tpu_embedding_input_key) tpu_embedding_features.append( (tpu_embedding_input_key, tpu_embedding_feature)) if first_batch is None: first_batch = batch flat_batch = batch.FlattenItems() if tpu_embedding is not None: enqueue_dict_per_core = [ {} ] * tpu_embedding.num_cores_per_host num_cores_per_host = tpu_embedding.num_cores_per_host for tpu_embedding_input_key, tpu_embedding_feature in tpu_embedding_features: tpu_embedding_feature_splitted = tf.split( tpu_embedding_feature, num_cores_per_host) for core, split in enumerate( tpu_embedding_feature_splitted): enqueue_data = tpu_embedding_lib.EnqueueData( tf.squeeze(split, axis=[1])) enqueue_dict_per_core[core][ tpu_embedding_input_key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) shapes, types = [], [] for k, x in flat_batch: assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes.append(x.shape) types.append(x.dtype) q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def _tpu_ordinal_function(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=_tpu_ordinal_function) else: input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) with tf.device(tf.compat.v1.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return first_batch.Pack(tensors)
def FProp(self, theta, x, paddings=None, update=False): """Computes distances of the given input 'x' to all centroids. This implementation applies layer normalization on 'x' internally first, and the returned 'dists' is computed using the normalized 'x'. Args: theta: A `.NestedMap` of weights' values of this layer. x: A tensor of shape [B, L, N, H]. paddings: If not None, a tensor of shape [B, L]. update: bool, whether to update centroids using x. Returns: dists: "distances" of the given input 'x' to all centroids. Shape [B, L, N, K]. k_means_loss: the average squared Euclidean distances to the closest centroid, a scalar. """ p = self.params if paddings is None: paddings = tf.zeros_like(x[:, :, 0, 0]) # Shape [B, L, 1, 1] paddings_4d = paddings[:, :, None, None] if p.apply_layer_norm: x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon) # 'x' is normalized (but theta.means is not), we use negative dot product to # approximate the Euclidean distance here. dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means) # For padded positions we update the distances to very large numbers. very_large_dists = tf.ones_like(dists) * tf.constant( 0.1, dtype=dists.dtype) * dists.dtype.max paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters]) dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists) # Shape [B, L, N, K], the same as 'dists' above. nearest_one_hot = tf.one_hot( tf.math.argmin(dists, axis=-1), p.num_clusters, dtype=py_utils.FPropDtype(p)) # Same shape as the input 'x'. nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot, theta.means) diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid)) diff = py_utils.ApplyPadding(paddings_4d, diff) diff = tf.math.reduce_mean(diff, axis=2) # The commitment loss which when back proped against encourages the 'x' # values to commit to their chosen centroids. k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings) summary_utils.scalar('k_means/squared_distance_loss', k_means_loss) # TODO(zhouwk): investigate normalizing theta.means after each update. means_norm = tf.norm(theta.means) summary_utils.scalar('k_means/centroid_l2_norm/min', tf.math.reduce_min(means_norm)) summary_utils.scalar('k_means/centroid_l2_norm/mean', tf.math.reduce_mean(means_norm)) if not update: return dists, k_means_loss # To update the centroids (self.vars.means), we apply gradient descent on # the mini-batch of input 'x', which yields the following: # new_centroid = centroid + (1 - decay) * (x_mean - centroid) # where x_mean is the average over all the input vectors closest to this # centroid. # # Note that this approach is equivalent with backprop via # loss = tf.math.reduce_mean( # tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid))) # , except that here the learning rate is independently set via 'decay'. # Ensure that the padded positions are not used to update the centroids. nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot) # Sum away batch and sequence length dimensions to get per cluster count. # Shape: [N, K] per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1]) summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count) # Sum of the input 'x' per each closest centroid. sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x) if py_utils.use_tpu(): per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count) sum_x = tf.tpu.cross_replica_sum(sum_x) # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that # cluster's position will always be 0, hence 'sum_x' in that dimension will # be 0. new_means = sum_x / tf.maximum( tf.constant(1.0, dtype=per_cluster_count.dtype), tf.expand_dims(per_cluster_count, axis=-1)) # We use exponential moving average. TODO(zhouwk): investigate smooth this # over an exponentially moving averaged per cluster count. # # Note that we intentionally do not normalize the means after this update # as empirically this works better. update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means), self.vars.means.dtype) return py_utils.with_dependencies( [tf.assign_add(self.vars.means, update_means_diff)], dists), k_means_loss
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def __init__(self, cell_fn, cell_grad, theta, state0, inputs, extras, implicit_captures=None, unused_acc_state=None): """RNN helper class. Args: cell_fn: A python function, which computes: state1, extras = cell_fn(theta, state0, inputs[t, :]) cell_grad: A python function which computes: dtheta, dstate0, dinputs[t, :] = cell_grad( theta, state0, inputs[t, :], extras, dstate1) theta: weights. A `.NestedMap`. state0: initial state. A `.NestedMap`. inputs: inputs. A `.NestedMap`. extras: A `.NestedMap` of Tensors. The 2nd return value of every invocation of cell_fn is a `.NestedMap` with matching keys and shapes of this 'extras'. implicit_captures: A `.NestedMap` corresponding to implicit captures of the cell_fn. If empty/None, implicit captures are either not present or disallowed. unused_acc_state: If None, we assume every field of acc_state is consumed in the following timestamps. If True, None of the acc_state is consumed. And we reduce_sum each timestep's new state into a scalar. Note, this feature should be used with StackedRecurrent where we send out the new state to the other devices. """ self._theta = theta self._state = state0 self._inputs = inputs self._cell_fn = cell_fn self._cell_grad = cell_grad self._extras = extras self._implicit_captures = implicit_captures self._unused_acc_state = unused_acc_state if self._implicit_captures is None: self._implicit_captures = _EmptyCaptures() # pylint: disable=unbalanced-tuple-unpacking # NOTE: TF Function (Fwd, Bak, ForwardLoopBody, BackwardLoopBody, # Forward and Backward defined below) simply takes a list of # Tensors and returns a list of Tensors. When we pass in a # structure (a list of NestedMap of Tensors), we use _Flatten to # convert the structure into a list of tensor. Conversely, the # following code often uses _Pack to formulate a structure from a # list of tensors based on a "template". # Wraps cell_fn in a TF Function: # state1 = cell_fn(theta, state0, inputs) fwd_sig = [self._theta, self._state, self._inputs] compiled = py_utils.use_tpu() noinline = not compiled dev_t_type = tf.int32 if py_utils.use_tpu() else tf.int64 @function.Defun(*_Dtypes(fwd_sig)) def Fwd(*args): (theta, state0, inputs) = _Pack(args, fwd_sig) state1, extras = self._cell_fn(theta, state0, inputs) _AssertIsCompatible(state1, self._state) _AssertIsCompatible(extras, self._extras) return _Flatten([state1, extras]) # Wraps cell_fn in a TF Function as a for-loop's body. # # The loop state is composed of: # t: The loop variable. Timestep id. # dev_t: The loop variable mirrored on the device. # theta: the recurrent net's weights. # state0: the previous recurrent state. # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t. # acc_state: Each timestep's computed new state is also stashed into # acc_state. # acc_extras: Each timestep's computed extras is stashed into acc_extras fwdloop_sig = [ self._theta, self._state, self._inputs, self._state, self._extras ] @function.Defun(tf.int32, dev_t_type, *_Dtypes(fwdloop_sig)) def ForwardLoopBody(*args): """The body of forward loop.""" t, dev_t = args[0], args[1] (theta, state0, inputs, acc_state, acc_extras) = _Pack( args[2:], fwdloop_sig) inputs_t = _Index(inputs, t) # external input at time step t. state1, extras = _Pack( Fwd(*_Flatten([theta, state0, inputs_t])), [self._state, self._extras]) # Saves state1 and extras in their accumulators. if not self._unused_acc_state: acc_state = _Update(acc_state, state1, dev_t) acc_extras = _Update(acc_extras, extras, dev_t) return [tf.add(dev_t, 1)] + _Flatten( [theta, state1, inputs, acc_state, acc_extras]) def Grad(op, *args): """The python grad function for the Forward function. Flowchart: +------------------------------------------------------------+ | Backward() DEFUN -> [d_fwd..., acc_extras, dcaptured] | | | | | v | | For(BackwardLoopBody()) | | | | | v | | BackwardLoopBody() DEFUN -> | | ..., d_theta, d_state0, d_inputs, | | d_acc_state, d_captured | | | | | v | | Bak(..., inputs[t], extras[t]) DEFUN -> | | d_theta_t, d_state0, d_inputs_t, d_captured_t | | | | | v | | CellGrad(theta, state0, inputs, extras, d_state1) -> | | dtheta, dstate0, dinputs, dcaptured | | | +------------------------------------------------------------+ The key thing is that this function must return a dx value for each of the inputs to the Fwd function (theta, state0, inputs, captured...). The tricky part is that implicitly captured inputs are carried through function boundaries implicitly by the function call as the last arguments. When assembling gradients, we must account for these implicit captures even though they are not passed explicitly from function to function. Args: op: The forward operation. *args: Args to the forward operation (includes implicit captures). Returns: Tuple of derivitives. Raises: ValueError: on argument mismatch issues. """ expected_num_inputs = 0 for nmap in [ self._theta, self._state, self._inputs, self._extras, # Implicit captured tensors always come last self._implicit_captures ]: expected_num_inputs += len(nmap.Flatten()) if len(op.inputs) != expected_num_inputs: if len(op.inputs) > expected_num_inputs: raise ValueError( ('Too many inputs. The most likely cause is that cell_fn ' 'captures additional tensors: extra inputs %r vs captures %r') % (list(op.inputs), list(self._implicit_captures.Flatten()))) raise ValueError( ('Mismatched inputs to cell fn: Found %d vs expected %d: %r' '. Implicit captures(%d) = %r') % (len(op.inputs), expected_num_inputs, list(op.inputs), len(self._implicit_captures.Flatten()), self._implicit_captures)) # NOTE: tf.gradient backprops None for int32/int64 while zeros # for float32/float64. For consistency, we always backprop # zeros. args = list(args) for i, dy in enumerate(args): if dy is None: args[i] = tf.zeros_like(op.outputs[i]) (theta, state0, inputs, _, unused_captured) = _Pack( [x for x in op.inputs], [ self._theta, self._state, self._inputs, self._extras, # Implicit captured tensors always come last self._implicit_captures, ]) # acc_state and acc_extras are computed by the Forward pass and # needed by the Backward pass. acc_state, _, acc_extras = _Pack([x for x in op.outputs], [self._state, self._state, self._extras]) # Forward computes acc_state, the final state and # acc_extras. tf.gradients gives us their gradients w.r.t. the # final loss. Because acc_extras are not exposed by Compute(), # it has no gradients w.r.t. the final loss (i.e., by # construction, it must be zeros). d_acc_state, d_state1, _ = _Pack(args, [self._state, self._state, self._extras]) if self._unused_acc_state: # XLA While op requires the same shape for the init and carry on values. state0 = state0.Transform(tf.reduce_sum) d_state1 = d_state1.Transform(tf.reduce_sum) return Backward(*_Flatten([ theta, state0, inputs, acc_state, acc_extras, d_acc_state, d_state1, ])) # Forward calls ForwardLoopBody n times. Each time computes one # time step of the recurrent net. forward_sig = [self._theta, self._state, self._inputs, self._extras] @function.Defun( *_Dtypes(forward_sig), python_grad_func=Grad, noinline=noinline) def Forward(*args): """Forward pass of the recurrent net.""" theta, state0, inputs, extras = _Pack(args, forward_sig) # The sequence length. pad_begin, pad_end = _SeqPaddingLength(inputs) slen_dim = _SeqLenDim(inputs) # Creates accumulators for state0 and extras. if self._unused_acc_state: acc_state = _EmptyWithFixShape([slen_dim], state0) else: acc_state = _EmptyAcc(slen_dim, state0) acc_extras = _EmptyAcc(slen_dim, extras) if py_utils.use_tpu(): dev_t = tf.to_int32(pad_begin) else: dev_t = tf.to_int64(pad_begin) run = functional_ops.For( start=pad_begin, limit=slen_dim - pad_end, delta=1, inputs=[dev_t] + _Flatten( [theta, state0, inputs, acc_state, acc_extras]), body=ForwardLoopBody, rewrite_with_while=compiled) _, state1, _, acc_state, acc_extras = _Pack( run[1:], [self._theta, self._state, self._inputs, self._state, self._extras]) return _Flatten([acc_state, state1, acc_extras]) # The per-step backward computes: # d_theta, d_state0, d_inputs = cell_grad( # theta, state0, inputs, extras, d_state1) # where d_state1 is the backprop-ed gradient for state1, and # extras is the computed by the forward step to facilitate the # backward step. bak_sig = [ self._theta, self._state, self._inputs, self._extras, self._state, ] @function.Defun(*_Dtypes(bak_sig)) def Bak(*args): """Backward step.""" (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig) (dtheta, dstate0, dinputs, dcaptures) = self._cell_grad( theta, state0, inputs, extras, d_state1) _AssertIsCompatible(dtheta, self._theta) _AssertIsCompatible(dstate0, self._state) _AssertIsCompatible(dinputs, self._inputs) if dcaptures is None: # NOTE: Custom gradient fns can return None if they do not support # captured tensors. The return value is reserved for the future when # that may be supported. dcaptures = _EmptyLike(self._implicit_captures) _AssertIsCompatible(dcaptures, self._implicit_captures) # Make sure this function didn't capture anything different than the # cell_fn when reflected on at the beginning. Must come after the call # to cell_grad() which adds to the captured list. _AssertSameTensors(function.get_extra_inputs(), self._implicit_captures.Flatten()) (captured,) = _Pack(function.get_extra_args(), [self._implicit_captures]) return _Flatten( _ConvertNoneGradientToZeros([theta, state0, inputs, captured], [dtheta, dstate0, dinputs, dcaptures])) # Define defuns used by a functional.if in BackwardLoopBody. state_if_sig = [self._state, self._state] @function.Defun(*_Dtypes(state_if_sig)) def ReturnOrigState0(*args): """Returns original state0 from inputs.""" (_, orig_state0) = _Pack(args, state_if_sig) return orig_state0.Flatten() @function.Defun(*_Dtypes(state_if_sig)) def ReturnAccState(*args): """Returns acc_state[t-1] from inputs.""" (acc_state, _) = _Pack(args, state_if_sig) return acc_state.Flatten() # Wraps cell_grad gradient function in a TF Function as a # for-loop's body for the Backward pass. # # The loop state is composed of: # t: The loop variable. Timestep id. # state0: the initial state for the entire backward loop. # dev_t: The loop variable mirrored on the device. # theta: the recurrent net's weights. # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t. # acc_state: Each timestep's computed new state was stashed into # acc_state by the Forward pass. # acc_extras: Each timestep's computed extras was stashed into # acc_extras by the Forward pass. # d_theta: All timestep's gradient for theta is accumulated (added) into # d_theta. # d_state1: The backprop-ed gradient for the new stated computed by # timestep t. # d_inputs: d_inputs[t, :] is populated by the backward time step t. # d_acc_state: The backprop-ed gradient for acc_state. # d_captured: All timestep's gradient for theta is accumulated (added) # into d_captured. bakloop_sig = [ self._theta, self._state, self._inputs, self._state, self._extras, # End of forward params self._theta, self._state, self._inputs, self._state, self._implicit_captures, ] @function.Defun(tf.int32, dev_t_type, *_Dtypes(bakloop_sig)) def BackwardLoopBody(*args): """Backward loop body function.""" t, dev_t = args[0], args[1] ( theta, orig_state0, inputs, acc_state, acc_extras, # End of forward params d_theta, d_state1, d_inputs, d_acc_state, d_captured) = ( _Pack(args[2:], bakloop_sig)) # The input recurrent state for time step t is previous time step's # output, or the original state0 when on time step 0. state_from_acc = _Index(acc_state, tf.maximum(0, t - 1)) state0 = functional_ops.If( tf.equal(t, tf.constant(0, tf.int32)), _Flatten([state_from_acc, orig_state0]), ReturnOrigState0, ReturnAccState) state0 = orig_state0.Pack(state0) # The external inputs for time step t. inputs_t = _Index(inputs, t) # The extras for time step t. extras_t = _Index(acc_extras, t) d_state1 = _Add(_Index(d_acc_state, t), d_state1) (d_theta_t, d_state0, d_inputs_t, d_captured_t) = _Pack( Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])), [self._theta, self._state, self._inputs, self._implicit_captures]) if self._unused_acc_state: # XLA IF op requires the same shape for if and else branches. d_state0 = d_state0.Transform(tf.reduce_sum) d_theta = _Add(d_theta, d_theta_t) d_inputs = _Update(d_inputs, d_inputs_t, dev_t) d_captured = _Add(d_captured, d_captured_t) # Make sure this function didn't capture anything different than the # cell_fn when reflected on at the beginning. Must come after the call # to Bak() which adds to the captured list. _AssertSameTensors(function.get_extra_inputs(), self._implicit_captures.Flatten()) return [tf.subtract(dev_t, 1)] + _Flatten([ theta, orig_state0, inputs, acc_state, acc_extras, # End of forward params d_theta, d_state0, d_inputs, d_acc_state, d_captured, ]) # Backward calls BackwardLoopBody n times. Each time computes the backprop # for one time step of the recurrent net. backward_sig = [ self._theta, self._state, self._inputs, self._state, self._extras, # End of forward params. self._state, self._state, ] @function.Defun(*_Dtypes(backward_sig), noinline=noinline) def Backward(*args): """Backward pass for the recurrent net.""" # theta, state0, inputs are Forward's inputs. # acc_state is the accumulated 1st output of Forward. # acc_extras is the accumulated 2nd output of Forward. # d_acc_state is the gradient for acc_state. # d_state1 is the gradient for the final state computed by Forward. (theta, state0, inputs, acc_state, acc_extras, d_acc_state, d_state1) = _Pack(args, backward_sig) # Accumulators for gradients. d_theta = _EmptyLike(theta) d_inputs = _EmptyLike(inputs) d_captured = _EmptyLike(self._implicit_captures) # The sequence length. pad_begin, pad_end = _SeqPaddingLength(inputs) start = _SeqLenDim(inputs) - pad_end - 1 if py_utils.use_tpu(): dev_t = tf.to_int32(start) else: dev_t = tf.to_int64(start) run = functional_ops.For( start=start, limit=pad_begin - 1, delta=-1, inputs=[dev_t] + _Flatten([ theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1, d_inputs, d_acc_state, d_captured, ]), body=BackwardLoopBody, rewrite_with_while=compiled) (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0, d_inputs, d_acc_state, d_captured) = _Pack(run[1:], bakloop_sig) # Make sure this function didn't capture anything different than the # cell_fn when reflected on at the beginning. Must come after the # call to BackwardLoopBody, which adds to the captured list. _AssertSameTensors(function.get_extra_inputs(), self._implicit_captures.Flatten()) if self._unused_acc_state: # Match the shape of gradient of the init_state. d_state0 = self._state.Transform(tf.zeros_like) return _Flatten([d_theta, d_state0, d_inputs, acc_extras, d_captured]) self._forward = Forward
def ComputeMetrics(self, decoder_outs, input_batch, ids_to_strings_fn): """Computes metrics on output from decoder. Args: decoder_outs: A `BeamSearchDecodeOutput`, a namedtuple containing the decode results. input_batch: A `NestedMap` of tensors representing the source, target, and other components of the input batch. ids_to_strings_fn: a function of (ids, lens) -> strings, where ids has shape [batch, length], lens has shape [batch], and strings has shape [batch]. Returns: A dict of Tensors containing decoder output and metrics. """ topk = self.GetTopK(decoder_outs, ids_to_strings_fn=ids_to_strings_fn) tgt_batch = tf.shape(topk.scores)[0] num_hyps_per_beam = tf.shape(topk.scores)[1] tgt = input_batch.tgt tgt_lens = tf.cast(tf.round(tf.reduce_sum(1.0 - tgt.paddings, 1)), tf.int32) tgt_lens = py_utils.HasShape(tgt_lens, [tgt_batch]) transcripts = ids_to_strings_fn(tgt.labels, tgt_lens - 1) # Filter out all isolated '<noise>' tokens. noise_pattern = ' <noise> |^<noise> | <noise>$|^<noise>$' filtered_refs = tf.strings.regex_replace(transcripts, noise_pattern, ' ') filtered_hyps = tf.strings.regex_replace(topk.decoded, noise_pattern, ' ') # Compute translation quality scores for all hyps. filtered_refs = tf.tile(tf.reshape(filtered_refs, [-1, 1]), [1, num_hyps_per_beam]) filtered_hyps = tf.reshape(filtered_hyps, [-1]) filtered_refs = tf.reshape(filtered_refs, [-1]) tf.logging.info('filtered_refs=%s', filtered_refs) norm_wer_errors, norm_wer_words = self.ComputeNormalizedWER( filtered_hyps, filtered_refs, num_hyps_per_beam) ret_dict = { 'target_ids': tgt.ids, 'target_labels': tgt.labels, 'target_weights': tgt.weights, 'target_paddings': tgt.paddings, 'transcripts': transcripts, 'topk_decoded': topk.decoded, 'topk_ids': topk.ids, 'topk_lens': topk.lens, 'topk_scores': topk.scores, 'norm_wer_errors': norm_wer_errors, 'norm_wer_words': norm_wer_words, } if not py_utils.use_tpu() and 'sample_ids' in input_batch: ret_dict['utt_id'] = input_batch.sample_ids ret_dict.update( self.AddAdditionalDecoderMetricsToGraph(topk, filtered_hyps, filtered_refs, input_batch, decoder_outs)) return ret_dict
def PostProcessDecodeOut(self, dec_out_dict, dec_metrics_dict): p = self.params assert 'topk_scores' in dec_out_dict, dec_out_dict.keys() topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] if not py_utils.use_tpu(): utt_id = dec_out_dict['utt_id'] assert len(utt_id) == len(transcripts) norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert (len(topk_ids) == p.decoder.beam_search.num_hyps_per_beam * len(transcripts)) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) dec_metrics_dict['num_samples_in_batch'].Update(len(transcripts)) def GetRefIds(ref_ids, ref_paddinds): assert len(ref_ids) == len(ref_paddinds) return_ids = [] for i in range(len(ref_ids)): if ref_paddinds[i] == 0: return_ids.append(ref_ids[i]) return return_ids total_errs = 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_norm_wer_errs = 0 total_norm_wer_words = 0 total_accurate_sentences = 0 key_value_pairs = [] for i in range(len(transcripts)): ref_str = transcripts[i] if not py_utils.use_tpu(): tf.logging.info('utt_id: %s', utt_id[i]) tf.logging.info(' ref_str: %s', ref_str) hyps = topk_decoded[i] ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * p.decoder.beam_search.num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs assert p.decoder.beam_search.num_hyps_per_beam == len(hyps) filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): tf.logging.info(' %f: %s', score, hyp_str) filtered_hyp = decoder_utils.FilterNoise(hyp_str) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) ins, subs, dels, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) # Note that these numbers are not consistent with what is used to # compute normalized WER. In particular, these numbers will be inflated # when the transcript contains punctuation. tf.logging.info(' ins: %d, subs: %d, del: %d, total: %d', ins, subs, dels, errs) hyp_norm_wer_errors = norm_wer_errors[i][n] hyp_norm_wer_words = norm_wer_words[i][n] # Only aggregate scores of the top hypothesis. if n == 0: total_errs += errs total_ref_words += len( decoder_utils.Tokenize(filtered_ref)) total_norm_wer_errs += hyp_norm_wer_errors if hyp_norm_wer_errors == 0: total_accurate_sentences += 1 total_norm_wer_words += hyp_norm_wer_words dec_metrics_dict['corpus_bleu'].Update( filtered_ref, filtered_hyp) if hyp_norm_wer_errors < oracle_errs: oracle_errs = hyp_norm_wer_errors total_oracle_errs += oracle_errs dec_metrics_dict['wer'].Update(total_errs / total_ref_words, total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / total_ref_words, total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) dec_metrics_dict['ter'].Update(total_token_errs / total_ref_tokens, total_ref_tokens) # Update any additional metrics. dec_metrics_dict = self.UpdateAdditionalMetrics( dec_out_dict, dec_metrics_dict) return key_value_pairs
def FProp(self, theta, batch, state0=None): """Encodes source as represented by 'inputs' and 'paddings'. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. batch: A NestedMap with fields: - src_inputs - The inputs tensor. It is expected to be of shape [batch, time, feature_dim, channels]. - paddings - The paddings tensor. It is expected to be of shape [batch, time]. state0: Recurrent input state. Not supported/ignored by this encoder. Returns: A NestedMap containing - 'encoded': a feature tensor of shape [time, batch, depth] - 'padding': a 0/1 tensor of shape [time, batch] - 'state': the updated recurrent state - '${layer_type}_${layer_index}': The per-layer encoder output. Each one is a NestedMap containing 'encoded' and 'padding' similar to regular final outputs, except that 'encoded' from conv or conv_lstm layers are of shape [time, batch, depth, channels]. """ p = self.params inputs, paddings = batch.src_inputs, batch.paddings outputs = py_utils.NestedMap() with tf.name_scope(p.name): # Adding specAugmentation. if p.use_specaugment and not self.do_eval: inputs, paddings = self.specaugment.FProp( theta.specaugment, inputs, paddings) # Add a few extra padded timesteps at the end. This is for ensuring the # correctness of the conv-layers at the edges. if p.pad_steps > 0: # inplace_update() is not supported by TPU for now. Since we have done # padding on the input_generator, we may avoid this additional padding. assert not py_utils.use_tpu() inputs_pad = tf.zeros( inplace_ops.inplace_update(tf.shape(inputs), 1, p.pad_steps), inputs.dtype) paddings_pad = tf.ones( inplace_ops.inplace_update(tf.shape(paddings), 1, p.pad_steps), paddings.dtype) inputs = tf.concat([inputs, inputs_pad], 1, name='inputs') paddings = tf.concat([paddings, paddings_pad], 1) plots = [ summary_utils.PrepareSequenceForPlot( tf.transpose(inputs, [0, 1, 3, 2]), paddings, 'inputs') ] conv_out = inputs out_padding = paddings for i, conv_layer in enumerate(self.conv): conv_out, out_padding = conv_layer.FProp( theta.conv[i], conv_out, out_padding) if p.extra_per_layer_outputs: conv_out *= (1.0 - out_padding[:, :, tf.newaxis, tf.newaxis]) outputs['conv_%d' % i] = py_utils.NestedMap( encoded=tf.transpose(conv_out, [1, 0, 2, 3]), # to [t, b, d, c] padding=tf.transpose(out_padding)) plots.append( summary_utils.PrepareSequenceForPlot( tf.transpose(conv_out, [0, 1, 3, 2]), out_padding, 'conv_%d_out' % i)) def TransposeFirstTwoDims(t): first_dim = tf.shape(t)[0] second_dim = tf.shape(t)[1] t_new = tf.transpose( tf.reshape(t, [first_dim, second_dim, -1]), [1, 0, 2]) t_shape_new = tf.concat([[second_dim], [first_dim], tf.shape(t)[2:]], 0) return tf.reshape(t_new, t_shape_new) # Now the conv-lstm part. conv_lstm_out = conv_out conv_lstm_out_padding = out_padding for i, (rnn, cnn) in enumerate( zip(self.conv_lstm_rnn, self.conv_lstm_cnn)): conv_lstm_in = conv_lstm_out # Move time dimension to be the first. conv_lstm_in = TransposeFirstTwoDims(conv_lstm_in) conv_lstm_in = tf.expand_dims(conv_lstm_in, 2) conv_lstm_in_padding = tf.expand_dims( tf.transpose(conv_lstm_out_padding), 2) lstm_out = rnn.FProp(theta.conv_lstm_rnn[i], conv_lstm_in, conv_lstm_in_padding) # Move time dimension to be the second. cnn_in = TransposeFirstTwoDims(lstm_out) cnn_in = tf.squeeze(cnn_in, 2) cnn_in_padding = conv_lstm_out_padding cnn_out, cnn_out_padding = cnn.FProp(theta.conv_lstm_cnn[i], cnn_in, cnn_in_padding) conv_lstm_out, conv_lstm_out_padding = cnn_out, cnn_out_padding if p.extra_per_layer_outputs: conv_lstm_out *= ( 1.0 - conv_lstm_out_padding[:, :, tf.newaxis, tf.newaxis]) outputs['conv_lstm_%d' % i] = py_utils.NestedMap( encoded=tf.transpose(conv_lstm_out, [1, 0, 2, 3]), # to [t, b, d, c] padding=tf.transpose(conv_lstm_out_padding)) plots.append( summary_utils.PrepareSequenceForPlot( conv_lstm_out, conv_lstm_out_padding, 'conv_lstm_%d_out' % i)) # Need to do a reshape before starting the rnn layers. conv_lstm_out = py_utils.HasRank(conv_lstm_out, 4) conv_lstm_out_shape = tf.shape(conv_lstm_out) new_shape = tf.concat([conv_lstm_out_shape[:2], [-1]], 0) conv_lstm_out = tf.reshape(conv_lstm_out, new_shape) if self._first_lstm_input_dim_pad: conv_lstm_out = tf.pad( conv_lstm_out, [[0, 0], [0, 0], [0, self._first_lstm_input_dim_pad]]) conv_lstm_out = py_utils.HasShape( conv_lstm_out, [-1, -1, self._first_lstm_input_dim]) # Transpose to move the time dimension to be the first. rnn_in = tf.transpose(conv_lstm_out, [1, 0, 2]) rnn_padding = tf.expand_dims(tf.transpose(conv_lstm_out_padding), 2) # rnn_in is of shape [time, batch, depth] # rnn_padding is of shape [time, batch, 1] # Now the rnn layers. num_skips = 0 for i in range(p.num_lstm_layers): rnn_out = self.rnn[i].FProp(theta.rnn[i], rnn_in, rnn_padding) residual_index = i - p.residual_start + 1 if p.residual_start > 0 and residual_index >= 0: if residual_index % p.residual_stride == 0: residual_in = rnn_in if residual_index % p.residual_stride == p.residual_stride - 1: # Highway skip connection. if p.highway_skip: rnn_out = self.highway_skip[num_skips].FProp( theta.highway_skip[num_skips], residual_in, rnn_out) num_skips += 1 else: # Residual skip connection. rnn_out += py_utils.HasShape( residual_in, tf.shape(rnn_out)) if p.project_lstm_output and (i < p.num_lstm_layers - 1): # Projection layers. rnn_out = self.proj[i].FProp(theta.proj[i], rnn_out, rnn_padding) if i == p.num_lstm_layers - 1: rnn_out *= (1.0 - rnn_padding) if p.extra_per_layer_outputs: rnn_out *= (1.0 - rnn_padding) outputs['rnn_%d' % i] = py_utils.NestedMap( encoded=rnn_out, padding=tf.squeeze(rnn_padding, [2])) # Stacking layer connection. if p.layer_index_before_stacking == i: # Stacking layer expects input tensor shape as [batch, time, feature]. # So transpose the tensors before and after the layer. rnn_out, rnn_padding = self.stacking.FProp( tf.transpose(rnn_out, [1, 0, 2]), tf.transpose(rnn_padding, [1, 0, 2])) rnn_out = tf.transpose(rnn_out, [1, 0, 2]) rnn_padding = tf.transpose(rnn_padding, [1, 0, 2]) plots.append( summary_utils.PrepareSequenceForPlot( tf.transpose(rnn_out, [1, 0, 2]), tf.transpose(rnn_padding, [1, 0, 2]), 'rnn_%d_out' % i)) rnn_in = rnn_out final_out = rnn_in summary_utils.PlotSequenceFeatures(list(reversed(plots)), 'encoder_example', xlabel='Time') outputs['encoded'] = final_out outputs['padding'] = tf.squeeze(rnn_padding, [2]) outputs['state'] = py_utils.NestedMap() return outputs
def PostProcess(self, dec_out_dict, dec_metrics_dict): p = self.params assert 'topk_scores' in dec_out_dict, list(dec_out_dict.keys()) topk_scores = dec_out_dict['topk_scores'] topk_decoded = dec_out_dict['topk_decoded'] transcripts = dec_out_dict['transcripts'] if not py_utils.use_tpu(): utt_id = dec_out_dict['utt_id'] assert len(utt_id) == len(transcripts) norm_wer_errors = dec_out_dict['norm_wer_errors'] norm_wer_words = dec_out_dict['norm_wer_words'] target_labels = dec_out_dict['target_labels'] target_paddings = dec_out_dict['target_paddings'] topk_ids = dec_out_dict['topk_ids'] topk_lens = dec_out_dict['topk_lens'] if 'example_weights' in dec_out_dict: example_weights = dec_out_dict['example_weights'] else: example_weights = np.ones([len(transcripts)], np.float32) assert len(transcripts) == len(target_labels) assert len(transcripts) == len(target_paddings) assert len(transcripts) == len(topk_decoded) assert len(norm_wer_errors) == len(transcripts) assert len(norm_wer_words) == len(transcripts) num_samples_in_batch = example_weights.sum() dec_metrics_dict['num_samples_in_batch'].Update(num_samples_in_batch) def GetRefIds(ref_ids, ref_paddinds): assert len(ref_ids) == len(ref_paddinds) return_ids = [] for i in range(len(ref_ids)): if ref_paddinds[i] == 0: return_ids.append(ref_ids[i]) return return_ids total_norm_wer_errs = (norm_wer_errors[:, 0] * example_weights).sum() total_norm_wer_words = (norm_wer_words[:, 0] * example_weights).sum() dec_metrics_dict['norm_wer'].Update( total_norm_wer_errs / total_norm_wer_words, total_norm_wer_words) filtered_transcripts = [] filtered_top_hyps = [] for ref_str, hyps in zip(transcripts, topk_decoded): filtered_ref = decoder_utils.FilterNoise(ref_str) filtered_ref = decoder_utils.FilterEpsilon(filtered_ref) filtered_transcripts.append(filtered_ref) filtered_hyp = decoder_utils.FilterNoise(hyps[0]) filtered_hyp = decoder_utils.FilterEpsilon(filtered_hyp) filtered_top_hyps.append(filtered_hyp) dec_metrics_dict['corpus_bleu'].Update(filtered_ref, filtered_hyp) total_errs = 0 total_oracle_errs = 0 total_ref_words = 0 total_token_errs = 0 total_ref_tokens = 0 total_accurate_sentences = 0 key_value_pairs = [] if p.include_auxiliary_metrics: for i in range(len(transcripts)): ref_str = transcripts[i] if not py_utils.use_tpu(): tf.logging.info('utt_id: %s', utt_id[i]) if self.cluster.add_summary: tf.logging.info( ' ref_str: %s', ref_str.decode('utf-8') if p.log_utf8 else ref_str) hyps = topk_decoded[i] num_hyps_per_beam = len(hyps) ref_ids = GetRefIds(target_labels[i], target_paddings[i]) hyp_index = i * num_hyps_per_beam top_hyp_ids = topk_ids[hyp_index][:topk_lens[hyp_index]] if self.cluster.add_summary: tf.logging.info(' ref_ids: %s', ref_ids) tf.logging.info(' top_hyp_ids: %s', top_hyp_ids) total_ref_tokens += len(ref_ids) _, _, _, token_errs = decoder_utils.EditDistanceInIds( ref_ids, top_hyp_ids) total_token_errs += token_errs filtered_ref = filtered_transcripts[i] oracle_errs = norm_wer_errors[i][0] for n, (score, hyp_str) in enumerate(zip(topk_scores[i], hyps)): oracle_errs = min(oracle_errs, norm_wer_errors[i, n]) if self.cluster.add_summary: tf.logging.info( ' %f: %s', score, hyp_str.decode('utf-8') if p.log_utf8 else hyp_str) # Only aggregate scores of the top hypothesis. if n != 0: continue filtered_hyp = filtered_top_hyps[i] _, _, _, errs = decoder_utils.EditDistance( filtered_ref, filtered_hyp) total_errs += errs total_ref_words += len( decoder_utils.Tokenize(filtered_ref)) if norm_wer_errors[i, n] == 0: total_accurate_sentences += 1 total_oracle_errs += oracle_errs dec_metrics_dict['wer'].Update( total_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['oracle_norm_wer'].Update( total_oracle_errs / max(1., total_ref_words), total_ref_words) dec_metrics_dict['sacc'].Update( total_accurate_sentences / len(transcripts), len(transcripts)) dec_metrics_dict['ter'].Update( total_token_errs / max(1., total_ref_tokens), total_ref_tokens) return key_value_pairs
def Task(self): p = feature_neighborhood_model_trans.FeatureNeighborhoodModelTrans.Params() if self._share_embeddings: output_symbol_path = FLAGS.input_symbols else: output_symbol_path = FLAGS.output_symbols _, p.input_symbols, p.output_symbols = ( fn.FeatureNeighborhoodInput.ParameterizedConfigs( input_symbol_path=FLAGS.input_symbols, output_symbol_path=output_symbol_path, append_eos=FLAGS.append_eos, max_spelling_len=FLAGS.max_spelling_len, max_pronunciation_len=FLAGS.max_pronunciation_len, max_neighbors=FLAGS.max_neighbors)) p.input_vocab_size = p.input_symbols.num_symbols() p.output_vocab_size = p.output_symbols.num_symbols() p.max_neighbors = FLAGS.max_neighbors p.max_pronunciation_len = FLAGS.max_pronunciation_len p.max_spelling_len = FLAGS.max_spelling_len p.start = p.output_symbols.find("<s>") p.share_embeddings = self._share_embeddings if self._share_embeddings: vocab_size = p.input_vocab_size else: vocab_size = p.output_vocab_size p = base_config.SetupTransformerParams( p, name="feature_neighborhood_with_neighbors", vocab_size=vocab_size, model_dim=p.embedding_dim, hidden_dim=p.enc_units, num_heads=self._num_heads, num_layers=self._num_layers, learning_rate=3.0, warmup_steps=40000, residual_dropout_prob=self._residual_dropout_prob, relu_dropout_prob=self._relu_dropout_prob, input_dropout_prob=self._input_dropout_prob, atten_dropout_prob=self._atten_dropout_prob, label_smoothing_uncertainty=self._label_smoothing_uncertainty) if not self._share_embeddings: p.encoder.token_emb.vocab_size = p.input_vocab_size p.eval.samples_per_summary = 20000 # TODO(llion): Might need to change the output vocab size to one that can # be sharded to run efficiently on TPUs. p.decoder.softmax.num_shards = 1 p.decoder.target_seq_len = p.max_pronunciation_len if py_utils.use_tpu(): p.decoder.beam_search = model_helper.ChangeToBeamSearchTpuHelper( p.decoder.beam_search) if FLAGS.neigh_use_tpu: for pp in [p.encoder, p.decoder]: pp.token_emb = model_helper.ChangeToSimpleEmbedding(pp.token_emb) p.decoder.softmax = model_helper.ChangeToSimpleSoftmax(p.decoder.softmax) p.use_neighbors = self._use_neighbors if self._use_neighbors: p.spell_encoder = base_config.SetupTransformerEncoder( vocab_size=p.input_vocab_size, model_dim=p.embedding_dim, hidden_dim=p.enc_units, num_heads=self._num_heads, num_layers=self._num_layers, residual_dropout_prob=self._residual_dropout_prob, relu_dropout_prob=self._relu_dropout_prob, input_dropout_prob=self._input_dropout_prob, atten_dropout_prob=self._atten_dropout_prob) if self._attention_type != "CONCATAVE": p.pron_encoder = base_config.SetupTransformerEncoder( vocab_size=p.output_vocab_size, model_dim=p.embedding_dim, hidden_dim=p.enc_units, num_heads=self._num_heads, num_layers=self._num_layers, residual_dropout_prob=self._residual_dropout_prob, relu_dropout_prob=self._relu_dropout_prob, input_dropout_prob=self._input_dropout_prob, atten_dropout_prob=self._atten_dropout_prob) else: if not self._share_embeddings: raise ValueError("Must share embeddings to concat spelling and pron.") if FLAGS.neigh_use_tpu: for pp in [p.spell_encoder, p.pron_encoder]: if pp: pp.token_emb = model_helper.ChangeToSimpleEmbedding(pp.token_emb) p.also_shuffle_neighbors = self._also_shuffle_neighbors if self._use_neigh_id_emb: assert self._use_neighbors p.use_neigh_id_emb = True if self._attention_type == "CONCAT": neigh_id_emb = layers.EmbeddingLayer.Params().Set( vocab_size=FLAGS.max_neighbors + 1, # +1 to include the main input embedding_dim=p.embedding_dim, max_num_shards=1, params_init=py_utils.WeightInit.Gaussian( 1.0 / maths.sqrt(p.embedding_dim)), scale_sqrt_depth=True) p.encoder.task_emb = neigh_id_emb elif self._attention_type == "AVERAGE": neigh_id_emb = layers.EmbeddingLayer.Params().Set( vocab_size=FLAGS.max_neighbors, embedding_dim=p.embedding_dim, max_num_shards=1, params_init=py_utils.WeightInit.Gaussian( 1.0 / maths.sqrt(p.embedding_dim)), scale_sqrt_depth=True) p.spell_encoder.task_emb = neigh_id_emb p.pron_encoder.task_emb = neigh_id_emb p.neigh_att_type = self._attention_type p.aux_dropout_prob = self._aux_dropout_prob return p
def GetProjectLastDim(cls, inputs, weight, input_dim, output_dim, proj_obj): """Linear projection on the last dim of the input tensor along with pruning. This is a TPU efficient implementation to avoid reshaping inputs to Rank-2 tensor by using Einsum for the compute. Args: inputs: An input Tensor, the last dimension of which is input_dim. weight: A weight matrix with shape [input_dim, output_dim]. input_dim: An integer or a symbolic dim, the last dimension of the inputs. output_dim: An integer or a symbolic dim, the last dimension of the outputs. proj_obj: a ProjectionLayer object. Returns: An output Tensor of the same rank as inputs, the last dimension is output_dim. """ theta = proj_obj.theta p = proj_obj.params input_dim = int( symbolic.ToStatic(input_dim) if symbolic.IsExpr(input_dim ) else input_dim) output_dim = int( symbolic.ToStatic(output_dim) if symbolic.IsExpr(output_dim ) else output_dim) if (py_utils.use_tpu() and inputs.shape is not None and inputs.shape.rank is not None and inputs.shape.rank < 26): # Avoids reshape if feasible and uses Einsum. if inputs.shape.rank == 2: outputs = tf.matmul(inputs, weight) else: outputs = cls.GetEinSumResult(inputs, proj_obj) else: if p.pruning_hparams_dict[ 'compression_option'] == 9 and p.pruning_hparams_dict[ 'compress_input']: blocked_inputs = tf.reshape( inputs, py_utils.ToStaticShape( [-1, p.pruning_hparams_dict['input_block_size']])) compressed_inputs = tf.reshape( py_utils.Matmul(blocked_inputs, theta.b_matrix_tfvar), py_utils.ToStaticShape([ -1, input_dim // p.pruning_hparams_dict['input_compression_factor'] ])) else: compressed_inputs = tf.reshape( inputs, py_utils.ToStaticShape([-1, input_dim])) if p.pruning_hparams_dict['compression_option'] == 10: if p.pruning_hparams_dict['block_method'] == 'mask': intermediate_result = py_utils.Matmul( compressed_inputs, tf.multiply(theta.c_matrix_tfvar, theta.c_mask_tfvar)) elif p.pruning_hparams_dict['block_method'] == 'loop': num_blocks = p.pruning_hparams_dict[ 'block_compression_factor'] input_splitted = tf.split(compressed_inputs, num_blocks, axis=-1) output_splitted = [] for i, input_i in enumerate(input_splitted): output_splitted.append( py_utils.Matmul(input_i, theta.c_matrix_tfvar[i, :, :])) intermediate_result = tf.concat(output_splitted, axis=-1) else: intermediate_result = py_utils.Matmul(compressed_inputs, theta.c_matrix_tfvar) if p.pruning_hparams_dict[ 'compression_option'] == 9 and p.pruning_hparams_dict[ 'compress_output']: blocked_intermediate_result = tf.reshape( intermediate_result, py_utils.ToStaticShape([ -1, p.pruning_hparams_dict['output_block_size'] // p.pruning_hparams_dict['output_compression_factor'] ])) outputs = py_utils.Matmul(blocked_intermediate_result, theta.d_matrix_tfvar) else: outputs = intermediate_result outputs = tf.reshape( outputs, tf.concat([ tf.cast(py_utils.GetShape(inputs)[:-1], tf.int32), py_utils.ToStaticShape([output_dim]) ], axis=0)) return outputs
def _DecoderDevice(self): """Returns the device to run the decoder computation.""" if py_utils.use_tpu(): return tf.device(self.cluster.WorkerDeviceInModelSplit(1)) else: return tf.device('')
def __init__(self, params): params.pad_to_max_seq_length = True params.fixed_input_shape = params.fixed_input_shape or py_utils.use_tpu( ) super().__init__(params)