def CreateTpuEmbeddingEnqueueOps(self):
        """Creates the TpuEmbedding enqueue ops on the host.

    Note that this must be called after the instantiation of the
    monolithic TPUEmbeddingLayer.
    """
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING)
        tpu_embedding = (tpu_embedding_collection[0]
                         if tpu_embedding_collection else None)

        enqueue_ops = []

        if num_tpu_hosts > 1 and tpu_embedding is not None:
            if not p.use_per_host_infeed:
                tf.logging.fatal(
                    'TPU Embedding must be used with per_host_infeed with multiple '
                    'TPU host topologies.')
        tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys())
                              if tpu_embedding is not None else [])
        tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)
        if not tpu_embedding:
            return

        for task_id in range(num_infeed_hosts):
            host_device = '/task:{}/device:CPU:0'.format(task_id)
            with tf.device(host_device):
                if isinstance(self._batch, py_utils.NestedMap):
                    # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                    # Note that when MultiTaskData is used, bucket_keys will be at the
                    # second level of the dictionary.
                    self._batch = self._batch.FilterKeyVal(
                        lambda k, _: not k.endswith('bucket_keys'))
                tf.logging.info('host_device: %s, batch: %r', host_device,
                                self._batch)

                enqueue_dict_per_core = [
                    {} for _ in range(tpu_embedding.num_cores_per_host)
                ]
                num_cores_per_host = tpu_embedding.num_cores_per_host
                for key in tpu_emb_input_keys:
                    feat = self._batch[key]
                    tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host)
                    for core, split in enumerate(tpu_emb_feat_splitted):
                        # Dense to sparse. Note the assumption of a padding id.
                        sample_indices = tf.where(tf.not_equal(split, -1))
                        embedding_indices = tf.gather_nd(split, sample_indices)
                        enqueue_data = tpu_embedding_lib.EnqueueData(
                            embedding_indices, sample_indices)
                        enqueue_dict_per_core[core][key] = enqueue_data
                enqueue_ops += tpu_embedding.generate_enqueue_ops(
                    enqueue_dict_per_core)
        self._tpu_infeed_op.append(tf.group(*enqueue_ops))
示例#2
0
    def FProp(self, theta, inputs, paddings):
        """Builds FProp graph.

    Args:
      theta: A NestedMap of Tensors, see base class.
      inputs: A Tensor of shape [batch, seqlen, dim0].
      paddings: A Tensor of shape [batch, seqlen].

    Returns:
      output: A Tensor of shape [batch, seqlen, dim0].
      out_paddings: A Tensor of shape [batch, seqlen].
    """

        p = self.params
        with tf.name_scope(p.name):
            unnormalized_inputs = inputs

            inputs = self.ln.FProp(theta.ln, inputs)
            if p.split_act_gated_linear_start:
                act_inputs = self.linear_start_act.FProp(
                    theta.linear_start_act, inputs)
                gated_inputs = self.linear_start_gated.FProp(
                    theta.linear_start_gated, inputs)
            else:
                inputs = self.linear_start.FProp(theta.linear_start, inputs)
                gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1)
            inputs = self._GLU(gated_inputs, act_inputs)

            # TODO(jamesqin): inroduce depthwise conv2d with 3d inputs.
            # [b, t, d] --> [b, t, 1, d]
            inputs = tf.expand_dims(inputs, 2)
            adapted_blf_dims_mapping = None
            if p.activation_split_dims_mapping.blf is not None:
                adapted_blf_dims_mapping = p.activation_split_dims_mapping.blf.copy(
                )
                adapted_blf_dims_mapping.insert(2, -1)
            inputs = xla_sharding_utils.MeshSplit(inputs, p.device_mesh,
                                                  adapted_blf_dims_mapping)
            theta.depthwise_conv1d.w = xla_sharding_utils.MeshSplit(
                theta.depthwise_conv1d.w, p.device_mesh,
                p.weight_split_dims_mapping.hwim)
            inputs, paddings = self.depthwise_conv1d.FProp(
                theta.depthwise_conv1d, inputs, paddings)

            inputs = xla_sharding_utils.MeshSplit(inputs, p.device_mesh,
                                                  adapted_blf_dims_mapping)
            inputs = self._Normalize(theta, inputs, paddings)
            inputs = xla_sharding_utils.MeshSplit(
                inputs, p.device_mesh, p.activation_split_dims_mapping.blf)

            inputs = self._ApplyActivation(inputs, p.conv_activation)

            inputs = self.linear_end.FProp(theta.linear_end, inputs)
            inputs = self.dropout.FProp(theta.dropout, inputs)

            output = inputs + unnormalized_inputs
            return output, paddings
