Пример #1
0
    def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False):
        """Computes mean and variance over the valid data points in inputs."""
        inputs = py_utils.with_dependencies([
            py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
            py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
        ], inputs)
        rank = tf.rank(mask)
        reduce_over_dims = tf.range(0, rank - 1)
        sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                              reduce_over_dims)
        count_v = tf.reduce_sum(mask, reduce_over_dims)
        # Input shape is guaranteed to be a multiple of mask shape because the
        # inputs * mask op above was successfully broadcasted.
        mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1]
        count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype)
        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_v = tf.tpu.cross_replica_sum(sum_v)
            count_v = tf.tpu.cross_replica_sum(count_v)

        count_v = tf.maximum(count_v, 1.0)
        mean = sum_v / count_v
        sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                               reduce_over_dims)

        if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
            sum_vv = tf.tpu.cross_replica_sum(sum_vv)

        variance = py_utils.with_dependencies([
            py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
        ], sum_vv / count_v)
        return mean, variance
Пример #2
0
 def __init__(self, params):
     super(SeqLayer, self).__init__(params)
     p = self.params
     assert p.name
     num_cells = len(p.cell_tpl)
     self._before_layers = []
     self._cells = []
     before_tpl_device = ''
     cell_devices = [''] * num_cells
     if py_utils.use_tpu():
         cluster = self.cluster
         before_tpl_device = cluster.WorkerDeviceInModelSplit(0)
         cell_devices = [
             cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells)
         ]
     for l in p.before_tpl:
         with tf.device(before_tpl_device):
             assert l.name
             self.CreateChild(l.name, l)
             self._before_layers.append((l.name, self.children[l.name]))
     for i, l in enumerate(p.cell_tpl):
         with tf.device(cell_devices[i]):
             assert l.name
             self.CreateChild(l.name, l)
             self._cells.append((l.name, self.children[l.name]))
Пример #3
0
    def _ProcessSingleInput(self, source_id, src, tgt):
        """Performs strings-to-ids on the given input pair via p.tokenizer_dict."""
        _, src_labels, src_paddings = self.StringsToIds(
            tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key)
        tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(
            tf.reshape(tgt, [1]), is_source=False, key=self._tgt_tokenizer_key)
        # Mask positions to 0 where padding is 1 for consistency. We do this because
        # tokenizer implementation may use EOS token to pad.
        src_labels = py_utils.ApplyPadding(src_paddings, src_labels)
        tgt_ids = py_utils.ApplyPadding(tgt_paddings, tgt_ids)
        tgt_labels = py_utils.ApplyPadding(tgt_paddings, tgt_labels)

        features = py_utils.NestedMap()
        features.src = py_utils.NestedMap()
        features.src.ids = src_labels
        # ids_indicator is 1 if and only if the output from tokenizer has a
        # non-padded id. Unlike weights, it will not mutate and can be used for
        # determining actual sequence length, for example.
        features.src.ids_indicator = 1 - src_paddings
        features.tgt = py_utils.NestedMap()
        features.tgt.ids = tgt_ids
        features.tgt.labels = tgt_labels
        features.tgt.ids_indicator = 1 - tgt_paddings

        src_task_id, tgt_task_id = self._GetTaskIds(source_id)
        # task_ids are padded with zeros.
        features.src.task_ids = tf.cast(features.src.ids_indicator,
                                        dtype=tf.int32) * src_task_id
        features.tgt.task_ids = tf.cast(features.tgt.ids_indicator,
                                        dtype=tf.int32) * tgt_task_id

        if not py_utils.use_tpu():
            features.src.strs = src
            features.tgt.strs = tgt
        return features.Transform(tf.squeeze)
