示例#1
0
def test_simple_data_source(test_data_csv_png_20, shuffle):
    src_data = []
    with open(test_data_csv_png_20) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(os.path.dirname(test_data_csv_png_20),
                                         values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, int(values[1])))

    def test_load_func(position):
        return src_data[position]

    size = len(src_data)
    ds = SimpleDataSource(test_load_func, size, shuffle=shuffle)
    order = []
    for i in range(ds.size):
        data, label = ds.next()
        assert data[0][0][0] == label
        order.append(label)
    if shuffle:
        assert not list(range(size)) == order
        assert list(range(size)) == sorted(order)
    else:
        assert list(range(size)) == order
示例#2
0
def test_simple_data_source(test_data_csv_png_20, shuffle):
    src_data = []
    with open(test_data_csv_png_20) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(
                os.path.dirname(test_data_csv_png_20), values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, int(values[1])))

    def test_load_func(position):
        return src_data[position]

    size = len(src_data)
    ds = SimpleDataSource(test_load_func, size, shuffle=shuffle)
    order = []
    for i in range(ds.size):
        data, label = ds.next()
        assert data[0][0][0] == label
        order.append(label)
    if shuffle:
        assert not list(range(size)) == order
        assert list(range(size)) == sorted(order)
    else:
        assert list(range(size)) == order
示例#3
0
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle):
    src_data = []
    with open(test_data_csv_png_10) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(
                os.path.dirname(test_data_csv_png_10), values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, [int(values[1])]))

    def test_load_func(position):
        return src_data[position]

    size = len(src_data)
    with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle) as di:
        check_data_iterator_result(di, batch_size, shuffle, False)
示例#4
0
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle, stop_exhausted):
    src_data = []
    with open(test_data_csv_png_10) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(
                os.path.dirname(test_data_csv_png_10), values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, [int(values[1])]))

    def test_load_func(position):
        return src_data[position]

    size = len(src_data)
    with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle, stop_exhausted=stop_exhausted) as di:
        check_data_iterator_result(
            di, batch_size, shuffle, False, stop_exhausted)
示例#5
0
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle,
                              stop_exhausted):
    src_data = []
    with open(test_data_csv_png_10) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(os.path.dirname(test_data_csv_png_10),
                                         values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, [int(values[1])]))

    def test_load_func(position):
        return src_data[position]

    def end_epoch(epoch):
        print(f"{epoch} == {expect_epoch[0]}")
        assert epoch == expect_epoch[0], "Failed for end epoch check"
        assert threading.current_thread(
        ).ident == main_thread, "Failed for thread checking"

    def begin_epoch(epoch):
        print(f"{epoch} == {expect_epoch[0]}")
        assert epoch == expect_epoch[0], "Failed for begin epoch check"
        assert threading.current_thread(
        ).ident == main_thread, "Failed for thread checking"

    size = len(src_data)
    main_thread = threading.current_thread().ident
    expect_epoch = [0]
    with data_iterator_simple(test_load_func,
                              size,
                              batch_size,
                              shuffle=shuffle,
                              stop_exhausted=stop_exhausted) as di:
        if batch_size // size == 0:
            di.register_epoch_end_callback(begin_epoch)
            di.register_epoch_end_callback(end_epoch)
        di.register_epoch_end_callback(begin_epoch)
        di.register_epoch_end_callback(end_epoch)
        check_data_iterator_result(di, batch_size, shuffle, False,
                                   stop_exhausted, expect_epoch)
示例#6
0
 def _get_data(self, _position):
     position = self._indexes[_position]
     img = load_image(self.paths[position])
     if img.shape[0] == 1:
         img = img.repeat(3, 0)
     return (img, self.labels[position])