示例#3
0
def IsWithinBBox3D(points_3d, bboxes_3d):
    """Checks if points are within a 3-d bbox.

  Args:
    points_3d: [num_points, 3] float32 Tensor specifying points in 3-d space as
      [x, y, z] coordinates.
    bboxes_3d: [num_bboxes, 7] float32 Tensor specifying a 3-d bboxes specified
      as [x, y, z, dx, dy, dz, phi] where x, y and z is the center of the box.

  Returns:
    boolean Tensor of shape [num_points, num_bboxes] indicating whether the
    points belong within each box.
  """
    points_3d = py_utils.HasRank(points_3d, 2)
    points_3d = py_utils.HasShape(points_3d, [-1, 3])
    num_points, _ = py_utils.GetShape(points_3d, 2)

    bboxes_3d = py_utils.HasRank(bboxes_3d, 2)
    bboxes_3d = py_utils.HasShape(bboxes_3d, [-1, 7])
    num_bboxes, _ = py_utils.GetShape(bboxes_3d, 2)

    # Compute the 3-D corners of the bounding boxes.
    bboxes_3d_b = tf.expand_dims(bboxes_3d, 0)
    bbox_corners = BBoxCorners(bboxes_3d_b)
    bbox_corners = py_utils.HasShape(bbox_corners, [1, -1, 8, 3])
    # First four points are the top of the bounding box.
    # Counter-clockwise arrangement of points specifying 2-d Euclidean box.
    #   (x0, y1) <--- (x1, y1)
    #                    ^
    #                    |
    #                    |
    #   (x0, y0) ---> (x1, y0)
    bboxes_2d_corners = bbox_corners[0, :, 0:4, 0:2]
    bboxes_2d_corners = py_utils.HasShape(bboxes_2d_corners, [-1, 4, 2])
    # Determine if points lie within 2-D (x, y) plane for all bounding boxes.
    points_2d = points_3d[:, :2]
    is_inside_2d = IsWithinBBox(points_2d, bboxes_2d_corners)
    is_inside_2d = py_utils.HasShape(is_inside_2d, [num_points, num_bboxes])

    # Determine if points lie with the z-dimension for all bounding boxes.
    [_, _, z, _, _, dz, _] = tf.split(bboxes_3d, 7, axis=-1)

    def _ComputeLimits(center, width):
        left = center - width / 2.0
        right = center + width / 2.0
        return left, right

    z0, z1 = _ComputeLimits(z, dz)
    z_points = tf.expand_dims(points_3d[:, 2], -1)

    is_inside_z = tf.math.logical_and(
        tf.less_equal(z_points, z1[tf.newaxis, :, 0]),
        tf.greater_equal(z_points, z0[tf.newaxis, :, 0]))
    is_inside_z = py_utils.HasShape(is_inside_z, [num_points, num_bboxes])

    return tf.math.logical_and(is_inside_z, is_inside_2d)
  def dec_callback(self, tgt_id, tgt_pos, tgt_segment_id, tgt_mask, dec_state,
                   t):
    del tgt_pos, tgt_segment_id

    [buf] = dec_state
    if tgt_id.shape == (self.batch_size, self.beam_size):
      buf = inplace_ops.alias_inplace_update(buf, t, tgt_id)
    else:
      div = int(tgt_id.shape[1] // self.beam_size)
      for i, x_i in enumerate(tf.split(tgt_id, div, 1)):
        buf = inplace_ops.alias_inplace_update(buf, t + i, x_i)

    buf1 = tf.transpose(buf, [1, 0, 2])
    buf1 = tf.reshape(buf1, [self.batch_size, self.max_steps * self.beam_size])

    # select next_tgt_id as a function of previous target tokens
    if self.rule == '+1':
      next_tgt_id = (tgt_id + 1)
      next_tgt_id %= self.vocab_size
    elif self.rule == 'sum':
      # sum over all previous tokens in tgt_mask
      next_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(tgt_mask, tf.int32))
      next_tgt_id %= self.vocab_size
    elif self.rule == 'fib':
      # select last token according to tgt_mask
      m = tgt_mask
      m *= tf.cast(
          tf.equal(tf.cumsum(m, -1),
                   tf.reduce_sum(m, -1, keepdims=True) - 1), m.dtype)
      last_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(m, tf.int32))
      next_tgt_id = (last_tgt_id + tgt_id) % self.vocab_size

    # with a lower probably add extra +1 to the correct next_tgt_id
    n = self.vocab_size
    logits = 5 * tf.one_hot(next_tgt_id % n, n)
    logits += 4 * tf.one_hot((next_tgt_id + 1) % n, n)
    logits += 3 * tf.one_hot((next_tgt_id + 2) % n, n)
    logits += 2 * tf.one_hot((next_tgt_id + 3) % n, n)
    logits += 1 * tf.one_hot((next_tgt_id + 4) % n, n)

    # increase eos_score if current tgt_id contains 9
    eos_id = 0
    tgt_id_contains_9 = tf.logical_or(
        tf.equal(tgt_id % 10, 9), tf.equal((tgt_id // 10) % 10, 9))
    logits += 9 * tf.einsum('V,BK->BKV', tf.one_hot(eos_id, self.vocab_size),
                            tf.cast(tgt_id_contains_9, tf.float32))

    # tie-breaking -- lower token id wins a little bit
    tie = np.arange(0., 1., 1. / n)
    tie /= tie.sum()
    logits -= tie

    logits = tf.nn.log_softmax(logits)

    dec_state = [buf]
    return logits, dec_state
示例#5
0
  def testDecoderFPropSplitBatch(self, dtype=tf.float32):
    with self.session(use_gpu=True) as sess:
      tf.random.set_seed(_TF_RANDOM_SEED)
      p = self._DecoderParams(dtype=dtype)
      dec = decoder.TransformerDecoder(p)

      encoder_outputs, targets, _ = self._Inputs(dtype=dtype)
      src_enc1, src_enc2 = tf.split(encoder_outputs.encoded, 2, 1)
      src_paddings1, src_paddings2 = tf.split(encoder_outputs.padding, 2, 1)

      # source idx <-> target idx:
      # 0 <-> (0, 4), 1 <-> (1, 5), 2 <-> (2, 6), 3 <-> (3, 7)
      tgts = ig_helper.SplitDictOfTensors(targets, 4)
      targets1 = py_utils.NestedMap({
          'ids': tf.concat([tgts[0]['ids'], tgts[2]['ids']], 0),
          'labels': tf.concat([tgts[0]['labels'], tgts[2]['labels']], 0),
          'weights': tf.concat([tgts[0]['weights'], tgts[2]['weights']], 0),
          'paddings': tf.concat([tgts[0]['paddings'], tgts[2]['paddings']], 0)
      })
      targets2 = py_utils.NestedMap({
          'ids': tf.concat([tgts[1]['ids'], tgts[3]['ids']], 0),
          'labels': tf.concat([tgts[1]['labels'], tgts[3]['labels']], 0),
          'weights': tf.concat([tgts[1]['weights'], tgts[3]['weights']], 0),
          'paddings': tf.concat([tgts[1]['paddings'], tgts[3]['paddings']], 0)
      })

      loss, _ = dec.FPropDefaultTheta(encoder_outputs, targets).metrics['loss']
      encoder_outputs1 = py_utils.NestedMap(
          encoded=src_enc1, padding=src_paddings1, segment_id=None)
      loss1, _ = dec.FPropDefaultTheta(encoder_outputs1,
                                       targets1).metrics['loss']
      encoder_outputs2 = py_utils.NestedMap(
          encoded=src_enc2, padding=src_paddings2, segment_id=None)
      loss2, _ = dec.FPropDefaultTheta(encoder_outputs2,
                                       targets2).metrics['loss']

      tf.global_variables_initializer().run()
      actual_loss, actual_loss1, actual_loss2 = sess.run([loss, loss1, loss2])
      print('actual loss = ', actual_loss)
      print('actual loss1 = ', actual_loss1)
      print('actual loss2 = ', actual_loss2)
      self.assertAlmostEqual(
          actual_loss, np.mean([actual_loss1, actual_loss2]), delta=0.0001)
示例#6
0
    def ComputeNormalizedWER(self, hyps, refs, num_hyps_per_beam):
        # Filter out all '<epsilon>' tokens for norm_wer computation.
        hyps_no_epsilon = tf.strings.regex_replace(hyps, '(<epsilon>)+', ' ')
        # norm_wer is size [num_transcripts * hyps_per_beam, 2]
        norm_wer = decoder_utils.ComputeWer(hyps_no_epsilon, refs)
        # Split into two tensors of size [num_transcripts * hyps_per_beam, 1]
        norm_wer_errors, norm_wer_words = tf.split(norm_wer, [1, 1], 1)
        shape = [-1, num_hyps_per_beam]
        norm_wer_errors = tf.reshape(norm_wer_errors, shape)
        norm_wer_words = tf.reshape(norm_wer_words, shape)

        return norm_wer_errors, norm_wer_words
示例#7
0
  def testForwardPassSplitBatch(self):
    with self.session(use_gpu=False):
      bs = 8
      sl = 20
      tf.random.set_seed(8372749040)
      p = self._EncoderParams()
      p.random_seed = 1234
      mt_enc = encoder.TransformerEncoder(p)

      batch = py_utils.NestedMap()
      batch.ids = tf.constant(
          np.random.randint(low=0, high=63, size=[bs, sl], dtype=np.int32))
      batch.paddings = tf.zeros([bs, sl])
      out = mt_enc.FPropDefaultTheta(batch)
      enc_out = out.encoded
      emb_out = out.embedded_inputs

      inputs1, inputs2 = tf.split(batch.ids, 2, 0)
      paddings1, paddings2 = tf.split(batch.paddings, 2, 0)

      batch.ids = inputs1
      batch.paddings = paddings1
      out1 = mt_enc.FPropDefaultTheta(batch)
      enc_out1 = out1.encoded
      emb_out1 = out1.embedded_inputs

      batch.ids = inputs2
      batch.paddings = paddings2
      out2 = mt_enc.FPropDefaultTheta(batch)
      enc_out2 = out2.encoded
      emb_out2 = out2.embedded_inputs

      self.evaluate(tf.global_variables_initializer())
      actual_enc_out, actual_enc_out1, actual_enc_out2, \
          actual_emb_out, actual_emb_out1, actual_emb_out2 = self.evaluate(
              [enc_out, enc_out1, enc_out2, emb_out, emb_out1, emb_out2])
      self.assertAllClose(actual_enc_out,
                          np.concatenate([actual_enc_out1, actual_enc_out2], 1))
      self.assertAllClose(actual_emb_out,
                          np.concatenate([actual_emb_out1, actual_emb_out2], 1))
示例#8
0
    def StreamStep(self, theta, inputs, paddings, state0):
        """Runs single step.

    Args:
      theta: A NestedMap of layer params.
      inputs: [b, 1, d].
      paddings: A 0/1 valued tensor of shape [b, 1].
      state0: A NestedMap of tensors of the same struct as returned by
        zero_state().

    Returns:
      outputs: A NestedMap of tensors consisting:
      padding: the same as input paddings.
      state1: A NestedMap of tensors of the same struct as state0.
    """
        p = self.params
        assert p.is_causal

        state1 = py_utils.NestedMap()
        with tf.name_scope(f'{p.name}/StreamStep'):
            unnormalized_inputs = inputs

            inputs = self.ln.FProp(theta.ln, inputs)
            if p.split_act_gated_linear_start:
                act_inputs = self.linear_start_act.FProp(
                    theta.linear_start_act, inputs)
                gated_inputs = self.linear_start_gated.FProp(
                    theta.linear_start_gated, inputs)
            else:
                inputs = self.linear_start.FProp(theta.linear_start, inputs)
                gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1)
            inputs = self._GLU(gated_inputs, act_inputs)

            # TODO(jamesqin): inroduce depthwise conv2d with 3d inputs.
            # TODO(jamesqin): optimize DepthwiseConv1D.StreamStep()
            # [b, t, d] --> [b, t, 1, d]
            inputs = tf.expand_dims(inputs, 2)
            # [b, t, 1, d]
            inputs, paddings, conv_state1 = self.depthwise_conv1d.StreamStep(
                theta.depthwise_conv1d, inputs, paddings, state0.conv_state)
            state1.conv_state = conv_state1
            # [b, t, d]
            inputs = self._NormalizeStep(theta, inputs, paddings, state0,
                                         state1)

            inputs = self._ApplyActivation(inputs, p.conv_activation)

            inputs = self.linear_end.FProp(theta.linear_end, inputs)
            inputs = self.dropout.FProp(theta.dropout, inputs)

            output = inputs + unnormalized_inputs
            return output, paddings, state1
示例#9
0
  def _InputBatch(self):
    np.random.seed(1)
    bs, sl = 10, 7
    src_ids = tf.constant(
        np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32))
    tgt_ids = tf.constant(
        np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32))
    tgt_labels = tf.constant(
        np.random.randint(low=0, high=8192 - 1, size=[bs, sl], dtype=np.int32))
    tgt_weights = tf.constant(np.ones(shape=[bs, sl], dtype=np.float32))

    src_paddings = tf.zeros([bs, sl])
    tgt_paddings = tf.zeros([bs, sl])

    ret = py_utils.NestedMap()
    ret.src = py_utils.NestedMap()
    ret.tgt = py_utils.NestedMap()

    if self.params.split:
      src_ids = tf.split(src_ids, 2, 0)
      src_paddings = tf.split(src_paddings, 2, 0)
      tgt_ids = tf.split(tgt_ids, 2, 0)
      tgt_labels = tf.split(tgt_labels, 2, 0)
      tgt_paddings = tf.split(tgt_paddings, 2, 0)
      tgt_weights = tf.split(tgt_weights, 2, 0)

      ret.src.ids = tf.cond(
          tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0),
          lambda: src_ids[0], lambda: src_ids[1])
      ret.src.paddings = tf.cond(
          tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0),
          lambda: src_paddings[0], lambda: src_paddings[1])
      ret.tgt.ids = tf.cond(
          tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0),
          lambda: tgt_ids[0], lambda: tgt_ids[1])
      ret.tgt.labels = tf.cond(
          tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0),
          lambda: tgt_labels[0], lambda: tgt_labels[1])
      ret.tgt.paddings = tf.cond(
          tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0),
          lambda: tgt_paddings[0], lambda: tgt_paddings[1])
      ret.tgt.weights = tf.cond(
          tf.equal(tf.math.floormod(py_utils.GetGlobalStep(), 2), 0),
          lambda: tgt_weights[0], lambda: tgt_weights[1])
    else:
      ret.src.ids = src_ids
      ret.src.paddings = src_paddings
      ret.tgt.ids = tgt_ids
      ret.tgt.labels = tgt_labels
      ret.tgt.paddings = tgt_paddings
      ret.tgt.weights = tgt_weights

    return ret
