Example #1
0
  def testMismatchPosLenShapes(self, dtype, unit):
    test_string = {
        "BYTE": [[b"ten", b"eleven", b"twelve"],
                 [b"thirteen", b"fourteen", b"fifteen"],
                 [b"sixteen", b"seventeen", b"eighteen"]],
        "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
                                                   u"\xc6\u053c\u025bv\u025bn",
                                                   u"tw\u0c1dlv\u025b"]],
                      [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
                                                   u"f\U0001f604urt\xea\xean",
                                                   u"f\xcd\ua09ctee\ua0e4"]],
                      [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
                                                   u"se\U00010299enteen",
                                                   u"ei\U0001e920h\x86een"]]],
    }[unit]
    position = np.array([[1, 2, 3]], dtype)
    length = np.array([2, 3, 4], dtype)
    # Should fail: position/length have different rank
    with self.assertRaises(ValueError):
      string_ops.substr(test_string, position, length)

    position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
    length = np.array([[2, 3, 4]], dtype)
    # Should fail: position/length have different dimensionality
    with self.assertRaises(ValueError):
      string_ops.substr(test_string, position, length)
  def _testBroadcast(self, dtype):
    # Broadcast pos/len onto input string
    test_string = [[b"ten", b"eleven", b"twelve"],
                   [b"thirteen", b"fourteen", b"fifteen"],
                   [b"sixteen", b"seventeen", b"eighteen"],
                   [b"nineteen", b"twenty", b"twentyone"]]
    position = np.array([1, 2, 3], dtype)
    length = np.array([1, 2, 3], dtype)
    expected_value = [[b"e", b"ev", b"lve"], [b"h", b"ur", b"tee"],
                      [b"i", b"ve", b"hte"], [b"i", b"en", b"nty"]]
    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)

    # Broadcast input string onto pos/len
    test_string = [b"thirteen", b"fourteen", b"fifteen"]
    position = np.array([[1, 2, 3], [3, 2, 1], [5, 5, 5]], dtype)
    length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
    expected_value = [[b"hir", b"ur", b"t"], [b"r", b"ur", b"ift"],
                      [b"ee", b"ee", b"en"]]
    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)

    # Test 1D broadcast
    test_string = b"thirteen"
    position = np.array([1, 5, 7], dtype)
    length = np.array([3, 2, 1], dtype)
    expected_value = [b"hir", b"ee", b"n"]
    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)
Example #3
0
  def testOutOfRangeError_MatrixMatrix(self, dtype, unit):
    # Matrix/Matrix
    test_string = {
        "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
                 [b"good", b"good", b"good"]],
        "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
                                                   u"g\xc3\xc3d"]],
                      [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
                                                   u"b\xc3d"]],
                      [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
                                                   u"g\xc3\xc3d"]]],
    }[unit]
    position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
    length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      with self.assertRaises(errors_impl.InvalidArgumentError):
        substr_op.eval()

    # Matrix/Matrix (with negative)
    position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
    length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      with self.assertRaises(errors_impl.InvalidArgumentError):
        substr_op.eval()
Example #4
0
    def testOutOfRangeError_MatrixMatrix(self, dtype, unit):
        # Matrix/Matrix
        test_string = {
            "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
                     [b"good", b"good", b"good"]],
            "UTF8_CHAR":
            [[
                x.encode("utf-8")
                for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", u"g\xc3\xc3d"]
            ],
             [
                 x.encode("utf-8")
                 for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", u"b\xc3d"]
             ],
             [
                 x.encode("utf-8")
                 for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", u"g\xc3\xc3d"]
             ]],
        }[unit]
        position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
        length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
        substr_op = string_ops.substr(test_string, position, length, unit=unit)
        with self.cached_session():
            with self.assertRaises(errors_impl.InvalidArgumentError):
                substr_op.eval()

        # Matrix/Matrix (with negative)
        position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
        length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
        substr_op = string_ops.substr(test_string, position, length, unit=unit)
        with self.cached_session():
            with self.assertRaises(errors_impl.InvalidArgumentError):
                substr_op.eval()
