コード例 #1
0
  def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False):
    """Computes mean and variance over the valid data points in inputs."""
    inputs = py_utils.with_dependencies([
        py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
        py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
    ], inputs)
    rank = tf.rank(mask)
    reduce_over_dims = tf.range(0, rank - 1)
    sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                          reduce_over_dims)
    count_v = tf.reduce_sum(mask, reduce_over_dims)
    # Input shape is guaranteed to be a multiple of mask shape because the
    # inputs * mask op above was successfully broadcasted.
    mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1]
    count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype)
    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
      sum_v = tf.tpu.cross_replica_sum(sum_v)
      count_v = tf.tpu.cross_replica_sum(count_v)

    count_v = tf.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                           reduce_over_dims)

    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
      sum_vv = tf.tpu.cross_replica_sum(sum_vv)

    variance = py_utils.with_dependencies([
        py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
    ], sum_vv / count_v)
    return mean, variance
コード例 #2
0
    def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret
コード例 #3
0
    def SplitInputBatch(self, num_splits):
        """Splits the current InputBatch into num_splits ways.

    Args:
      num_splits: The number of splits.

    Returns:
      A list of `.NestedMap`. Each `.NestedMap` represents the input
      tensors in one split.
    """
        assert num_splits >= 1

        batch = self.GetPreprocessedInputBatch()
        if num_splits == 1:
            # Special case. No split is needed.
            return [batch]

        assert not py_utils.use_tpu()
        field_split = ig_helper.SplitTensors(batch.Flatten(), num_splits)
        num_fields = len(field_split)
        ret = []
        for j in range(num_splits):
            split_flatten = [field_split[i][j] for i in range(num_fields)]
            split = batch.Pack(split_flatten)
            ret += [split]
        return ret
コード例 #4
0
    def _LoopEnqueue(self, op, session_override=None):
        """Runs the enqueue op in a loop."""
        p = self.params
        sess = session_override or self._GetSession()

        with tf.container(self._container_id), sess:
            if self._initialize_tables is not None:
                sess.run(self._initialize_tables)
            gsteps = py_utils.GetGlobalStep()
            local_enqueue_steps = 0

            # Global enqueue steps measures how many global steps have data enqueued
            # for already. We use this to terminate; note that the enqueue op may
            # hang in session.run if we do not terminate with this check.
            global_enqueue_steps = None

            tf.logging.info(
                'params.train.max_steps: %d, enqueue_max_steps: %d',
                p.train.max_steps, p.train.enqueue_max_steps)
            while True:
                if self._dequeue_thread_complete:
                    tf.logging.info(
                        'LoopEnqueue done since consuming thread is done.')
                    return

                global_step = sess.run(gsteps)
                if global_enqueue_steps is None:
                    global_enqueue_steps = global_step
                if local_enqueue_steps % 1000 == 0:
                    tf.logging.info(
                        'Current global_enqueue_steps: %d, '
                        'local_enqueue_steps: %d, global_step: %d',
                        global_enqueue_steps, local_enqueue_steps, global_step)

                if py_utils.use_tpu():
                    global_steps_with_available_data = int(
                        global_enqueue_steps // p.train.tpu_steps_per_loop *
                        p.train.tpu_steps_per_loop)
                else:
                    global_steps_with_available_data = global_enqueue_steps

                if (self._ShouldStop(sess, global_steps_with_available_data)
                        or self._ShouldStop(sess, global_step)):
                    tf.logging.info('Done. ShouldStop is True.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                if (p.train.enqueue_max_steps > 0
                        and local_enqueue_steps >= p.train.enqueue_max_steps):
                    tf.logging.info('Done. train.enqueue_max_steps reached.')
                    tf.logging.info('Enqueue loop sleeping')
                    time.sleep(15)
                    continue
                local_enqueue_steps += 1

                # There are tpu_infeed_parallelism parallel threads enqueuing.
                # We account for all of them when updating global_enqueue_steps.
                global_enqueue_steps += p.input.tpu_infeed_parallelism

                sess.run([op])