示例#10
0
 def partition_tensor(cls, tensor, partition_info):
     """Returns partitioned tensors."""
     metadata = (TensorPartitioner.partition_metadata(
         tensor, partition_info))
     # Split from last to first axis.
     partitioned_tensors = [tensor]
     rank = len(metadata.num_splits_per_dim)
     for raxis, (num_splits, sizes) in enumerate(
             zip(reversed(metadata.num_splits_per_dim),
                 reversed(metadata.split_sizes_per_dim))):
         if num_splits > 1:
             tmp_partitioned_tensors = []
             for item in partitioned_tensors:
                 tmp_partitioned_tensors += tf.split(item,
                                                     sizes,
                                                     axis=rank - raxis - 1)
             partitioned_tensors = tmp_partitioned_tensors
     return partitioned_tensors
示例#11
0
    def FProp(self, theta, inputs, paddings):
        """Builds FProp graph.

    Args:
      theta: A NestedMap of Tensors, see base class.
      inputs: A Tensor of shape [batch, seqlen, dim0].
      paddings: A Tensor of shape [batch, seqlen].

    Returns:
      output: A Tensor of shape [batch, seqlen, dim0].
      out_paddings: A Tensor of shape [batch, seqlen].
    """

        p = self.params
        with tf.name_scope(p.name):
            unnormalized_inputs = inputs

            inputs = self.ln.FProp(theta.ln, inputs)
            if p.split_act_gated_linear_start:
                act_inputs = self.linear_start_act.FProp(
                    theta.linear_start_act, inputs)
                gated_inputs = self.linear_start_gated.FProp(
                    theta.linear_start_gated, inputs)
            else:
                inputs = self.linear_start.FProp(theta.linear_start, inputs)
                gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1)
            inputs = self._GLU(gated_inputs, act_inputs)

            # TODO(jamesqin): inroduce depthwise conv2d with 3d inputs.
            # [b, t, d] --> [b, t, 1, d]
            inputs = tf.expand_dims(inputs, 2)
            theta.depthwise_conv1d.w = moe_layers.Split(
                theta.depthwise_conv1d.w, 2, p.xla_num_partitions)
            inputs, paddings = self.depthwise_conv1d.FProp(
                theta.depthwise_conv1d, inputs, paddings)
            inputs = self._Normalize(theta, inputs, paddings)

            inputs = self._ApplyActivation(inputs, p.conv_activation)

            inputs = self.linear_end.FProp(theta.linear_end, inputs)
            inputs = self.dropout.FProp(theta.dropout, inputs)

            output = inputs + unnormalized_inputs
            return output, paddings