Example #5
0
 def testBadBroadcast(self, dtype, unit):
   test_string = [[b"ten", b"eleven", b"twelve"],
                  [b"thirteen", b"fourteen", b"fifteen"],
                  [b"sixteen", b"seventeen", b"eighteen"]]
   position = np.array([1, 2, -3, 4], dtype)
   length = np.array([1, 2, 3, 4], dtype)
   with self.assertRaises(ValueError):
     string_ops.substr(test_string, position, length, unit=unit)
Example #6
0
  def testScalarString_EdgeCases(self, dtype, unit):
    # Empty string
    test_string = {
        "BYTE": b"",
        "UTF8_CHAR": u"".encode("utf-8"),
    }[unit]
    expected_value = b""
    position = np.array(0, dtype)
    length = np.array(3, dtype)
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)

    # Full string
    test_string = {
        "BYTE": b"Hello",
        "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
    }[unit]
    position = np.array(0, dtype)
    length = np.array(5, dtype)
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, test_string)

    # Full string (Negative)
    test_string = {
        "BYTE": b"Hello",
        "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
    }[unit]
    position = np.array(-5, dtype)
    length = np.array(5, dtype)
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, test_string)

    # Length is larger in magnitude than a negative position
    test_string = {
        "BYTE": b"Hello",
        "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
    }[unit]
    expected_string = {
        "BYTE": b"ello",
        "UTF8_CHAR": u"\xc3ll\U0001f604".encode("utf-8"),
    }[unit]
    position = np.array(-4, dtype)
    length = np.array(5, dtype)
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_string)
  def _testMismatchPosLenShapes(self, dtype):
    test_string = [[b"ten", b"eleven", b"twelve"],
                   [b"thirteen", b"fourteen", b"fifteen"],
                   [b"sixteen", b"seventeen", b"eighteen"]]
    position = np.array([[1, 2, 3]], dtype)
    length = np.array([2, 3, 4], dtype)
    # Should fail: position/length have different rank
    with self.assertRaises(ValueError):
      substr_op = string_ops.substr(test_string, position, length)

    position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
    length = np.array([[2, 3, 4]], dtype)
    # Should fail: postion/length have different dimensionality
    with self.assertRaises(ValueError):
      substr_op = string_ops.substr(test_string, position, length)