Пример #4
0
    def _LoopEnqueue(self, op, session_override=None):
        """Runs the enqueue op in a loop."""
        p = self.params
        sess = session_override or self._GetSession()

        with tf.container(self._container_id), sess:
            if self._initialize_tables is not None:
                sess.run(self._initialize_tables)
            gsteps = py_utils.GetGlobalStep()
            local_enqueue_steps = 0

            # Global enqueue steps measures how many global steps have data enqueued
            # for already. We use this to terminate; note that the enqueue op may
            # hang in session.run if we do not terminate with this check.
            global_enqueue_steps = None

            tf.logging.info(
                'params.train.max_steps: %d, enqueue_max_steps: %d',
                p.train.max_steps, p.train.enqueue_max_steps)
            while True:
                if self._dequeue_thread_complete:
                    tf.logging.info(
                        'LoopEnqueue done since consuming thread is done.')
                    return

                global_step = sess.run(gsteps)
                if global_enqueue_steps is None:
                    global_enqueue_steps = global_step
                if local_enqueue_steps % 1000 == 0:
                    tf.logging.info(
                        'Current global_enqueue_steps: %d, '
                        'local_enqueue_steps: %d, global_step: %d',
                        global_enqueue_steps, local_enqueue_steps, global_step)

                if py_utils.use_tpu():
                    global_steps_with_available_data = int(
                        global_enqueue_steps // p.train.tpu_steps_per_loop *
                        p.train.tpu_steps_per_loop)
                else:
                    global_steps_with_available_data = global_enqueue_steps

                if (self._ShouldStop(sess, global_steps_with_available_data)
                        or self._ShouldStop(sess, global_step)):
                    tf.logging.info('Done. ShouldStop is True.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                if (p.train.enqueue_max_steps > 0
                        and local_enqueue_steps >= p.train.enqueue_max_steps):
                    tf.logging.info('Done. train.enqueue_max_steps reached.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                local_enqueue_steps += 1

                # There are tpu_infeed_parallelism parallel threads enqueuing.
                # We account for all of them when updating global_enqueue_steps.
                global_enqueue_steps += p.input.tpu_infeed_parallelism

                sess.run([op])
    def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret
    def SplitInputBatch(self, num_splits):
        """Splits the current InputBatch into num_splits ways.

    Args:
      num_splits: The number of splits.

    Returns:
      A list of `.NestedMap`. Each `.NestedMap` represents the input
      tensors in one split.
    """
        assert num_splits >= 1

        batch = self.GetPreprocessedInputBatch()
        if num_splits == 1:
            # Special case. No split is needed.
            return [batch]

        assert not py_utils.use_tpu()
        field_split = ig_helper.SplitTensors(batch.Flatten(), num_splits)
        num_fields = len(field_split)
        ret = []
        for j in range(num_splits):
            split_flatten = [field_split[i][j] for i in range(num_fields)]
            split = batch.Pack(split_flatten)
            ret += [split]
        return ret
 def GlobalBatchSize(self):
     """Returns the total batch size (for stats), int or dynamic int tensor."""
     p = self.params
     global_batch_size = self.InfeedBatchSize()
     cluster = self.cluster
     if p.use_per_host_infeed and cluster.num_tpu_hosts > 0:
         if not py_utils.use_tpu():
             raise ValueError(
                 'Scaling to TPU hosts without TPUs. {}'.format(
                     cluster.num_tpu_hosts))
         global_batch_size *= cluster.num_tpu_hosts
     tf.logging.info('GlobalBatchSize {}'.format(global_batch_size))
     return global_batch_size
Пример #8
0
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None):
    """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep."""
    seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64
    if p.is_inference and p.random_seed is None:
        # Unlike tf.random*, stateless random ops are completely determined by the
        # passed-in seeds. This means at inference time the same inputs will produce
        # the same outputs, even if the model is supposed to have randomness such as
        # dropout during inference. We inject additional randomness only during
        # inference if the graph is exported with random_seed=None as a workaround.
        return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype)

    with tf.name_scope('op_seed') as scope:
        global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype)
        step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype)
        seeds = tf.stack([global_step, step_seed])

        if p.random_seed is not None:
            seeds += p.random_seed
        if op_seed is not None:
            seeds += op_seed
        return seeds
