def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False): """Computes mean and variance over the valid data points in inputs.""" inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) rank = tf.rank(mask) reduce_over_dims = tf.range(0, rank - 1) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims) count_v = tf.reduce_sum(mask, reduce_over_dims) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1] count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def _InputBatch(self): p = self.params @tf.function def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.cast(x, tf.float32), tf.cast(y, tf.float32) # Loads data and label into memory and keep it around. data, label = ops.cached_call(f=ReadData.get_concrete_function(), T=[tf.float32, tf.float32]) b, shape = self.InfeedBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret
def SplitInputBatch(self, num_splits): """Splits the current InputBatch into num_splits ways. Args: num_splits: The number of splits. Returns: A list of `.NestedMap`. Each `.NestedMap` represents the input tensors in one split. """ assert num_splits >= 1 batch = self.GetPreprocessedInputBatch() if num_splits == 1: # Special case. No split is needed. return [batch] assert not py_utils.use_tpu() field_split = ig_helper.SplitTensors(batch.Flatten(), num_splits) num_fields = len(field_split) ret = [] for j in range(num_splits): split_flatten = [field_split[i][j] for i in range(num_fields)] split = batch.Pack(split_flatten) ret += [split] return ret
def _LoopEnqueue(self, op, session_override=None): """Runs the enqueue op in a loop.""" p = self.params sess = session_override or self._GetSession() with tf.container(self._container_id), sess: if self._initialize_tables is not None: sess.run(self._initialize_tables) gsteps = py_utils.GetGlobalStep() local_enqueue_steps = 0 # Global enqueue steps measures how many global steps have data enqueued # for already. We use this to terminate; note that the enqueue op may # hang in session.run if we do not terminate with this check. global_enqueue_steps = None tf.logging.info( 'params.train.max_steps: %d, enqueue_max_steps: %d', p.train.max_steps, p.train.enqueue_max_steps) while True: if self._dequeue_thread_complete: tf.logging.info( 'LoopEnqueue done since consuming thread is done.') return global_step = sess.run(gsteps) if global_enqueue_steps is None: global_enqueue_steps = global_step if local_enqueue_steps % 1000 == 0: tf.logging.info( 'Current global_enqueue_steps: %d, ' 'local_enqueue_steps: %d, global_step: %d', global_enqueue_steps, local_enqueue_steps, global_step) if py_utils.use_tpu(): global_steps_with_available_data = int( global_enqueue_steps // p.train.tpu_steps_per_loop * p.train.tpu_steps_per_loop) else: global_steps_with_available_data = global_enqueue_steps if (self._ShouldStop(sess, global_steps_with_available_data) or self._ShouldStop(sess, global_step)): tf.logging.info('Done. ShouldStop is True.') tf.logging.info('Enqueue loop sleeping') time.sleep(15) continue if (p.train.enqueue_max_steps > 0 and local_enqueue_steps >= p.train.enqueue_max_steps): tf.logging.info('Done. train.enqueue_max_steps reached.') tf.logging.info('Enqueue loop sleeping') time.sleep(15) continue local_enqueue_steps += 1 # There are tpu_infeed_parallelism parallel threads enqueuing. # We account for all of them when updating global_enqueue_steps. global_enqueue_steps += p.input.tpu_infeed_parallelism sess.run([op])
def _ProcessSingleInput(self, source_id, src, tgt): """Performs strings-to-ids on the given input pair via p.tokenizer_dict.""" _, src_labels, src_paddings = self.StringsToIds( tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds( tf.reshape(tgt, [1]), is_source=False, key=self._tgt_tokenizer_key) # Mask positions to 0 where padding is 1 for consistency. We do this because # tokenizer implementation may use EOS token to pad. src_labels = py_utils.ApplyPadding(src_paddings, src_labels) tgt_ids = py_utils.ApplyPadding(tgt_paddings, tgt_ids) tgt_labels = py_utils.ApplyPadding(tgt_paddings, tgt_labels) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = src_labels # ids_indicator is 1 if and only if the output from tokenizer has a # non-padded id. Unlike weights, it will not mutate and can be used for # determining actual sequence length, for example. features.src.ids_indicator = 1 - src_paddings features.tgt = py_utils.NestedMap() features.tgt.ids = tgt_ids features.tgt.labels = tgt_labels features.tgt.ids_indicator = 1 - tgt_paddings src_task_id, tgt_task_id = self._GetTaskIds(source_id) # task_ids are padded with zeros. features.src.task_ids = tf.cast( features.src.ids_indicator, dtype=tf.int32) * src_task_id features.tgt.task_ids = tf.cast( features.tgt.ids_indicator, dtype=tf.int32) * tgt_task_id if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = tgt return features.Transform(tf.squeeze)
def __init__(self, params): super(SeqLayer, self).__init__(params) p = self.params assert p.name num_cells = len(p.cell_tpl) self._before_layers = [] self._cells = [] before_tpl_device = '' cell_devices = [''] * num_cells if py_utils.use_tpu(): cluster = self.cluster before_tpl_device = cluster.WorkerDeviceInModelSplit(0) cell_devices = [ cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells) ] for l in p.before_tpl: with tf.device(before_tpl_device): assert l.name self.CreateChild(l.name, l) self._before_layers.append((l.name, self.children[l.name])) for i, l in enumerate(p.cell_tpl): with tf.device(cell_devices[i]): assert l.name self.CreateChild(l.name, l) self._cells.append((l.name, self.children[l.name]))
def _Moments(self, inputs, group_size): """Computes mean and variance over N,H,W dimensions in inputs.""" counts, mean_ss, variance_ss, _, = tf.nn.sufficient_statistics( inputs, axes=[0, 1, 2], keepdims=False) self.accumulators.counts.Update(counts) self.accumulators.mean_ss.Update(mean_ss) self.accumulators.variance_ss.Update(variance_ss) # Distributed batch norm that computes sufficient statistics from group_size # replicas. This is useful when batch_size_per_replica is too small to # compute reliable sufficient statistics. if py_utils.use_tpu() and group_size > 1: group_assignment = None num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is not None: if num_shards < group_size: raise ValueError('TPU shards={} less than bn_gropu_size={}.'.format( num_shards, group_size)) if num_shards % group_size: raise ValueError( 'TPU shards={} not divisible by bn_group_size={}.'.format( num_shards, group_size)) num_groups = num_shards // group_size group_assignment = [] for g in range(num_groups): replica_ids = [g * group_size + i for i in range(group_size)] group_assignment.append(replica_ids) counts *= group_size mean_ss = tf.tpu.cross_replica_sum(mean_ss, group_assignment) variance_ss = tf.tpu.cross_replica_sum(variance_ss, group_assignment) # At each micro-step, batch_mean and batch_variance are computed # to normalize inputs. But they are not used to update moving_mean and # moving_variance variables until the last micro batch. mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None) return mean, variance
def GlobalBatchSize(self): """Returns the total batch size (for stats), int or dynamic int tensor.""" p = self.params global_batch_size = self.InfeedBatchSize() cluster = self.cluster if p.use_per_host_infeed and cluster.num_tpu_hosts > 0: if not py_utils.use_tpu(): raise ValueError( 'Scaling to TPU hosts without TPUs. {}'.format( cluster.num_tpu_hosts)) global_batch_size *= cluster.num_tpu_hosts tf.logging.info('GlobalBatchSize {}'.format(global_batch_size)) return global_batch_size
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None): """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep.""" seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64 if p.is_inference and p.random_seed is None: # Unlike tf.random*, stateless random ops are completely determined by the # passed-in seeds. This means at inference time the same inputs will produce # the same outputs, even if the model is supposed to have randomness such as # dropout during inference. We inject additional randomness only during # inference if the graph is exported with random_seed=None as a workaround. return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype) with tf.name_scope('op_seed') as scope: global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype) step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype) seeds = tf.stack([global_step, step_seed]) if p.random_seed is not None: seeds += p.random_seed if op_seed is not None: seeds += op_seed return seeds
def _ProcessMASSInput(self, source_id, src): """Perform MASS input processing.""" # TODO(yuancao): By doing so we assume that right now for monolingual # eval/dev sets (xx->xx) are in double-column format (since it bypasses # the Mass op). Ideally we should add a dedicated eval/dev processing # procedure for unsupervised MT cases, so that single-column eval/devs sets # are also supported. This should not be handled by any specific ops like # Mass, but inside the TextPackedInput class. assert not self.do_eval, 'MASS input can only be used for training.' _, labels, paddings = self.StringsToIds( tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) weights = 1 - paddings actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32) src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id) mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = mass_out.src.ids features.src.paddings = paddings features.src.weights = weights features.src.task_ids = tf.cast( features.src.weights, dtype=tf.int32) * src_lang_ids features.src.ids_indicator = weights features.tgt = py_utils.NestedMap() features.tgt.ids = mass_out.tgt.ids features.tgt.labels = mass_out.tgt.labels features.tgt.paddings = paddings features.tgt.weights = mass_out.tgt.weights features.tgt.task_ids = tf.ones_like( features.src.task_ids, dtype=tf.int32) * tgt_lang_ids features.tgt.ids_indicator = weights if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = src return features.Transform(tf.squeeze)
def infeed_bucket_batch_limit(self): """Returns the bucket batch limit for one infeed host.""" p = self.params cluster = self.cluster infeed_bucket_batch_limit = [ b * cluster.num_splits_per_client for b in p.bucket_batch_limit ] if p.use_per_host_infeed and cluster.num_tpu_hosts > 0: if not py_utils.use_tpu(): raise ValueError( 'Scaling to TPU hosts without TPUs. {}'.format( cluster.num_tpu_hosts)) tf.logging.info( 'scaling infeed_bucket_batch_limit num_tpu_hosts={}'.format( cluster.num_tpu_hosts)) infeed_bucket_batch_limit = [ x // cluster.num_tpu_hosts for x in infeed_bucket_batch_limit ] tf.logging.info( 'infeed_bucket_batch_limit={} num_splits_per_client={} bucket_batch_limit={}' .format(infeed_bucket_batch_limit, cluster.num_splits_per_client, p.bucket_batch_limit)) return infeed_bucket_batch_limit
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def _CommonParams(self, packed_input): p = base_config.SetupTransformerBatchMajorParams( name='en_de_ml_perf_transformer', vocab_size=self.VOCAB_SIZE, model_dim=self.MODEL_DIM, hidden_dim=self.HIDDEN_DIM, num_heads=self.NUM_HEADS, num_layers=6, inference_source_language='en', inference_target_language='de', input_dropout_prob=0.1, residual_dropout_prob=0.1, atten_dropout_prob=0.1, relu_dropout_prob=0.1, add_unnormalized_residuals=True, learning_rate=self.LEARNING_RATE, warmup_steps=self.WARMUP_STEPS, packed_input=packed_input, use_fast_projection_layer=True, enable_per_dim_scale=False, use_fused_layernorm=True, use_bf16_activations=py_utils.use_tpu(), use_bias=False, xla_num_partitions=self.XLA_NUM_PARTITIONS) for pp in [p.encoder, p.decoder]: pp.token_emb = model_helper.ChangeToSimpleEmbedding(pp.token_emb) p.decoder.softmax = model_helper.ChangeToSimpleSoftmax(p.decoder.softmax) sm_params = p.decoder.softmax.Copy() sm_params.input_dim = self.MODEL_DIM shared_emb = layers.SharedSoftmaxLayer.Params().Set( **dict(sm_params.IterParams())) shared_emb.params_init = py_utils.WeightInit.Gaussian( 1.0 / math.sqrt(self.MODEL_DIM)) shared_emb.scale_sqrt_depth = True shared_emb.use_num_classes_major_weight = True p.decoder.shared_emb = shared_emb p.decoder.shared_emb.cls = layers.SharedSoftmaxLayer # Directly sharing encoder embedding with decoder results in worse model # quality, which requires more tuning. p.encoder.shared_emb = shared_emb p.encoder.shared_emb.cls = layers.SharedSoftmaxLayer p.train.lr_schedule = schedule.TransformerMLPerfSchedule.Params().Set( warmup_steps=self.WARMUP_STEPS, model_dim=self.MODEL_DIM) p.train.max_steps = self.MAX_STEPS p.train.scale_gradients = False p.train.optimizer.beta1 = self.ADAM_BETA1 p.train.optimizer.beta2 = self.ADAM_BETA2 # Fix this #p.eval.ml_perf_metrics_only = True p.decoder.beam_search.target_sos_id = self.ID_SOS p.decoder.beam_search.target_eos_id = self.ID_EOS p.decoder.beam_search.beam_size = 4.0 p.decoder.beam_search.num_hyps_per_beam = 4 p.decoder.target_sos_id = self.ID_SOS p.decoder.target_eos_id = self.ID_EOS p.decoder.use_fast_softmax = True p.decoder.target_seq_len = 147 if py_utils.use_tpu(): p.encoder.input_dropout_tpl.fprop_dtype = tf.bfloat16 p.decoder.trans_decoder_tpl.fprop_dtype = tf.bfloat16 p.decoder.input_dropout_tpl.fprop_dtype = tf.bfloat16 p.train.optimizer.use_bf16_gradients_ar = True return p