Example #8
0
 def testElementWisePosLen(self, dtype, unit):
   test_string = {
       "BYTE": [[b"ten", b"eleven", b"twelve"],
                [b"thirteen", b"fourteen", b"fifteen"],
                [b"sixteen", b"seventeen", b"eighteen"]],
       "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
                                                  u"\xc6\u053c\u025bv\u025bn",
                                                  u"tw\u0c1dlv\u025b"]],
                     [x.encode("utf-8") for x in [u"He\xc3\xc3o",
                                                  u"W\U0001f604rld",
                                                  u"d\xfcd\xea"]],
                     [x.encode("utf-8") for x in [u"sixt\xea\xean",
                                                  u"se\U00010299enteen",
                                                  u"ei\U0001e920h\x86een"]]],
   }[unit]
   position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype)
   length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype)
   expected_value = {
       "BYTE": [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
                [b"xteen", b"vente", b"hteen"]],
       "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
                                                  u"\u025bv",
                                                  u"lv\u025b"]],
                     [x.encode("utf-8") for x in [u"e\xc3\xc3o",
                                                  u"rld",
                                                  u"d\xfc"]],
                     [x.encode("utf-8") for x in [u"xt\xea\xean",
                                                  u"\U00010299ente",
                                                  u"h\x86een"]]],
   }[unit]
   substr_op = string_ops.substr(test_string, position, length, unit=unit)
   with self.cached_session():
     substr = substr_op.eval()
     self.assertAllEqual(substr, expected_value)
  def _read_test(self, batch_size, num_epochs, file_index=None,
                 num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
    if file_index is None:
      file_pattern = self.test_filenames
    else:
      file_pattern = self.test_filenames[file_index]

    if parser_fn:
      fn = lambda x: string_ops.substr(x, 1, 999)
    else:
      fn = None

    outputs = self.getNext(
        readers.make_tf_record_dataset(
            file_pattern=file_pattern,
            num_epochs=num_epochs,
            batch_size=batch_size,
            parser_fn=fn,
            num_parallel_reads=num_parallel_reads,
            drop_final_batch=drop_final_batch,
            shuffle=False))
    self._verify_records(
        outputs,
        batch_size,
        file_index,
        num_epochs=num_epochs,
        interleave_cycle_length=num_parallel_reads,
        drop_final_batch=drop_final_batch,
        use_parser_fn=parser_fn)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(outputs())
  def _read_test(self, batch_size, num_epochs, file_index=None,
                 num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
    if file_index is None:
      file_pattern = self.test_filenames
    else:
      file_pattern = self.test_filenames[file_index]

    if parser_fn:
      fn = lambda x: string_ops.substr(x, 1, 999)
    else:
      fn = None

    with ops.Graph().as_default() as g:
      with self.session(graph=g) as sess:
        outputs = readers.make_tf_record_dataset(
            file_pattern=file_pattern,
            num_epochs=num_epochs,
            batch_size=batch_size,
            parser_fn=fn,
            num_parallel_reads=num_parallel_reads,
            drop_final_batch=drop_final_batch,
            shuffle=False).make_one_shot_iterator().get_next()
        self._verify_records(
            sess, outputs, batch_size, file_index, num_epochs=num_epochs,
            interleave_cycle_length=num_parallel_reads,
            drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(outputs)
Example #11
0
  def _testOutOfRangeError(self, dtype):
    # Scalar/Scalar
    test_string = b"Hello"
    position = np.array(7, dtype)
    length = np.array(3, dtype)
    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      with self.assertRaises(errors_impl.InvalidArgumentError):
        substr = substr_op.eval()

    # Vector/Scalar
    test_string = [b"good", b"good", b"bad", b"good"]
    position = np.array(3, dtype)
    length = np.array(1, dtype)
    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      with self.assertRaises(errors_impl.InvalidArgumentError):
        substr = substr_op.eval()

    # Negative pos
    test_string = b"Hello"
    position = np.array(-1, dtype)
    length = np.array(3, dtype)
    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      with self.assertRaises(errors_impl.InvalidArgumentError):
        substr = substr_op.eval()

    # Matrix/Matrix
    test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
                   [b"good", b"good", b"good"]]
    position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
    length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      with self.assertRaises(errors_impl.InvalidArgumentError):
        substr = substr_op.eval()

    # Broadcast
    test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
    position = np.array([1, 2, 3], dtype)
    length = np.array([1, 2, 3], dtype)
    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      with self.assertRaises(errors_impl.InvalidArgumentError):
        substr = substr_op.eval()
Example #12
0
  def testMatrixStrings(self, dtype, unit):
    test_string = {
        "BYTE": [[b"ten", b"eleven", b"twelve"],
                 [b"thirteen", b"fourteen", b"fifteen"],
                 [b"sixteen", b"seventeen", b"eighteen"]],
        "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
                                                   u"\xc6\u053c\u025bv\u025bn",
                                                   u"tw\u0c1dlv\u025b"]],
                      [x.encode("utf-8") for x in [u"He\xc3\xc3o",
                                                   u"W\U0001f604rld",
                                                   u"d\xfcd\xea"]]],
    }[unit]
    position = np.array(1, dtype)
    length = np.array(4, dtype)
    expected_value = {
        "BYTE": [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
                 [b"ixte", b"even", b"ight"]],
        "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
                                                   u"\u053c\u025bv\u025b",
                                                   u"w\u0c1dlv"]],
                      [x.encode("utf-8") for x in [u"e\xc3\xc3o",
                                                   u"\U0001f604rld",
                                                   u"\xfcd\xea"]]],
    }[unit]
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)

    position = np.array(-3, dtype)
    length = np.array(2, dtype)
    expected_value = {
        "BYTE": [[b"te", b"ve", b"lv"], [b"ee", b"ee", b"ee"],
                 [b"ee", b"ee", b"ee"]],
        "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227",
                                                   u"v\u025b", u"lv"]],
                      [x.encode("utf-8") for x in [u"\xc3\xc3", u"rl",
                                                   u"\xfcd"]]],
    }[unit]
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)
    def _testScalarString(self, dtype):
        test_string = b"Hello"
        position = np.array(1, dtype)
        length = np.array(3, dtype)
        expected_value = b"ell"

        substr_op = string_ops.substr(test_string, position, length)
        with self.test_session():
            substr = substr_op.eval()
            self.assertAllEqual(substr, expected_value)