Пример #9
0
    def _ProcessMASSInput(self, source_id, src):
        """Perform MASS input processing."""
        # TODO(yuancao): By doing so we assume that right now for monolingual
        # eval/dev sets (xx->xx) are in double-column format (since it bypasses
        # the Mass op). Ideally we should add a dedicated eval/dev processing
        # procedure for unsupervised MT cases, so that single-column eval/devs sets
        # are also supported. This should not be handled by any specific ops like
        # Mass, but inside the TextPackedInput class.
        assert not self.do_eval, 'MASS input can only be used for training.'

        _, labels, paddings = self.StringsToIds(tf.reshape(src, [1]),
                                                is_source=True,
                                                key=self._src_tokenizer_key)
        weights = 1 - paddings
        actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32)
        src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id)

        mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len)

        features = py_utils.NestedMap()
        features.src = py_utils.NestedMap()
        features.src.ids = mass_out.src.ids
        features.src.paddings = paddings
        features.src.weights = weights
        features.src.task_ids = tf.cast(features.src.weights,
                                        dtype=tf.int32) * src_lang_ids
        features.src.ids_indicator = weights
        features.tgt = py_utils.NestedMap()
        features.tgt.ids = mass_out.tgt.ids
        features.tgt.labels = mass_out.tgt.labels
        features.tgt.paddings = paddings
        features.tgt.weights = mass_out.tgt.weights
        features.tgt.task_ids = tf.ones_like(features.src.task_ids,
                                             dtype=tf.int32) * tgt_lang_ids
        features.tgt.ids_indicator = weights

        if not py_utils.use_tpu():
            features.src.strs = src
            features.tgt.strs = src
        return features.Transform(tf.squeeze)
Пример #10
0
 def _Moments(self, inputs, group_size):
     """Computes mean and variance over N,H,W dimensions in inputs."""
     counts, mean_ss, variance_ss, _, = tf.nn.sufficient_statistics(
         inputs, axes=[0, 1, 2], keepdims=False)
     self.accumulators.counts.Update(counts)
     self.accumulators.mean_ss.Update(mean_ss)
     self.accumulators.variance_ss.Update(variance_ss)
     # Distributed batch norm that computes sufficient statistics from group_size
     # replicas. This is useful when batch_size_per_replica is too small to
     # compute reliable sufficient statistics.
     if py_utils.use_tpu() and group_size > 1:
         group_assignment = None
         num_shards = tpu_function.get_tpu_context().number_of_shards
         if num_shards is not None:
             if num_shards < group_size:
                 raise ValueError(
                     'TPU shards={} less than bn_gropu_size={}.'.format(
                         num_shards, group_size))
             if num_shards % group_size:
                 raise ValueError(
                     'TPU shards={} not divisible by bn_group_size={}.'.
                     format(num_shards, group_size))
             num_groups = num_shards // group_size
             group_assignment = []
             for g in range(num_groups):
                 replica_ids = [
                     g * group_size + i for i in range(group_size)
                 ]
                 group_assignment.append(replica_ids)
             counts *= group_size
         mean_ss = tf.tpu.cross_replica_sum(mean_ss, group_assignment)
         variance_ss = tf.tpu.cross_replica_sum(variance_ss,
                                                group_assignment)
     # At each micro-step, batch_mean and batch_variance are computed
     # to normalize inputs. But they are not used to update moving_mean and
     # moving_variance variables until the last micro batch.
     mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss,
                                              None)
     return mean, variance
 def infeed_bucket_batch_limit(self):
     """Returns the bucket batch limit for one infeed host."""
     p = self.params
     cluster = self.cluster
     infeed_bucket_batch_limit = [
         b * cluster.num_splits_per_client for b in p.bucket_batch_limit
     ]
     if p.use_per_host_infeed and cluster.num_tpu_hosts > 0:
         if not py_utils.use_tpu():
             raise ValueError(
                 'Scaling to TPU hosts without TPUs. {}'.format(
                     cluster.num_tpu_hosts))
         tf.logging.info(
             'scaling infeed_bucket_batch_limit num_tpu_hosts={}'.format(
                 cluster.num_tpu_hosts))
         infeed_bucket_batch_limit = [
             x // cluster.num_tpu_hosts for x in infeed_bucket_batch_limit
         ]
     tf.logging.info(
         'infeed_bucket_batch_limit={} num_splits_per_client={} bucket_batch_limit={}'
         .format(infeed_bucket_batch_limit, cluster.num_splits_per_client,
                 p.bucket_batch_limit))
     return infeed_bucket_batch_limit