def SplitTensors(xs, num_splits):
    """Splits tensors in `xs` evenly into num_splits along the 1st dimenion.

  Args:
    xs: A tuple of tensors. Each tensor's 1st dimension is the same size.
    num_splits: A python integer.

  Returns:
    A tuple of lists of tensors, num elements in the tuple = len(xs).

    i-th element in each list corresponds to i-th split of each tensor in xs
    along the first dimension of each tensor.
  """
    # assert first dim of all tensors in xs is equal
    batch_dims = [tf.shape(x)[0] for x in xs]
    all_batch_dims = tf.stack(batch_dims)

    all_batch_dims = py_utils.with_dependencies([
        py_utils.assert_equal(all_batch_dims,
                              tf.shape(xs[0])[0],
                              message='first dim of tensors in xs must match'),
        py_utils.assert_greater_equal(
            tf.shape(xs[0])[0],
            num_splits,
            message='first dim of tensors in xs must be greater than num_splits'
        )
    ], all_batch_dims)

    splits = ComputeSplits(tf.shape(xs[0])[0], num_splits)
    print("splits " + str(splits))
    # add the above assertion into the compute graph
    splits = py_utils.with_dependencies([all_batch_dims], splits)
    print("splits 2 " + str(splits))
    print("xs " + str(xs))
    # this step get x
    # splits is not the number of spilits, it is the
    #split_xs = [tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs]
    split_xs = [
        tf.split(axis=0, num_or_size_splits=num_splits, value=x) for x in xs
    ]
    print("split_xs " + str(split_xs))

    return split_xs
示例#13
0
def _DistortBrightnessAndColor(image):
  """Distorts brightness and color of the input image.

  Args:
    image: 3-D Tensor containing single image in [0, 1].

  Returns:
    3-D Tensor color-distorted image in range [0, 1]
  """
  br_delta = tf.random.uniform([], -32. / 255., 32. / 255.)
  cb_factor = tf.random.uniform([], -0.1, 0.1)
  cr_factor = tf.random.uniform([], -0.1, 0.1)

  channels = tf.split(axis=2, num_or_size_splits=3, value=image)
  red_offset = 1.402 * cr_factor + br_delta
  green_offset = -0.344136 * cb_factor - 0.714136 * cr_factor + br_delta
  blue_offset = 1.772 * cb_factor + br_delta
  channels[0] += red_offset
  channels[1] += green_offset
  channels[2] += blue_offset
  return tf.clip_by_value(tf.concat(channels, axis=2), 0., 1.)
 def _GLUFn(inputs):
   gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1)
   return act_inputs * tf.sigmoid(gated_inputs)