Example #14
0
 def _testBadBroadcast(self, dtype):
   test_string = [[b"ten", b"eleven", b"twelve"],
                  [b"thirteen", b"fourteen", b"fifteen"],
                  [b"sixteen", b"seventeen", b"eighteen"]]
   position = np.array([1, 2, 3, 4], dtype)
   length = np.array([1, 2, 3, 4], dtype)
   expected_value = [[b"e", b"ev", b"lve"], [b"h", b"ur", b"tee"],
                     [b"i", b"ve", b"hte"]]
   with self.assertRaises(ValueError):
     substr_op = string_ops.substr(test_string, position, length)
 def _testBadBroadcast(self, dtype):
   test_string = [[b"ten", b"eleven", b"twelve"],
                  [b"thirteen", b"fourteen", b"fifteen"],
                  [b"sixteen", b"seventeen", b"eighteen"]]
   position = np.array([1, 2, 3, 4], dtype)
   length = np.array([1, 2, 3, 4], dtype)
   expected_value = [[b"e", b"ev", b"lve"], [b"h", b"ur", b"tee"],
                     [b"i", b"ve", b"hte"]]
   with self.assertRaises(ValueError):
     substr_op = string_ops.substr(test_string, position, length)
Example #16
0
  def _testVectorStrings(self, dtype):
    test_string = [b"Hello", b"World"]
    position = np.array(1, dtype)
    length = np.array(3, dtype)
    expected_value = [b"ell", b"orl"]

    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)
Example #17
0
  def _testScalarString(self, dtype):
    test_string = b"Hello"
    position = np.array(1, dtype)
    length = np.array(3, dtype)
    expected_value = b"ell"

    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)
 def testForwarding(self, op):
   with self.test_session():
     # Generate an input that is uniquely consumed by the regex op.
     # This exercises code paths which are optimized for this case
     # (e.g., using forwarding).
     inp = string_ops.substr(
         constant_op.constant(["AbCdEfG",
                               "HiJkLmN"], dtypes.string),
         pos=0,
         len=5)
     stripped = op(inp, "\\p{Ll}", ".").eval()
     self.assertAllEqual([b"A.C.E", b"H.J.L"], stripped)
Example #19
0
 def testOutOfRangeError_Scalar(self, dtype, pos, unit):
   # Scalar/Scalar
   test_string = {
       "BYTE": b"Hello",
       "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
   }[unit]
   position = np.array(pos, dtype)
   length = np.array(3, dtype)
   substr_op = string_ops.substr(test_string, position, length, unit=unit)
   with self.cached_session():
     with self.assertRaises(errors_impl.InvalidArgumentError):
       substr_op.eval()