Пример #12
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(
                    theta.token_emb, tf.reshape(input_ids, [-1]))
            else:
                input_embs = self.softmax.EmbLookup(
                    theta.softmax, tf.reshape(input_ids, [-1]))

            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs
            if p.task_emb:
                input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                      input_batch.task_ids)

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        if not self.do_eval and p.apply_source_mask:
            # Augment padding for masked source word positions.
            dtype = paddings.dtype
            source_mask = tf.where(tf.equal(input_ids, p.source_mask_id),
                                   tf.ones_like(input_ids, dtype=dtype),
                                   tf.zeros_like(input_ids, dtype=dtype))
            # Make sure padding is between 0 and 1.
            paddings = tf.clip_by_value(paddings + tf.transpose(source_mask),
                                        0.0, 1.0)

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Пример #13
0
    def _CommonParams(self, packed_input):
        p = base_config.SetupTransformerBatchMajorParams(
            name='en_de_ml_perf_transformer',
            vocab_size=self.VOCAB_SIZE,
            model_dim=self.MODEL_DIM,
            hidden_dim=self.HIDDEN_DIM,
            num_heads=self.NUM_HEADS,
            num_layers=6,
            inference_source_language='en',
            inference_target_language='de',
            input_dropout_prob=0.1,
            residual_dropout_prob=0.1,
            atten_dropout_prob=0.1,
            relu_dropout_prob=0.1,
            add_unnormalized_residuals=True,
            learning_rate=self.LEARNING_RATE,
            warmup_steps=self.WARMUP_STEPS,
            packed_input=packed_input,
            use_fast_projection_layer=True,
            enable_per_dim_scale=False,
            use_fused_layernorm=True,
            use_bf16_activations=py_utils.use_tpu(),
            use_bias=False,
            xla_num_partitions=self.XLA_NUM_PARTITIONS)

        for pp in [p.encoder, p.decoder]:
            pp.token_emb = model_helper.ChangeToSimpleEmbedding(pp.token_emb)

        p.decoder.softmax = model_helper.ChangeToSimpleSoftmax(
            p.decoder.softmax)

        sm_params = p.decoder.softmax.Copy()
        sm_params.input_dim = self.MODEL_DIM
        shared_emb = layers.SharedSoftmaxLayer.Params().Set(
            **dict(sm_params.IterParams()))
        shared_emb.params_init = py_utils.WeightInit.Gaussian(
            1.0 / math.sqrt(self.MODEL_DIM))
        shared_emb.scale_sqrt_depth = True
        shared_emb.use_num_classes_major_weight = True
        p.decoder.shared_emb = shared_emb
        p.decoder.shared_emb.cls = layers.SharedSoftmaxLayer
        # Directly sharing encoder embedding with decoder results in worse model
        # quality, which requires more tuning.
        p.encoder.shared_emb = shared_emb
        p.encoder.shared_emb.cls = layers.SharedSoftmaxLayer

        p.train.lr_schedule = schedule.TransformerMLPerfSchedule.Params().Set(
            warmup_steps=self.WARMUP_STEPS, model_dim=self.MODEL_DIM)
        p.train.max_steps = self.MAX_STEPS
        p.train.scale_gradients = False

        p.train.optimizer.beta1 = self.ADAM_BETA1
        p.train.optimizer.beta2 = self.ADAM_BETA2

        # Fix this
        #p.eval.ml_perf_metrics_only = True

        p.decoder.beam_search.target_sos_id = self.ID_SOS
        p.decoder.beam_search.target_eos_id = self.ID_EOS

        p.decoder.beam_search.beam_size = 4.0
        p.decoder.beam_search.num_hyps_per_beam = 4

        p.decoder.target_sos_id = self.ID_SOS
        p.decoder.target_eos_id = self.ID_EOS

        p.decoder.use_fast_softmax = True

        p.decoder.target_seq_len = 147

        if py_utils.use_tpu():
            p.encoder.input_dropout_tpl.fprop_dtype = tf.bfloat16
            p.decoder.trans_decoder_tpl.fprop_dtype = tf.bfloat16
            p.decoder.input_dropout_tpl.fprop_dtype = tf.bfloat16
            p.train.optimizer.use_bf16_gradients_ar = True
        return p