示例#15
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
示例#16
0
文件: gpipe.py 项目: jairsan/lingvo-1
  def FProp(self, theta, *args):
    """Run multiple cells in different devices in a pipelining manner.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      *args: Non-keyworded variable length argument list of input tensors.

    Returns:
      A list of output tensors
    """
    # TODO(huangyp): handle optional None inputs.
    p = self.params
    if p.is_eval:
      outputs = _ToTuple(args)
      for (name, l) in self._before_layers:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      for (name, l) in self._cells:
        outputs = _ToTuple(outputs)
        outputs = l.FProp(theta[name], *outputs)
      return outputs

    num_cells = len(p.cell_tpl)
    cluster = self.cluster

    # Compute shapes of input and output tenors.
    input_tenors = _ToTuple(args)
    mini_batch_size = input_tenors[0].get_shape().as_list()[p.batch_dim]
    if p.state_dtype:
      state_dtype = p.state_dtype
    else:
      state_dtype = input_tenors[0].dtype
    if p.num_micro_batches > mini_batch_size:
      p.num_micro_batches = mini_batch_size
    micro_batch_size = mini_batch_size // p.num_micro_batches

    input_shapes = ()
    for input_tensor in input_tenors:
      if input_tensor is not None:
        input_shape = input_tensor.get_shape().as_list()
        input_shape[p.batch_dim] = micro_batch_size
        input_shapes += (tf.TensorShape(input_shape),)
      else:
        input_shapes += (None,)

    state_shapes = self._CalculateOutputShapes(input_shapes)

    def GetCellFn(i):
      """Get the ith feature extraction layer."""

      def CellFn(theta, state0, inputs):
        """A cell fn is exectued inside of StackedRecurrent."""
        del state0
        frop_inputs = []
        for input_idx in range(len(state_shapes[i])):
          name = 's{}'.format(input_idx)
          if state_shapes[i][input_idx] is not None:
            inputs[name].set_shape(state_shapes[i][input_idx])
            frop_inputs.append(inputs[name])
          else:
            frop_inputs.append(None)

        with CellFnFropOpReplacementWrapper():
          tf.logging.info('cell {} input {}'.format(i, frop_inputs))
          mb_tensor = inputs[_MICRO_BATCH_STATE_NAME]
          SetOverWriteGlobalStep(mb_tensor)
          _, cell = self._cells[i]
          outputs = cell.FProp(theta, *frop_inputs)

        state1 = py_utils.NestedMap()
        state1[_MICRO_BATCH_STATE_NAME] = mb_tensor
        outputs = _ToTuple(outputs)
        assert len(outputs) == len(state_shapes[i + 1])
        for output_idx in range(len(outputs)):
          if outputs[output_idx] is not None:
            name = 's{}'.format(output_idx)
            state1[name] = outputs[output_idx]
        return state1, py_utils.NestedMap()

      return CellFn

    cell_fns = []
    accumulator_layers = []
    thetas = []
    init_states = []
    devices = []
    for cell_idx in range(num_cells):
      cell_name, cell = self._cells[cell_idx]
      accumulator_layers.append(cell)
      cell_fns.append(GetCellFn(cell_idx))
      thetas.append(theta[cell_name])
      init_state = py_utils.NestedMap()
      init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype)
      for output_idx in range(len(state_shapes[cell_idx + 1])):
        name = 's{}'.format(output_idx)
        if state_shapes[cell_idx + 1][output_idx] is not None:
          init_state[name] = tf.zeros(
              state_shapes[cell_idx + 1][output_idx], dtype=state_dtype)
      init_states.append(init_state)
      devices.append(cluster.WorkerDeviceInModelSplit(cell_idx))

    cell_grads = [None] * num_cells
    cell_outs = [lambda x: x] * num_cells
    cell_out_grads = [lambda x: x] * num_cells

    with tf.device(devices[0]):
      previous = input_tenors
      for (name, l) in self._before_layers:
        previous = l.FProp(theta[name], *previous)
        previous = _ToTuple(previous)
      inputs = py_utils.NestedMap()
      gs_tensor = py_utils.GetGlobalStep()
      inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([
          tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype)
          for t in range(p.num_micro_batches)
      ])

      # TODO(huangyp, dehao): apply dehao's trick to reshape the input tensor
      # to [p.num_micro_batches, -1, 128].
      for output_idx, output_tenor in enumerate(previous):
        name = 's{}'.format(output_idx)
        if output_tenor is not None:
          output_tenor = tf.stack(
              tf.split(output_tenor, p.num_micro_batches, axis=p.batch_dim))
          inputs[name] = output_tenor

    output, _ = recurrent.StackedRecurrent(
        devices=devices,
        cell_fns=cell_fns,
        cell_grads=cell_grads,
        cell_outs=cell_outs,
        cell_out_grads=cell_out_grads,
        thetas=thetas,
        init_states=init_states,
        inputs=inputs,
        accumulator_layers=accumulator_layers,
        unused_acc_state=True)

    with tf.device(devices[-1]):
      output_tensors = []
      for output_idx in range(len(state_shapes[-1])):
        state_shape = state_shapes[-1][output_idx]
        if state_shape is None:
          output_tensors.append(None)
          continue
        output_name = 's{}'.format(output_idx)
        output_tensor = output[output_name]
        if p.batch_dim != 0:
          perm = list(range(1, p.batch_dim + 1)) + [0]
          perm += list(range(p.batch_dim + 1, len(state_shape) + 1))
          output_tensor = tf.transpose(output_tensor, perm=perm)
        state_shape[p.batch_dim] *= p.num_micro_batches
        output_tensor = tf.reshape(output_tensor, state_shape)
        output_tensors.append(output_tensor)
      tf.logging.info('pipeline output = {}'.format(output_tensors))
      if len(output_tensors) == 1:
        return output_tensors[0]
      return tuple(output_tensors)