def decode_image(contents, channels=None, name=None):
  """Convenience function for `decode_gif`, `decode_jpeg`, and `decode_png`.
  Detects whether an image is a GIF, JPEG, or PNG, and performs the appropriate 
  operation to convert the input bytes `string` into a `Tensor` of type `uint8`.

  Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as 
  opposed to `decode_jpeg` and `decode_png`, which return 3-D arrays 
  `[height, width, num_channels]`. Make sure to take this into account when 
  constructing your graph if you are intermixing GIF files with JPEG and/or PNG 
  files.

  Args:
    contents: 0-D `string`. The encoded image bytes.
    channels: An optional `int`. Defaults to `0`. Number of color channels for 
      the decoded image.
    name: A name for the operation (optional)
    
  Returns:
    `Tensor` with type `uint8` with shape `[height, width, num_channels]` for 
      JPEG and PNG images and shape `[num_frames, height, width, 3]` for GIF 
      images.
  """
  with ops.name_scope(name, 'decode_image') as scope:
    if channels not in (None, 0, 1, 3):
      raise ValueError('channels must be in (None, 0, 1, 3)')
    substr = string_ops.substr(contents, 0, 4)

    def _gif():
      # Create assert op to check that bytes are GIF decodable
      is_gif = math_ops.equal(substr, b'\x47\x49\x46\x38', name='is_gif')
      decode_msg = 'Unable to decode bytes as JPEG, PNG, or GIF'
      assert_decode = control_flow_ops.Assert(is_gif, [decode_msg])
      # Create assert to make sure that channels is not set to 1
      # Already checked above that channels is in (None, 0, 1, 3)
      gif_channels = 0 if channels is None else channels
      good_channels = math_ops.not_equal(gif_channels, 1, name='check_channels')
      channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images'
      assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
      with ops.control_dependencies([assert_decode, assert_channels]):
        return gen_image_ops.decode_gif(contents)

    def _png():
      return gen_image_ops.decode_png(contents, channels)

    def check_png():
      is_png = math_ops.equal(substr, b'\211PNG', name='is_png')
      return control_flow_ops.cond(is_png, _png, _gif, name='cond_png')

    def _jpeg():
      return gen_image_ops.decode_jpeg(contents, channels)

    is_jpeg = math_ops.equal(substr, b'\xff\xd8\xff\xe0', name='is_jpeg')
    return control_flow_ops.cond(is_jpeg, _jpeg, check_png, name='cond_jpeg')
Example #21
0
 def testOutOfRangeError_VectorScalar(self, dtype, pos, unit):
   # Vector/Scalar
   test_string = {
       "BYTE": [b"good", b"good", b"bad", b"good"],
       "UTF8_CHAR": [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"b\xc3d",
                                                 u"g\xc3\xc3d"]],
   }[unit]
   position = np.array(pos, dtype)
   length = np.array(1, dtype)
   substr_op = string_ops.substr(test_string, position, length, unit=unit)
   with self.cached_session():
     with self.assertRaises(errors_impl.InvalidArgumentError):
       substr_op.eval()
Example #22
0
  def _testMatrixStrings(self, dtype):
    test_string = [[b"ten", b"eleven", b"twelve"],
                   [b"thirteen", b"fourteen", b"fifteen"],
                   [b"sixteen", b"seventeen", b"eighteen"]]
    position = np.array(1, dtype)
    length = np.array(4, dtype)
    expected_value = [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
                      [b"ixte", b"even", b"ight"]]

    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)
Example #23
0
  def _testElementWisePosLen(self, dtype):
    test_string = [[b"ten", b"eleven", b"twelve"],
                   [b"thirteen", b"fourteen", b"fifteen"],
                   [b"sixteen", b"seventeen", b"eighteen"]]
    position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
    length = np.array([[2, 3, 4], [4, 3, 2], [5, 5, 5]], dtype)
    expected_value = [[b"en", b"eve", b"lve"], [b"hirt", b"urt", b"te"],
                      [b"ixtee", b"vente", b"hteen"]]

    substr_op = string_ops.substr(test_string, position, length)
    with self.test_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)
  def test_forwarding(self):
    with self.cached_session():
      # Generate an input that is uniquely consumed by the transcode op.
      # This exercises code paths which are optimized for this case
      # (e.g., using forwarding).
      inp = string_ops.substr(
          constant_op.constant([b"AbCdEfG", b"HiJkLmN"], dtypes.string),
          pos=0,
          len=5)
      transcoded = string_ops.unicode_transcode(
          inp, input_encoding="UTF-8", output_encoding="UTF-8")

      self.assertAllEqual([b"AbCdE", b"HiJkL"], transcoded)
  def testOutOfRangeError_Broadcast(self, dtype, unit):
    # Broadcast
    test_string = {
        "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]],
        "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
                                                   u"g\xc3\xc3d"]],
                      [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
                                                   u"b\xc3d"]]],
    }[unit]
    position = np.array([1, 2, 4], dtype)
    length = np.array([1, 2, 3], dtype)
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      with self.assertRaises(errors_impl.InvalidArgumentError):
        self.evaluate(substr_op)

    # Broadcast (with negative)
    position = np.array([-1, -2, -4], dtype)
    length = np.array([1, 2, 3], dtype)
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      with self.assertRaises(errors_impl.InvalidArgumentError):
        self.evaluate(substr_op)