コード例 #5
0
  def _ProcessSingleInput(self, source_id, src, tgt):
    """Performs strings-to-ids on the given input pair via p.tokenizer_dict."""
    _, src_labels, src_paddings = self.StringsToIds(
        tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key)
    tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds(
        tf.reshape(tgt, [1]), is_source=False, key=self._tgt_tokenizer_key)
    # Mask positions to 0 where padding is 1 for consistency. We do this because
    # tokenizer implementation may use EOS token to pad.
    src_labels = py_utils.ApplyPadding(src_paddings, src_labels)
    tgt_ids = py_utils.ApplyPadding(tgt_paddings, tgt_ids)
    tgt_labels = py_utils.ApplyPadding(tgt_paddings, tgt_labels)

    features = py_utils.NestedMap()
    features.src = py_utils.NestedMap()
    features.src.ids = src_labels
    # ids_indicator is 1 if and only if the output from tokenizer has a
    # non-padded id. Unlike weights, it will not mutate and can be used for
    # determining actual sequence length, for example.
    features.src.ids_indicator = 1 - src_paddings
    features.tgt = py_utils.NestedMap()
    features.tgt.ids = tgt_ids
    features.tgt.labels = tgt_labels
    features.tgt.ids_indicator = 1 - tgt_paddings

    src_task_id, tgt_task_id = self._GetTaskIds(source_id)
    # task_ids are padded with zeros.
    features.src.task_ids = tf.cast(
        features.src.ids_indicator, dtype=tf.int32) * src_task_id
    features.tgt.task_ids = tf.cast(
        features.tgt.ids_indicator, dtype=tf.int32) * tgt_task_id

    if not py_utils.use_tpu():
      features.src.strs = src
      features.tgt.strs = tgt
    return features.Transform(tf.squeeze)
コード例 #6
0
 def __init__(self, params):
   super(SeqLayer, self).__init__(params)
   p = self.params
   assert p.name
   num_cells = len(p.cell_tpl)
   self._before_layers = []
   self._cells = []
   before_tpl_device = ''
   cell_devices = [''] * num_cells
   if py_utils.use_tpu():
     cluster = self.cluster
     before_tpl_device = cluster.WorkerDeviceInModelSplit(0)
     cell_devices = [
         cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells)
     ]
   for l in p.before_tpl:
     with tf.device(before_tpl_device):
       assert l.name
       self.CreateChild(l.name, l)
       self._before_layers.append((l.name, self.children[l.name]))
   for i, l in enumerate(p.cell_tpl):
     with tf.device(cell_devices[i]):
       assert l.name
       self.CreateChild(l.name, l)
       self._cells.append((l.name, self.children[l.name]))
コード例 #7
0
 def _Moments(self, inputs, group_size):
   """Computes mean and variance over N,H,W dimensions in inputs."""
   counts, mean_ss, variance_ss, _, = tf.nn.sufficient_statistics(
       inputs, axes=[0, 1, 2], keepdims=False)
   self.accumulators.counts.Update(counts)
   self.accumulators.mean_ss.Update(mean_ss)
   self.accumulators.variance_ss.Update(variance_ss)
   # Distributed batch norm that computes sufficient statistics from group_size
   # replicas. This is useful when batch_size_per_replica is too small to
   # compute reliable sufficient statistics.
   if py_utils.use_tpu() and group_size > 1:
     group_assignment = None
     num_shards = tpu_function.get_tpu_context().number_of_shards
     if num_shards is not None:
       if num_shards < group_size:
         raise ValueError('TPU shards={} less than bn_gropu_size={}.'.format(
             num_shards, group_size))
       if num_shards % group_size:
         raise ValueError(
             'TPU shards={} not divisible by bn_group_size={}.'.format(
                 num_shards, group_size))
       num_groups = num_shards // group_size
       group_assignment = []
       for g in range(num_groups):
         replica_ids = [g * group_size + i for i in range(group_size)]
         group_assignment.append(replica_ids)
       counts *= group_size
     mean_ss = tf.tpu.cross_replica_sum(mean_ss, group_assignment)
     variance_ss = tf.tpu.cross_replica_sum(variance_ss, group_assignment)
   # At each micro-step, batch_mean and batch_variance are computed
   # to normalize inputs. But they are not used to update moving_mean and
   # moving_variance variables until the last micro batch.
   mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None)
   return mean, variance