示例#17
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if 'bucket_keys' in batch:
                        # Hack: bucket_keys are not needed on TPU.
                        del batch['bucket_keys']
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        # For executor-driven multiple programs, we need more fine-grained
        # access rather than using a single global graph collection.
        self.tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
def flat_beam_search(batch_size,
                     beam_size,
                     max_steps,
                     dec_callback,
                     dec_state,
                     bos_id=1,
                     eos_id=2,
                     length_norm_alpha=0.8,
                     beam_gap=3.0,
                     top_k_fn=tf.math.top_k,
                     prefix=None,
                     prefix_len=None,
                     fprop_dtype=tf.float32,
                     ext_size=0,
                     nbest_size=None,
                     debug=True):
    """Flat beam search.

  Args:
    batch_size: batch size
    beam_size: beam size limit in number of hyps
    max_steps: max steps
    dec_callback: decoder callback (see above)
    dec_state: decoder state
    bos_id: <s> token id
    eos_id: </s> token id
    length_norm_alpha: length normalization parameter
    beam_gap: early stopping threshold; None to disable
    top_k_fn: top_k function to call
    prefix: (optional) int32 tensor [batch_size, prefix_max]
    prefix_len: (optional) int32 tensor [batch_size]
    fprop_dtype: fprop dtype
    ext_size: int >= beam_size, extension buffer size
    nbest_size: number of returned hyps, default is beam_size
    debug: log intermediate vlaues with tpu_summary.tensor()

  Returns:
    (loop_vars, dec_state, nbest) where
    nbest = (topk_ids, topk_len, topk_score)
  """
    assert beam_size > 0
    assert batch_size > 0
    assert max_steps > 0

    buf_size = beam_size * max_steps
    output_len = max_steps

    if prefix is None:
        assert prefix_len is None
        prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        prefix += tf.one_hot(0, beam_size, dtype=tf.int32) * bos_id
        prefix_len = tf.ones([batch_size], dtype=tf.int32)
    else:
        assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape)
        assert int(prefix_len.shape[0]) == batch_size, (batch_size,
                                                        prefix_len.shape)
        output_len += int(prefix.shape[1])

    if debug:
        tpu_summary.tensor('prefix', prefix)
        tpu_summary.tensor('prefix_len', prefix_len)

    with tf.name_scope('init_state'):
        t = tf.constant(0)
        tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_id += bos_id
        tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32)
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size),
                               buf_size,
                               dtype=fprop_dtype)
        hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype)
        # penalize all hyps except the first
        hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5,
                             dtype=fprop_dtype)
        nbest_size = nbest_size or beam_size
        nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype)
        nbest_score -= 1e9
        nbest_score_norm = nbest_score
        nbest_mask = tf.zeros([batch_size, nbest_size, buf_size],
                              dtype=fprop_dtype)

    with tf.name_scope('init_ext'):
        # Initialize the extension buffer.
        #
        # Extension buffer stores a (potentially large) set of 'extensions',
        # which consist of a hypothesis (represented by ext_mask) and next token
        # (represented by ext_id). At each decoder iteration, top_k extensions
        # from each hypothesis are added to the buffer and sorted by score.
        #
        # Then top beam_size extensions are removed from the buffer and used
        # in the next decoder iteration. And top 'ext_size' remaining extensions
        # are carried over to be possibly evaluated at a later step.
        #
        # As a result of this manipulation, the decoder is no longer restricted
        # to always compare hyps of the same token length at each iteration.
        # In particular, for a fixed length N it can generate more than beam_size
        # terminated hyps.
        #
        # Setting ext_size = 0 disables this feautre.
        if ext_size:
            ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32)
            ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype)
            ext_score -= 1e9
            ext_mask = tf.zeros([batch_size, ext_size, buf_size],
                                dtype=fprop_dtype)
        else:
            ext_size = ext_id = ext_score = ext_mask = 0

    with tf.name_scope('init_prefix'):
        # rename prefix->pfx for shorter variables
        pfx = tf.cast(prefix, tf.int32)
        pfx_len = tf.cast(prefix_len, tf.int32)
        del prefix, prefix_len
        # Before the first call to dec_callback() the prefix shall be packed into
        # the tgt_id buffer as follows:
        #
        # [ P P P P P P - - - - - - P* - - - ]   ^
        # [ P P P P P P P P P P - - P* - - - ]   | batch
        # [ P - - - - - - - - - - - P* - - - ]   V
        # |<---- prefix len ---->  |<-- beam -->
        #
        # The last meaningful token in the prefix (P*)
        # must be located at the same position in all batch rows.
        #
        # We then make one dec_callback() with full prefix (minus P*)
        # which will populate the initial dec_state
        # (for transformer -- self-attention key/value cache)
        #
        # The last block [batch, beam] then becomes the first tgt_id for the loop.
        pfx_max = int(pfx.shape[1])
        pfx_mul = pfx_max // beam_size
        assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size)
        pfx_time = tf.range(pfx_max)
        pfx_pad = tf.cast(
            tf.less(tf.expand_dims(pfx_time, 0),
                    tf.expand_dims(pfx_len - 1, 1)), tf.int32)
        pfx_id = pfx * pfx_pad
        pfx_last = einsum_i32(
            'BT,BT->B', pfx, tf.one_hot(pfx_len - 1,
                                        pfx_max,
                                        dtype=fprop_dtype))

        buf_time = tf.range(buf_size)
        pfx_time_mask = tf.cast(
            tf.less_equal(tf.expand_dims(buf_time, 0),
                          tf.expand_dims(pfx_time, 1)), fprop_dtype)
        pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype),
                             pfx_time_mask)
        pfx_segment_id = pfx_pad
        pfx_pos = pfx_time * pfx_pad

        if debug:
            tpu_summary.tensor('pfx_id', pfx_id)
            tpu_summary.tensor('pfx_len', pfx_len)
            tpu_summary.tensor('pfx_pos', pfx_pos)
            tpu_summary.tensor('pfx_last', pfx_last)

        # Now call decoder with prefix minus P*:
        # 'dec_state' now shall contain the key/value cache for prefix tokens
        # (for transformer models), and 'logits' we can either discard or
        # roll into the initial hyp_score. Discard is simpler.
        with tf.name_scope('prefix_fprop'):
            # TODO(krikun): remove extra type checks
            assert (pfx_id.dtype == tf.int32), (pfx_id.dtype)
            assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype)
            assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype)
            assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype)
            assert (t.dtype == tf.int32), (t.dtype)
            logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos,
                                             pfx_mask, dec_state, t)
            del logits

        # Now construct the initial state for the rest of the beam search loop.
        # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape
        # 'tgt_pos' is different for each batch row and is equal to prefix_len
        # 'tgt_segment_id' always 1 (no packing)
        # 'hyp_score' is 0 for beam=0 and negative for beam>=1
        tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            pfx_last, 1)
        tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims(
            (pfx_len - 1), 1)
        hyp_score = tf.zeros(
            [batch_size, beam_size], dtype=fprop_dtype) - tf.cast(
                tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype)

        # TODO(krikun) Here we make initial 't' constant and determined by the
        # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic
        # as t ~  max(pfx_len) / beam_size and this will more steps for beam search
        # however 'max' results in a very slow all-to-all for 'max' on 16x16
        # and variable number of decoder steps may result in bad latency.
        t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32)

        # Initial tgt_mask is such that each token P* has attention on itself
        # (as usual) and on all prefix tokens before it, which are not padding.
        tgt_mask = tf.zeros([batch_size, beam_size, buf_size],
                            dtype=fprop_dtype)
        tgt_mask += tf.cast(
            tf.expand_dims(
                tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1),
            fprop_dtype)
        tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                               buf_size,
                               dtype=fprop_dtype)

        if debug:
            tpu_summary.tensor('tgt_id', tgt_id)
            tpu_summary.tensor('tgt_pos', tgt_pos)
            tpu_summary.tensor('tgt_mask', tgt_mask)
            tpu_summary.tensor('t', t)

    with tf.name_scope('init_hist'):
        # h_tgt_id is used to recover topk_ids from nbest_mask
        h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps)
        h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps)

        # When non-trivial prefix is present we also write prefix ids to
        # h_tgt_id so that the full sequence including prefix can be recovered
        # by unmask() below.  When prefix is empty, pfx_id shape is [batch, 0]
        # and the loop below becomes a no-op.
        # TODO(krikun): maybe a tf.while_loop is more appropriate here.
        for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)):
            h_tgt_id = h_tgt_id.write(i, x_i)
        for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)):
            h_tgt_pos = h_tgt_pos.write(i, x_i)

        hist = (h_tgt_id, h_tgt_pos)
        tf.logging.info('hist=%r', hist)

    nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm)
    tf.logging.info('nbest_hyps=%r', nbest_hyps)

    ext = (ext_id, ext_score, ext_mask)
    tf.logging.info('ext=%r', ext)

    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)
    tf.logging.info('loop_vars=%r', loop_vars)

    def loop_step(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
         hist) = loop_vars
        (ext_id, ext_score, ext_mask) = ext
        (h_tgt_id, h_tgt_pos) = hist
        h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id')
        h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos')
        # not using tf.ones() here because of XLA compilation error
        tgt_segment_id = tgt_id * 0 + 1
        logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos,
                                         tgt_mask, dec_state, t)
        # take predicted EOS score for each hyp and compute normalized score
        eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype)

        def length_norm(t):
            t = tf.cast(t, fprop_dtype)
            alpha = length_norm_alpha
            tf.logging.info('length_norm.alpha=%r', alpha)
            return tf.math.pow((t + 5.) / 5., alpha)

        hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1)
        eos_score_norm = eos_score / length_norm(hyp_len)
        # update the n-best list
        nbest_hyps = update_nbest(nbest_hyps,
                                  (tgt_mask, hyp_score, eos_score_norm))

        if debug:
            tpu_summary.tensor('eos_score', eos_score)
            tpu_summary.tensor('hyp_len', hyp_len)

        # take top k tokens for each hyp
        k = beam_size
        with tf.name_scope('topk1'):
            top_score, top_id = top_k_fn(logits, k)
            top_score = tf.cast(top_score, fprop_dtype)

        top_score += tf.expand_dims(hyp_score, -1)
        top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype)

        top_score = tf.reshape(top_score, [batch_size, beam_size * k])
        top_id = tf.reshape(top_id, [batch_size, beam_size * k])
        top_mask = tf.repeat(tgt_mask, beam_size, 1)

        if debug:
            tpu_summary.tensor('top_id', top_id)
            tpu_summary.tensor('top_score', top_score)
            # tpu_summary.tensor('top_mask', top_mask)

        with tf.name_scope('update_ext'):
            # combine top k tokens with extension buffer (if any)
            if ext_size:
                ext_id = tf.concat([ext_id, top_id], 1)
                ext_score = tf.concat([ext_score, top_score], 1)
                ext_mask = tf.concat([ext_mask, top_mask], 1)
            else:
                ext_id, ext_score, ext_mask = top_id, top_score, top_mask

            # sort by score
            ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size)
            i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype)
            ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1)
            ext_id = einsum_i32('bk,bjk->bj', ext_id, i1)

            # pick top beam_size extensions to evaluate at next iteration
            if ext_size:
                hyp_score = ext_score[:, :beam_size]
                ext_score = ext_score[:, beam_size:]
                tgt_id = ext_id[:, :beam_size]
                ext_id = ext_id[:, beam_size:]
                tgt_mask = ext_mask[:, :beam_size]
                ext_mask = ext_mask[:, beam_size:]
            else:
                hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask
                ext_score = ext_id = ext_mask = 0

        tgt_pos = tf.reduce_sum(tgt_mask, -1)
        tgt_pos = tf.cast(tgt_pos, tf.int32)

        t += 1
        with tf.name_scope('tgt_mask_extend'):
            tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size,
                                   buf_size,
                                   dtype=fprop_dtype)

        ext = (ext_id, ext_score, ext_mask)
        hist = (h_tgt_id, h_tgt_pos)
        loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                     hist)
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        return loop_vars, dec_state

    def loop_cond(loop_vars, dec_state):  # pylint: disable=missing-docstring
        tf.logging.info('loop_vars=%r', loop_vars)
        tf.logging.info('dec_state=%r', dec_state)
        if beam_gap is None:
            (t, _, _, _, _, _, _, _) = loop_vars
            return t < max_steps
        else:
            (t, _, _, _, _, nbest_hyps, _, _) = loop_vars
            (_, nbest_score, _) = nbest_hyps
            # stop early if all current hyps are significantly worse than nbest
            diff = tf.reduce_min(
                tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1))
            return tf.math.logical_and(t < max_steps, diff < beam_gap)

    with tf.name_scope('flat_beam_search_loop'):
        (loop_vars, dec_state) = tf.while_loop(loop_cond,
                                               loop_step,
                                               loop_vars=(loop_vars,
                                                          dec_state),
                                               back_prop=False,
                                               swap_memory=False,
                                               maximum_iterations=max_steps)

    # flatten all tensorarrays into tensors
    (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
     hist) = loop_vars
    (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps
    (h_tgt_id, h_tgt_pos) = hist
    h_tgt_id = h_tgt_id.stack()
    h_tgt_pos = h_tgt_pos.stack()
    hist = (h_tgt_id, h_tgt_pos)
    loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext,
                 hist)

    # recover topk_ids from nbest_mask and tgt_id history
    h = tf.transpose(h_tgt_id, [1, 0, 2])
    h = tf.reshape(h, [batch_size, buf_size])

    def unmask(h, m):
        with tf.name_scope('unmask'):
            tpu_summary.tensor('unmask_h', h)
            tpu_summary.tensor('unmask_m', m)
            t = tf.cumsum(m, -1) * m - 1
            mh = einsum_i32('bkt,bt->bkt', m, h)
            t2 = tf.one_hot(tf.cast(t, tf.int32),
                            output_len,
                            dtype=fprop_dtype)
            x = einsum_i32('bkt,bktT->bkT', mh, t2)
            return tf.cast(x, h.dtype)

    topk_ids = unmask(h, nbest_mask)
    topk_len = tf.reduce_sum(nbest_mask, -1)
    topk_len = tf.cast(topk_len, tf.int32)
    # add eos, because nbest_mask does not encode eos
    topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32)
    topk_len += 1
    topk_len = tf.minimum(topk_len, output_len)
    topk_score = nbest_score_norm

    nbest = (topk_ids, topk_len, topk_score)

    return loop_vars, dec_state, nbest