Example #26
0
 def testScalarString(self, dtype, pos, unit):
   test_string = {
       "BYTE": b"Hello",
       "UTF8_CHAR": u"He\xc3\xc3\U0001f604".encode("utf-8"),
   }[unit]
   expected_value = {
       "BYTE": b"ell",
       "UTF8_CHAR": u"e\xc3\xc3".encode("utf-8"),
   }[unit]
   position = np.array(pos, dtype)
   length = np.array(3, dtype)
   substr_op = string_ops.substr(test_string, position, length, unit=unit)
   with self.cached_session():
     substr = substr_op.eval()
     self.assertAllEqual(substr, expected_value)
Example #27
0
 def testVectorStrings(self, dtype, pos, unit):
   test_string = {
       "BYTE": [b"Hello", b"World"],
       "UTF8_CHAR": [x.encode("utf-8") for x in [u"H\xc3llo",
                                                 u"W\U0001f604rld"]],
   }[unit]
   expected_value = {
       "BYTE": [b"ell", b"orl"],
       "UTF8_CHAR": [x.encode("utf-8") for x in [u"\xc3ll", u"\U0001f604rl"]],
   }[unit]
   position = np.array(pos, dtype)
   length = np.array(3, dtype)
   substr_op = string_ops.substr(test_string, position, length, unit=unit)
   with self.cached_session():
     substr = substr_op.eval()
     self.assertAllEqual(substr, expected_value)
Example #28
0
    def testStateBasedSentenceBreaker(self, test_description, doc,
                                      expected_fragment_text):
        input = constant_op.constant(doc)  # pylint: disable=redefined-builtin
        sentence_breaker = (
            state_based_sentence_breaker_op.StateBasedSentenceBreaker())
        fragment_text, fragment_starts, fragment_ends = (
            sentence_breaker.break_sentences_with_offsets(input))

        texts, starts, ends = self.evaluate(
            (fragment_text, fragment_starts, fragment_ends))
        self.assertAllEqual(expected_fragment_text, fragment_text)
        for d, text, start, end in zip(doc, texts.to_list(), starts.to_list(),
                                       ends.to_list()):
            # broadcast d to match start/end's shape
            start = constant_op.constant(start)
            end = constant_op.constant(end)
            d = array_ops.broadcast_to(d, start.shape)
            self.assertAllEqual(string_ops.substr(d, start, end - start), text)
 def _make_csv_dataset(self,
                       filenames,
                       defaults,
                       label_key=LABEL,
                       batch_size=1,
                       num_epochs=1,
                       shuffle=False,
                       shuffle_seed=None):
     return readers.make_csv_dataset(
         filenames,
         column_keys=self.COLUMNS,
         column_defaults=defaults,
         label_key=label_key,
         batch_size=batch_size,
         num_epochs=num_epochs,
         shuffle=shuffle,
         shuffle_seed=shuffle_seed,
         skip=1,
         filter_fn=lambda line: math_ops.not_equal(
             string_ops.substr(line, 0, 1), "#"),
     )
 def _make_csv_dataset(self,
                       filenames,
                       defaults,
                       label_key=LABEL,
                       batch_size=1,
                       num_epochs=1,
                       shuffle=False,
                       shuffle_seed=None):
   return readers.make_csv_dataset(
       filenames,
       column_keys=self.COLUMNS,
       column_defaults=defaults,
       label_key=label_key,
       batch_size=batch_size,
       num_epochs=num_epochs,
       shuffle=shuffle,
       shuffle_seed=shuffle_seed,
       skip=1,
       filter_fn=
       lambda line: math_ops.not_equal(string_ops.substr(line, 0, 1), "#"),
   )
