示例#1
0
文件: test_text.py 项目: deeprtc/io
def test_text_input():
    """test_text_input
  """
    text_filename = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 "test_text", "lorem.txt")
    with open(text_filename, 'rb') as f:
        lines = [line.strip() for line in f]
    text_filename = "file://" + text_filename

    gzip_text_filename = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "test_text",
        "lorem.txt.gz")
    gzip_text_filename = "file://" + gzip_text_filename

    lines = lines * 3
    filenames = [text_filename, gzip_text_filename, text_filename]
    dataset = text_io.TextDataset(filenames, batch=2)
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()
    with tf.compat.v1.Session() as sess:
        sess.run(init_op)
        for i in range(0, len(lines) - 2, 2):
            v = sess.run(get_next)
            assert lines[i] == v[0]
            assert lines[i + 1] == v[1]
        v = sess.run(get_next)
        assert lines[len(lines) - 1] == v[0]
        with pytest.raises(errors.OutOfRangeError):
            sess.run(get_next)
示例#2
0
def test_re2_extract():
    """test_text_input
  """
    filename = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                            "test_text", "lorem.txt")
    with open(filename, 'rb') as f:
        lines = [line.strip() for line in f]
    filename = "file://" + filename

    dataset = text_io.TextDataset(filename).map(
        lambda x: text_io.re2_full_match(x, ".+(ipsum).+(dolor).+")).apply(
            tf.data.experimental.unbatch())
    i = 0
    for v in dataset:
        r, g = v
        if re.match(".+(ipsum).+(dolor).+".encode(), lines[i]):
            assert r.numpy()
            assert g[0].numpy().decode() == "ipsum"
            assert g[1].numpy().decode() == "dolor"
        else:
            assert not r.numpy()
            assert g[0].numpy().decode() == ""
            assert g[1].numpy().decode() == ""
        i += 1
    assert i == len(lines)
示例#3
0
def test_text_input():
    """test_text_input
  """
    text_filename = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 "test_text", "lorem.txt")
    with open(text_filename, 'rb') as f:
        lines = [line.strip() for line in f]
    text_filename = "file://" + text_filename

    gzip_text_filename = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "test_text",
        "lorem.txt.gz")
    gzip_text_filename = "file://" + gzip_text_filename

    lines = lines * 3
    filenames = [text_filename, gzip_text_filename, text_filename]
    text_dataset = text_io.TextDataset(filenames, batch=2)
    i = 0
    for v in text_dataset:
        assert lines[i] == v.numpy()[0]
        i += 1
        if i < len(lines):
            assert lines[i] == v.numpy()[1]
            i += 1
    assert i == len(lines)
示例#4
0
def test_text_input():
    """test_text_input
  """
    text_filename = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 "test_text", "lorem.txt")
    with open(text_filename, 'rb') as f:
        lines = [line.strip() for line in f]
    text_filename = "file://" + text_filename

    gzip_text_filename = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "test_text",
        "lorem.txt.gz")
    gzip_text_filename = "file://" + gzip_text_filename

    lines = lines * 3
    filenames = [text_filename, gzip_text_filename, text_filename]
    text_dataset = text_io.TextDataset(filenames, batch=2)
    i = 0
    for v in text_dataset:
        assert lines[i] == v.numpy()[0]
        i += 1
        if i < len(lines):
            assert lines[i] == v.numpy()[1]
            i += 1
    assert i == len(lines)

    for batch in [1, 2, 3, 4, 5]:
        rebatch_dataset = text_dataset.apply(core_io.rebatch(batch))
        i = 0
        for v in rebatch_dataset:
            for vv in v.numpy():
                assert lines[i] == vv
                i += 1
        assert i == len(lines)

    rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "drop"))
    i = 0
    for v in rebatch_dataset:
        for vv in v.numpy():
            assert lines[i] == vv
            i += 1
    assert i == 145

    rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "pad"))
    i = 0
    for v in rebatch_dataset:
        for vv in v.numpy():
            if i < len(lines):
                assert lines[i] == vv
            else:
                assert vv.decode() == ""
            i += 1
    assert i == 150
示例#5
0
def test_text_input():
  """test_text_input
  """
  text_filename = os.path.join(
      os.path.dirname(os.path.abspath(__file__)), "test_text", "lorem.txt")
  with open(text_filename, 'rb') as f:
    lines = [line.strip() for line in f]
  text_filename = "file://" + text_filename

  text_dataset = text_io.TextDataset(text_filename).apply(
      tf.data.experimental.unbatch()).batch(2)
  i = 0
  for v in text_dataset:
    assert lines[i] == v.numpy()[0]
    i += 1
    if i < len(lines):
      assert lines[i] == v.numpy()[1]
      i += 1
  assert i == len(lines)

  for batch in [1, 2, 3, 4, 5]:
    rebatch_dataset = text_dataset.apply(core_io.rebatch(batch))
    i = 0
    for v in rebatch_dataset:
      for vv in v.numpy():
        assert lines[i] == vv
        i += 1
    assert i == len(lines)

  rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "drop"))
  i = 0
  for v in rebatch_dataset:
    for vv in v.numpy():
      assert lines[i] == vv
      i += 1
  assert i == 45

  rebatch_dataset = text_dataset.apply(core_io.rebatch(5, "pad"))
  i = 0
  for v in rebatch_dataset:
    for vv in v.numpy():
      if i < len(lines):
        assert lines[i] == vv
      else:
        assert vv.decode() == ""
      i += 1
  assert i == 50
示例#6
0
def test_text_input():
    """test_text_input
  """
    text_filename = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 "test_text", "lorem.txt")
    with open(text_filename, 'rb') as f:
        lines = [line.strip() for line in f]
    text_filename = "file://" + text_filename

    text_dataset = text_io.TextDataset(text_filename).apply(
        tf.data.experimental.unbatch()).batch(2)
    i = 0
    for v in text_dataset:
        assert lines[i] == v.numpy()[0]
        i += 1
        if i < len(lines):
            assert lines[i] == v.numpy()[1]
            i += 1
    assert i == len(lines)
示例#7
0
def test_text_output():
    """test_text_output"""
    text_filename = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 "test_text", "lorem.txt")
    with open(text_filename, 'rb') as f:
        lines = [line.strip() for line in f]
    text_filename = "file://" + text_filename

    f, filename = tempfile.mkstemp()
    os.close(f)

    df = text_io.TextDataset(text_filename)
    df = df.take(5)
    text_io.save_text(df, filename)

    with open(filename, 'rb') as f:
        saved_lines = [line.strip() for line in f]
    i = 0
    for line in saved_lines:
        assert lines[i] == line
        i += 1
    assert i == 5
示例#8
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
# License for the specific language governing permissions and limitations under
# the License.
# ==============================================================================
"""Tests for TextDataset with stdin."""

import tensorflow as tf

if not (hasattr(tf, "version") and tf.version.VERSION.startswith("2.")):
    tf.compat.v1.enable_eager_execution()
import tensorflow_io.text as text_io  # pylint: disable=wrong-import-position

# Note: run the following:
#  tshark -T fields -e frame.number -e ip.dst -e ip.proto -r attack-trace.pcap | python stdin_test.py


def f(v):
    frame_number, ip_dst, ip_proto = tf.io.decode_csv(
        v, [[0], [""], [0]], field_delim="\t"
    )
    return frame_number, ip_dst, ip_proto


text_dataset = text_io.TextDataset("file://-").map(f)

for (frame_number_value, ip_dst_value, ip_proto_value) in text_dataset:
    print(ip_dst_value.numpy())