示例#19
0
 def _GLU(self, inputs):
     p = self.params
     gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1)
     return self._ApplyActivation(
         act_inputs, p.glu_activation) * tf.sigmoid(gated_inputs)
示例#20
0
def IsWithinBBox3D(points_3d, bboxes_3d):
    """Checks if points are within a 3-d bbox.

  Args:
    points_3d: [..., num_points, 3] float32 Tensor specifying points in 3-d
      space as [x, y, z] coordinates.
    bboxes_3d: [..., num_bboxes, 7] float32 Tensor specifying a 3-d bboxes
      specified as [x, y, z, dx, dy, dz, phi] where x, y and z is the center of
      the box.

  Returns:
    boolean Tensor of shape [..., num_points, num_bboxes] indicating whether the
    points belong within each box.
  """
    # Check that points_3d and bboxes_3d have the same rank.
    bboxes_rank = py_utils.GetRank(bboxes_3d)
    points_3d = py_utils.HasRank(points_3d, bboxes_rank)
    leading_shape = py_utils.GetShape(bboxes_3d)[:-2]

    # Check that both points_3d and bboxes_3d have the same leading shape.
    points_3d = py_utils.HasShape(points_3d, leading_shape + [-1, 3])
    bboxes_3d = py_utils.HasShape(bboxes_3d, leading_shape + [-1, 7])

    num_points = py_utils.GetShape(points_3d)[-2]
    num_bboxes = py_utils.GetShape(bboxes_3d)[-2]

    bbox_corners = BBoxCorners(bboxes_3d)
    bbox_corners = py_utils.HasShape(bbox_corners,
                                     leading_shape + [num_bboxes, 8, 3])
    # First four points are the top of the bounding box.
    # Counter-clockwise arrangement of points specifying 2-d Euclidean box.
    #   (x0, y1) <--- (x1, y1)
    #                    ^
    #                    |
    #                    |
    #   (x0, y0) ---> (x1, y0)
    bboxes_2d_corners = bbox_corners[..., 0:4, 0:2]
    # Determine if points lie within 2-D (x, y) plane for all bounding boxes.
    points_2d = points_3d[..., :2]
    is_inside_2d = IsWithinBBox(points_2d, bboxes_2d_corners)

    is_inside_2d = py_utils.HasShape(is_inside_2d,
                                     leading_shape + [num_points, num_bboxes])

    # Determine if points lie with the z-dimension for all bounding boxes.
    [_, _, z, _, _, dz, _] = tf.split(bboxes_3d, 7, axis=-1)

    def _ComputeLimits(center, width):
        left = center - width / 2.0
        right = center + width / 2.0
        return left, right

    z0, z1 = _ComputeLimits(z[..., 0], dz[..., 0])
    z_points = points_3d[..., 2:]

    is_inside_z = tf.math.logical_and(
        tf.less_equal(z_points, z1[..., tf.newaxis, :]),
        tf.greater_equal(z_points, z0[..., tf.newaxis, :]))
    is_inside_z = py_utils.HasShape(is_inside_z,
                                    leading_shape + [num_points, num_bboxes])

    return tf.math.logical_and(is_inside_z, is_inside_2d)
示例#21
0
 def _StackAndSplit(x):
   # Split tensors into microbatches.
   if x is None:
     return None
   return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim))
示例#22
0
 def Gate(x):
     u, v = tf.split(x, 2, axis=-1)
     return u * tf.sigmoid(v)
 def _GatedTanhFn(inputs):
   gated_inputs, act_inputs = tf.split(inputs, 2, axis=-1)
   return tf.tanh(act_inputs) * tf.sigmoid(gated_inputs)