Example #31
0
    def testSingleStringHighRankFails(self, dtype, unit, rank):

        test_string = {
            "BYTE": [b"abcdefghijklmnopqrstuvwxyz"],
            "UTF8_CHAR": [
                u"\U0001d229\U0001d227n\U0001d229\U0001d227n\U0001d229\U0001d227n"
                .encode("utf-8")
            ],
        }[unit]
        position = np.array([1, 2, 3], dtype)
        length = np.array([1, 2, 1], dtype)

        test_string_tensor = np.array(test_string)
        for _ in range(rank - 1):
            test_string_tensor = np.expand_dims(test_string_tensor, axis=0)

        with self.assertRaises(errors_impl.UnimplementedError):
            # substr is only supported up to rank 2
            substr_op = string_ops.substr(test_string_tensor,
                                          position,
                                          length,
                                          unit=unit)
            self.evaluate(substr_op)
Example #32
0
 def filter_fn(line):
   return math_ops.not_equal(string_ops.substr(line, 0, 1), comment)
Example #33
0
 def testInvalidUnit(self):
   with self.cached_session():
     with self.assertRaises(ValueError):
       string_ops.substr(b"test", 3, 1, unit="UTF8")
Example #34
0
  def testBroadcast(self, dtype, unit):
    # Broadcast pos/len onto input string
    test_string = {
        "BYTE": [[b"ten", b"eleven", b"twelve"],
                 [b"thirteen", b"fourteen", b"fifteen"],
                 [b"sixteen", b"seventeen", b"eighteen"],
                 [b"nineteen", b"twenty", b"twentyone"]],
        "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
                                                   u"\xc6\u053c\u025bv\u025bn",
                                                   u"tw\u0c1dlv\u025b"]],
                      [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
                                                   u"f\U0001f604urt\xea\xean",
                                                   u"f\xcd\ua09ctee\ua0e4"]],
                      [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
                                                   u"se\U00010299enteen",
                                                   u"ei\U0001e920h\x86een"]],
                      [x.encode("utf-8") for x in [u"nineteen",
                                                   u"twenty",
                                                   u"twentyone"]]],
    }[unit]
    position = np.array([1, -4, 3], dtype)
    length = np.array([1, 2, 3], dtype)
    expected_value = {
        "BYTE": [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
                 [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]],
        "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227",
                                                   u"\u025bv", u"lv\u025b"]],
                      [x.encode("utf-8") for x in [u"h", u"t\xea", u"tee"]],
                      [x.encode("utf-8") for x in [u"\xcd", u"te", u"h\x86e"]],
                      [x.encode("utf-8") for x in [u"i", u"en", u"nty"]]],
    }[unit]
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)

    # Broadcast input string onto pos/len
    test_string = {
        "BYTE": [b"thirteen", b"fourteen", b"fifteen"],
        "UTF8_CHAR": [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
                                                  u"f\U0001f604urt\xea\xean",
                                                  u"f\xcd\ua09ctee\ua0e4"]],
    }[unit]
    position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype)
    length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
    expected_value = {
        "BYTE": [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
                 [b"ee", b"ee", b"ft"]],
        "UTF8_CHAR": [[x.encode("utf-8") for x in [u"h\xcdr", u"\xean", u"t"]],
                      [x.encode("utf-8") for x in [u"\xea", u"ur",
                                                   u"\xcd\ua09ct"]],
                      [x.encode("utf-8") for x in [u"\xea\xea", u"\xea\xea",
                                                   u"\ua09ct"]]],
    }[unit]
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)

    # Test 1D broadcast
    test_string = {
        "BYTE": b"thirteen",
        "UTF8_CHAR": u"th\xcdrt\xea\xean".encode("utf-8"),
    }[unit]
    position = np.array([1, -4, 7], dtype)
    length = np.array([3, 2, 1], dtype)
    expected_value = {
        "BYTE": [b"hir", b"te", b"n"],
        "UTF8_CHAR": [x.encode("utf-8") for x in [u"h\xcdr", u"t\xea", u"n"]],
    }[unit]
    substr_op = string_ops.substr(test_string, position, length, unit=unit)
    with self.cached_session():
      substr = substr_op.eval()
      self.assertAllEqual(substr, expected_value)
Example #35
0
 def testWrongDtype(self):
   with self.test_session():
     with self.assertRaises(TypeError):
       string_ops.substr(b"test", 3.0, 1)
     with self.assertRaises(TypeError):
       string_ops.substr(b"test", 3, 1.0)