コード例 #8
0
 def GlobalBatchSize(self):
     """Returns the total batch size (for stats), int or dynamic int tensor."""
     p = self.params
     global_batch_size = self.InfeedBatchSize()
     cluster = self.cluster
     if p.use_per_host_infeed and cluster.num_tpu_hosts > 0:
         if not py_utils.use_tpu():
             raise ValueError(
                 'Scaling to TPU hosts without TPUs. {}'.format(
                     cluster.num_tpu_hosts))
         global_batch_size *= cluster.num_tpu_hosts
     tf.logging.info('GlobalBatchSize {}'.format(global_batch_size))
     return global_batch_size
コード例 #9
0
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None):
  """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep."""
  seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64
  if p.is_inference and p.random_seed is None:
    # Unlike tf.random*, stateless random ops are completely determined by the
    # passed-in seeds. This means at inference time the same inputs will produce
    # the same outputs, even if the model is supposed to have randomness such as
    # dropout during inference. We inject additional randomness only during
    # inference if the graph is exported with random_seed=None as a workaround.
    return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype)

  with tf.name_scope('op_seed') as scope:
    global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype)
    step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype)
    seeds = tf.stack([global_step, step_seed])

    if p.random_seed is not None:
      seeds += p.random_seed
    if op_seed is not None:
      seeds += op_seed
    return seeds
コード例 #10
0
  def _ProcessMASSInput(self, source_id, src):
    """Perform MASS input processing."""
    # TODO(yuancao): By doing so we assume that right now for monolingual
    # eval/dev sets (xx->xx) are in double-column format (since it bypasses
    # the Mass op). Ideally we should add a dedicated eval/dev processing
    # procedure for unsupervised MT cases, so that single-column eval/devs sets
    # are also supported. This should not be handled by any specific ops like
    # Mass, but inside the TextPackedInput class.
    assert not self.do_eval, 'MASS input can only be used for training.'

    _, labels, paddings = self.StringsToIds(
        tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key)
    weights = 1 - paddings
    actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32)
    src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id)

    mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len)

    features = py_utils.NestedMap()
    features.src = py_utils.NestedMap()
    features.src.ids = mass_out.src.ids
    features.src.paddings = paddings
    features.src.weights = weights
    features.src.task_ids = tf.cast(
        features.src.weights, dtype=tf.int32) * src_lang_ids
    features.src.ids_indicator = weights
    features.tgt = py_utils.NestedMap()
    features.tgt.ids = mass_out.tgt.ids
    features.tgt.labels = mass_out.tgt.labels
    features.tgt.paddings = paddings
    features.tgt.weights = mass_out.tgt.weights
    features.tgt.task_ids = tf.ones_like(
        features.src.task_ids, dtype=tf.int32) * tgt_lang_ids
    features.tgt.ids_indicator = weights

    if not py_utils.use_tpu():
      features.src.strs = src
      features.tgt.strs = src
    return features.Transform(tf.squeeze)
コード例 #11
0
 def infeed_bucket_batch_limit(self):
     """Returns the bucket batch limit for one infeed host."""
     p = self.params
     cluster = self.cluster
     infeed_bucket_batch_limit = [
         b * cluster.num_splits_per_client for b in p.bucket_batch_limit
     ]
     if p.use_per_host_infeed and cluster.num_tpu_hosts > 0:
         if not py_utils.use_tpu():
             raise ValueError(
                 'Scaling to TPU hosts without TPUs. {}'.format(
                     cluster.num_tpu_hosts))
         tf.logging.info(
             'scaling infeed_bucket_batch_limit num_tpu_hosts={}'.format(
                 cluster.num_tpu_hosts))
         infeed_bucket_batch_limit = [
             x // cluster.num_tpu_hosts for x in infeed_bucket_batch_limit
         ]
     tf.logging.info(
         'infeed_bucket_batch_limit={} num_splits_per_client={} bucket_batch_limit={}'
         .format(infeed_bucket_batch_limit, cluster.num_splits_per_client,
                 p.bucket_batch_limit))
     return infeed_bucket_batch_limit
