def testBoolDType(self, use_tf_func):
    tmp = os.path.join(tf.test.get_temp_dir(), 'bool')
    with tf.python_io.TFRecordWriter(tmp) as w:
      for i in range(50):
        w.write(pickle.dumps(True if i % 2 == 0 else False))

    # A record processor written in TF graph.
    def _process(record):
      bucket_key = 1
      num, = tf.py_func(pickle.loads, [record], [tf.bool])
      return [num], bucket_key

    # Samples random records from the data files and processes them
    # to generate batches.
    resource, out_types, output_tmpl = get_test_input(
        tmp, bucket_upper_bound=[1], processor=_process)

    def _get_batch():
      return generic_input.GenericInputV2GetNext(resource, out_types,
                                                 output_tmpl)

    if use_tf_func:
      _get_batch = tf.function(autograph=False)(_get_batch)  # pylint: disable=invalid-name

    for _ in range(10):
      inputs_vals, _ = _get_batch()
      self.assertEqual(inputs_vals[0].dtype, bool)
  def testBasic(self, use_nested_map, bucket_batch_limit, use_tf_func):

    resource, out_types, output_tmpl = setup_basic(
        use_nested_map=use_nested_map, bucket_batch_limit=bucket_batch_limit)

    def _get_batch():
      return run_basic(
          resource,
          use_nested_map=use_nested_map,
          out_types=out_types,
          output_tmpl=output_tmpl)

    if use_tf_func:
      _get_batch = tf.function(autograph=False)(_get_batch)  # pylint: disable=invalid-name

    record_seen = set()
    # Iterate for 1 epoch
    for _ in range(100 // bucket_batch_limit + 1):
      ans_input_batch = _get_batch()
      for s in ans_input_batch.record:
        record_seen.add(s.numpy())
      self.assertEqual(ans_input_batch.source_id.shape, (bucket_batch_limit,))
      self.assertEqual(ans_input_batch.record.shape, (bucket_batch_limit,))
      self.assertEqual(ans_input_batch.num.shape, (bucket_batch_limit, 2))
      ans_vals = ans_input_batch.num
      self.assertAllEqual(np.square(ans_vals[:, 0]), ans_vals[:, 1])
    for i in range(100):
      self.assertIn(('%08d' % i).encode('utf-8'), record_seen)
  def testWithinBatchMixing(self, use_tf_func):
    # Generate couple files.
    def generate_test_data(tag, cnt):
      tmp = os.path.join(tf.test.get_temp_dir(), tag)
      with tf.python_io.TFRecordWriter(tmp) as w:
        for i in range(cnt):
          w.write(('%s:%08d' % (tag, i)).encode('utf-8'))
      return tmp

    path1 = generate_test_data('input1', 100)
    path2 = generate_test_data('input2', 200)
    path3 = generate_test_data('input3', 10)

    # A record processor written in TF graph.
    def _process(source_id, record):
      return py_utils.NestedMap(source_id=source_id, record=record), 1

    # Samples random records from the data files and processes them
    # to generate batches.
    resource, out_types, output_tmpl = generic_input.GenericInputV2Create(
        file_pattern=','.join(
            ['tfrecord:' + path1, 'tfrecord:' + path2, 'tfrecord:' + path3]),
        input_source_weights=[0.2, 0.3, 0.5],
        file_random_seed=0,
        file_buffer_size=32,
        file_parallelism=4,
        bucket_batch_limit=[8],
        bucket_upper_bound=[1],
        processor=_process)

    def _get_batch():
      return generic_input.GenericInputV2GetNext(resource, out_types,
                                                 output_tmpl)

    if use_tf_func:
      _get_batch = tf.function(autograph=False)(_get_batch)  # pylint: disable=invalid-name

    source_id_count = collections.defaultdict(int)
    tags_count = collections.defaultdict(int)
    total_count = 10000
    for _ in range(total_count):
      ans_input_batch, ans_buckets = _get_batch()
      for s in ans_input_batch.source_id:
        # We use `numpy()` to get Tensor's value
        source_id_count[s.numpy()] += 1
      for s in ans_input_batch.record:
        # We use `numpy()` to get Tensor's value
        tags_count[s.numpy().split(b':')[0]] += 1
      self.assertEqual(ans_input_batch.source_id.shape, (8,))
      self.assertEqual(ans_input_batch.record.shape, (8,))
      self.assertAllEqual(ans_buckets, [1] * 8)
    self.assertEqual(sum(source_id_count.values()), total_count * 8)
    self.assertEqual(sum(tags_count.values()), total_count * 8)
    num_records = 8. * total_count
    self.assertAlmostEqual(tags_count[b'input1'] / num_records, 0.2, delta=0.01)
    self.assertAlmostEqual(tags_count[b'input2'] / num_records, 0.3, delta=0.01)
    self.assertAlmostEqual(tags_count[b'input3'] / num_records, 0.5, delta=0.01)
    self.assertAlmostEqual(source_id_count[0] / num_records, 0.2, delta=0.01)
    self.assertAlmostEqual(source_id_count[1] / num_records, 0.3, delta=0.01)
    self.assertAlmostEqual(source_id_count[2] / num_records, 0.5, delta=0.01)
  def testDropRecordIfNegativeBucketKey(self, use_tf_func):

    def bucket_fn(num):
      # Drops record if num[0] is odd.
      return tf.cond(
          tf.equal(tf.math.floormod(num[0], 2), 0), lambda: 1,
          lambda: -tf.cast(num[0], tf.int32))

    resource, out_types, output_tmpl = setup_basic(
        use_nested_map=False, bucket_fn=bucket_fn)

    def _get_batch():
      return run_basic(
          resource,
          use_nested_map=False,
          out_types=out_types,
          output_tmpl=output_tmpl)

    if use_tf_func:
      _get_batch = tf.function(autograph=False)(_get_batch)  # pylint: disable=invalid-name

    record_seen = set()
    for i in range(100):
      ans_input_batch = _get_batch()
      for s in ans_input_batch.record:
        record_seen.add(s.numpy())
    for i in range(100):
      if i % 2 == 0:
        self.assertIn(('%08d' % i).encode('utf-8'), record_seen)
      else:
        self.assertNotIn(('%08d' % i).encode('utf-8'), record_seen)
  def testPadding(self, use_tf_func):
    # Generate a test file w/ 50 records of different lengths.
    tmp = os.path.join(tf.test.get_temp_dir(), 'basic')
    with tf.python_io.TFRecordWriter(tmp) as w:
      for n in range(1, 50):
        w.write(pickle.dumps(np.full([n, 3, 3], n, np.int32)))

    # A record processor written in TF graph.
    def _process(record):
      num = tf.py_func(pickle.loads, [record], tf.int32)
      bucket_key = tf.shape(num)[0]
      return [num, tf.transpose(num, [1, 0, 2])], bucket_key

    # Samples random records from the data files and processes them
    # to generate batches.
    resource, out_types, output_tmpl = get_test_input(
        tmp,
        bucket_upper_bound=[10],
        processor=_process,
        dynamic_padding_dimensions=[0, 1],
        dynamic_padding_constants=[0] * 2)

    def _get_batch():
      return generic_input.GenericInputV2GetNext(resource, out_types,
                                                 output_tmpl)

    if use_tf_func:
      _get_batch = tf.function(autograph=False)(_get_batch)  # pylint: disable=invalid-name

    for _ in range(10):
      (vals, transposed_vals), _ = _get_batch()
      self.assertEqual(vals.shape[0], 8)
      self.assertEqual(vals.shape[2], 3)
      self.assertEqual(vals.shape[3], 3)
      largest = np.amax(vals)
      self.assertLessEqual(largest, 10)
      self.assertEqual(vals.shape[1], largest)
      for j in range(8):
        n = vals[j, 0, 0, 0]
        self.assertTrue(np.all(vals[j, :n] == n))
        self.assertTrue(np.all(vals[j, n:] == 0))
      self.assertAllEqual(vals, np.transpose(transposed_vals, [0, 2, 1, 3]))
Beispiel #6
0
    def __init__(self, params):
        super().__init__(params)
        p = self.params
        for required_param_name in ('name', 'module_path'):
            if not p.Get(required_param_name):
                raise ValueError(f'Must set {required_param_name} param.')

        with tf.variable_scope(p.name):
            # NB: `trainable` merely controls whether the model *can* be run in
            # training mode.
            self._module = hub.KerasLayer(p.module_path, trainable=True)

        _WrapNonLingvoVars(
            self,
            variables=self._module.variables,
            trainable_variables=self._module.trainable_variables)

        # Functionalize the module's __call__ so train-mode update ops run eagerly.
        self._hub_module_fn = tf.function(
            lambda images, training: self._module(images, training=training))
  def testTfData(self, use_tf_func):
    """Checks that GenericInput can be invoked from a tf.data.Dataset."""

    resource, out_types, output_tmpl = setup_basic(use_nested_map=True)

    def _get_batch():
      return run_basic(
          resource,
          use_nested_map=False,
          out_types=out_types,
          output_tmpl=output_tmpl)

    if use_tf_func:
      _get_batch = tf.function(autograph=False)(_get_batch)  # pylint: disable=invalid-name

    # Trick to create dataset from tensor coming from custom op.
    dummy_dataset = tf.data.Dataset.from_tensors(0).repeat()
    dataset = dummy_dataset.map(lambda _: _get_batch())

    it = iter(dataset)
    for _ in range(10):  # Read 10 batches.
      print(it.get_next())