Esempio n. 1
0
    def cache_fn(batch_size):
        mems = []
        if FLAGS.mem_len > 0:
            for _ in range(FLAGS.n_layer):
                if FLAGS.chunk_len is not None:
                    zeros = tf.zeros(
                        [batch_size, mem_len, FLAGS.chunk_len, FLAGS.d_model],
                        dtype=tf_float)
                else:
                    zeros = tf.zeros([batch_size, mem_len, FLAGS.d_model],
                                     dtype=tf_float)
                mems.append(zeros)

        return mems
Esempio n. 2
0
    def cache_fn(batch_size):
        """Call back function used to create initial cache in TPUEstimator."""
        mems = []
        for obj_len in [FLAGS.seq_len]:
            if obj_len > 0:
                for _ in range(FLAGS.n_layer + int(FLAGS.use_extra_layer)):
                    zeros = tf.zeros([batch_size, mem_len, FLAGS.d_model],
                                     dtype=tf_float)
                    mems.append(zeros)

        return mems
    def _transform_images(self, params, features, labels=None):
        """Transforms images."""

        images = features['images']
        batch_size, _, _, c = images.get_shape().as_list()
        if params['conv0_space_to_depth_block_size'] != 0:
            # Transforms (space-to-depth) images for TPU performance.

            def _fused_transform(images, image_size):
                return spatial_transform.fused_transpose_and_space_to_depth(
                    images, image_size,
                    params['conv0_space_to_depth_block_size'],
                    params['transpose_input'])

            images = tf.cond(
                tf.less(features['image_info'][0, 3],
                        features['image_info'][0, 4]),
                lambda: _fused_transform(images, params['image_size']),
                lambda: _fused_transform(images, params['image_size'][::-1]))

        else:
            # Transposes images for TPU performance.
            image_area = params['image_size'][0] * params['image_size'][1]
            if params['transpose_input']:
                images = tf.transpose(images, [1, 2, 0, 3])
                # Flattens spatial dimensions so that the image tensor has a static
                # shape.
                images = tf.reshape(images, [image_area, batch_size, c])
            else:
                images = tf.reshape(images, [batch_size, image_area, c])

        if params['use_bfloat16']:
            images = tf.cast(images, dtype=tf.bfloat16)

        features['images'] = images

        if labels is not None:
            return features, labels
        else:
            return features, tf.zeros([batch_size])
 def zeroed_box_fn():
     return tf.zeros([0, self._ori_height, self._ori_width, 1])
Esempio n. 5
0
  def parser(record):
    """function used to parse tfrecord."""

    record_spec = {
        "input": tf.FixedLenFeature([seq_len], tf.int64),
        "target": tf.FixedLenFeature([seq_len], tf.int64),
        "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
        "label": tf.FixedLenFeature([1], tf.int64),
        "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
    }

    # retrieve serialized example
    example = tf.parse_single_example(
        serialized=record,
        features=record_spec)

    inputs = example.pop("input")
    target = example.pop("target")
    is_masked = tf.cast(example.pop("is_masked"), tf.bool)

    non_reuse_len = seq_len - reuse_len
    assert perm_size <= reuse_len and perm_size <= non_reuse_len

    perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
        inputs[:reuse_len],
        target[:reuse_len],
        is_masked[:reuse_len],
        perm_size,
        reuse_len)

    perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
        inputs[reuse_len:],
        target[reuse_len:],
        is_masked[reuse_len:],
        perm_size,
        non_reuse_len)

    perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
                            axis=1)
    perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
                            axis=1)
    perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
    target = tf.concat([target_0, target_1], axis=0)
    target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
    input_k = tf.concat([input_k_0, input_k_1], axis=0)
    input_q = tf.concat([input_q_0, input_q_1], axis=0)

    if num_predict is not None:
      indices = tf.range(seq_len, dtype=tf.int64)
      bool_target_mask = tf.cast(target_mask, tf.bool)
      indices = tf.boolean_mask(indices, bool_target_mask)

      ##### extra padding due to CLS/SEP introduced after prepro
      actual_num_predict = tf.shape(indices)[0]
      pad_len = num_predict - actual_num_predict

      ##### target_mapping
      target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
      paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
      target_mapping = tf.concat([target_mapping, paddings], axis=0)
      example["target_mapping"] = tf.reshape(target_mapping,
                                             [num_predict, seq_len])

      ##### target
      target = tf.boolean_mask(target, bool_target_mask)
      paddings = tf.zeros([pad_len], dtype=target.dtype)
      target = tf.concat([target, paddings], axis=0)
      example["target"] = tf.reshape(target, [num_predict])

      ##### target mask
      target_mask = tf.concat(
          [tf.ones([actual_num_predict], dtype=tf.float32),
           tf.zeros([pad_len], dtype=tf.float32)],
          axis=0)
      example["target_mask"] = tf.reshape(target_mask, [num_predict])
    else:
      example["target"] = tf.reshape(target, [seq_len])
      example["target_mask"] = tf.reshape(target_mask, [seq_len])

    # reshape back to fixed shape
    example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
    example["input_k"] = tf.reshape(input_k, [seq_len])
    example["input_q"] = tf.reshape(input_q, [seq_len])

    _convert_example(example, use_bfloat16)

    for k, v in example.items():
      tf.logging.info("%s: %s", k, v)

    return example