コード例 #12
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].
        - task_ids: If p.task_emb is provided, must contain per-token task
            ids of shape [batch, time].

    Returns:
      A NestedMap containing

      - encoded: The encoded features, either a tensor of shape
        [time, batch, depth], or a list of tensors if is_transparent is set in
        transformer_stack.
      - padding: of shape [time, batch]
      - segment_id: [time, batch] if packed inputs are supported by the model
        (and all layers), or None otherwise.
      - embedded_inputs: [time, batch, depth] embedded inputs tokens without
        positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            if not p.shared_emb:
                input_embs = self.token_emb.EmbLookup(
                    theta.token_emb, tf.reshape(input_ids, [-1]))
            else:
                input_embs = self.softmax.EmbLookup(
                    theta.softmax, tf.reshape(input_ids, [-1]))

            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs
            if p.task_emb:
                input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                      input_batch.task_ids)

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p))
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        if not self.do_eval and p.apply_source_mask:
            # Augment padding for masked source word positions.
            dtype = paddings.dtype
            source_mask = tf.where(tf.equal(input_ids, p.source_mask_id),
                                   tf.ones_like(input_ids, dtype=dtype),
                                   tf.zeros_like(input_ids, dtype=dtype))
            # Make sure padding is between 0 and 1.
            paddings = tf.clip_by_value(paddings + tf.transpose(source_mask),
                                        0.0, 1.0)

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
コード例 #13
0
  def _CommonParams(self, packed_input):
    p = base_config.SetupTransformerBatchMajorParams(
        name='en_de_ml_perf_transformer',
        vocab_size=self.VOCAB_SIZE,
        model_dim=self.MODEL_DIM,
        hidden_dim=self.HIDDEN_DIM,
        num_heads=self.NUM_HEADS,
        num_layers=6,
        inference_source_language='en',
        inference_target_language='de',
        input_dropout_prob=0.1,
        residual_dropout_prob=0.1,
        atten_dropout_prob=0.1,
        relu_dropout_prob=0.1,
        add_unnormalized_residuals=True,
        learning_rate=self.LEARNING_RATE,
        warmup_steps=self.WARMUP_STEPS,
        packed_input=packed_input,
        use_fast_projection_layer=True,
        enable_per_dim_scale=False,
        use_fused_layernorm=True,
        use_bf16_activations=py_utils.use_tpu(),
        use_bias=False,
        xla_num_partitions=self.XLA_NUM_PARTITIONS)

    for pp in [p.encoder, p.decoder]:
      pp.token_emb = model_helper.ChangeToSimpleEmbedding(pp.token_emb)

    p.decoder.softmax = model_helper.ChangeToSimpleSoftmax(p.decoder.softmax)

    sm_params = p.decoder.softmax.Copy()
    sm_params.input_dim = self.MODEL_DIM
    shared_emb = layers.SharedSoftmaxLayer.Params().Set(
        **dict(sm_params.IterParams()))
    shared_emb.params_init = py_utils.WeightInit.Gaussian(
        1.0 / math.sqrt(self.MODEL_DIM))
    shared_emb.scale_sqrt_depth = True
    shared_emb.use_num_classes_major_weight = True
    p.decoder.shared_emb = shared_emb
    p.decoder.shared_emb.cls = layers.SharedSoftmaxLayer
    # Directly sharing encoder embedding with decoder results in worse model
    # quality, which requires more tuning.
    p.encoder.shared_emb = shared_emb
    p.encoder.shared_emb.cls = layers.SharedSoftmaxLayer

    p.train.lr_schedule = schedule.TransformerMLPerfSchedule.Params().Set(
        warmup_steps=self.WARMUP_STEPS, model_dim=self.MODEL_DIM)
    p.train.max_steps = self.MAX_STEPS
    p.train.scale_gradients = False

    p.train.optimizer.beta1 = self.ADAM_BETA1
    p.train.optimizer.beta2 = self.ADAM_BETA2

    # Fix this
    #p.eval.ml_perf_metrics_only = True

    p.decoder.beam_search.target_sos_id = self.ID_SOS
    p.decoder.beam_search.target_eos_id = self.ID_EOS

    p.decoder.beam_search.beam_size = 4.0
    p.decoder.beam_search.num_hyps_per_beam = 4

    p.decoder.target_sos_id = self.ID_SOS
    p.decoder.target_eos_id = self.ID_EOS

    p.decoder.use_fast_softmax = True

    p.decoder.target_seq_len = 147

    if py_utils.use_tpu():
      p.encoder.input_dropout_tpl.fprop_dtype = tf.bfloat16
      p.decoder.trans_decoder_tpl.fprop_dtype = tf.bfloat16
      p.decoder.input_dropout_tpl.fprop_dtype = tf.bfloat16
      p.train.optimizer.use_bf16_gradients_ar = True
    return p