Ejemplo n.º 1
0
def test_create_cache(test_data_csv_csv_20, test_data_csv_png_20,
                      input_file_fmt, cache_file_fmt, shuffle, normalize,
                      num_of_threads):
    if input_file_fmt == 'csv':
        csvfilename = test_data_csv_csv_20
    else:
        csvfilename = test_data_csv_png_20

    nnabla_config.set('DATA_ITERATOR', 'cache_file_format', cache_file_fmt)

    with create_temp_with_dir() as tmpdir:
        cc = CreateCache(csvfilename,
                         shuffle=shuffle,
                         num_of_threads=num_of_threads)
        cc.create(tmpdir, normalize=normalize)

        # get cache data source and csv file data source
        with closing(CacheDataSource(tmpdir)) as cache_source:
            csv_source = CsvDataSource(csvfilename, normalize=normalize)

            check_relative_csv_file_result(cache_file_fmt, csvfilename, tmpdir)

            assert cache_source.size == csv_source.size
            assert set(cache_source.variables) == set(csv_source.variables)

            if shuffle:
                with open(os.path.join(tmpdir, 'order.csv'), 'r') as f:
                    csv_source._order = [int(row[1]) for row in csv.reader(f)]

            for _ in range(cache_source.size):
                cache_data = associate_variables_and_data(cache_source)
                csv_data = associate_variables_and_data(csv_source)

                for v in cache_source.variables:
                    assert_allclose(cache_data[v], csv_data[v])
Ejemplo n.º 2
0
def test_data_iterator_concat_datasets(test_data_csv_png_10,
                                       test_data_csv_png_20,
                                       batch_size,
                                       shuffle,
                                       use_thread,
                                       normalize,
                                       with_memory_cache,
                                       with_file_cache,
                                       with_context,
                                       stop_exhausted):

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_max_size', '10000')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_num_of_data', '9')

    csvfilename_1 = test_data_csv_png_10
    csvfilename_2 = test_data_csv_png_20

    ds1 = CsvDataSource(csvfilename_1,
                        shuffle=shuffle,
                        normalize=normalize)

    ds2 = CsvDataSource(csvfilename_2,
                        shuffle=shuffle,
                        normalize=normalize)

    if with_context:
        with data_iterator_concat_datasets([ds1, ds2],
                                           batch_size=batch_size,
                                           shuffle=shuffle,
                                           with_memory_cache=with_memory_cache,
                                           with_file_cache=with_file_cache,
                                           use_thread=use_thread,
                                           stop_exhausted=stop_exhausted) as di:
            check_data_iterator_concat_result(
                di, batch_size, normalize, ds1.size, ds2.size, stop_exhausted)
    else:
        di = data_iterator_concat_datasets([ds1, ds2],
                                           batch_size=batch_size,
                                           shuffle=shuffle,
                                           with_memory_cache=with_memory_cache,
                                           with_file_cache=with_file_cache,
                                           use_thread=use_thread,
                                           stop_exhausted=stop_exhausted)
        check_data_iterator_concat_result(
            di, batch_size, normalize, ds1.size, ds2.size, stop_exhausted)
        di.close()
Ejemplo n.º 3
0
def conv_dataset_command(args):
    if os.path.exists(args.destination):
        if not args.force:
            print(
                'File or directory [{}] is exists use `-F` option to overwrite it.'
                .format(args.destination))
            return False
        elif os.path.isdir(args.destination):
            print('Overwrite destination [{}].'.format(args.destination))
            shutil.rmtree(args.destination, ignore_errors=True)
            os.mkdir(args.destination)
        else:
            print('Cannot overwrite file [{}] please delete it.'.format(
                args.destination))
            return False
    else:
        os.mkdir(args.destination)
    datasource = None
    _, ext = os.path.splitext(args.source)
    if ext.lower() == '.csv':
        with CsvDataSource(args.source,
                           shuffle=args.shuffle,
                           normalize=args.normalize) as source:
            _convert(args, source)
    elif ext.lower() == '.cache':
        with CacheDataSource(args.source,
                             shuffle=args.shuffle,
                             normalize=args.normalize) as source:
            _convert(args, source)
    else:
        print('Command `conv_dataset` only supports CSV or CACHE as source.')
    return True
Ejemplo n.º 4
0
def test_csv_data_source(test_data_csv_csv_10, test_data_csv_csv_20, size, shuffle):

    if size == 10:
        csvfilename = test_data_csv_csv_10
    elif size == 20:
        csvfilename = test_data_csv_csv_20

    cds = CsvDataSource(csvfilename, shuffle)
    cds.reset()
    order = []
    for n in range(0, cds.size):
        data, label = cds.next()
        assert data[0][0] == label[0]
        order.append(int(round(data[0][0])))
    if shuffle:
        assert not list(range(size)) == order
        assert list(range(size)) == sorted(order)
    else:
        assert list(range(size)) == order
Ejemplo n.º 5
0
def test_csv_data_source(test_data_csv_csv_10, test_data_csv_csv_20, size, shuffle):

    if size == 10:
        csvfilename = test_data_csv_csv_10
    elif size == 20:
        csvfilename = test_data_csv_csv_20

    cds = CsvDataSource(csvfilename, shuffle)
    cds.reset()
    order = []
    for n in range(0, cds.size):
        data, label = cds.next()
        assert data[0][0] == label[0]
        order.append(int(round(data[0][0])))
    if shuffle:
        assert not list(range(size)) == order
        assert list(range(size)) == sorted(order)
    else:
        assert list(range(size)) == order
Ejemplo n.º 6
0
def conv_dataset_command(args):
    if type(args.num_of_threads) == int and args.num_of_threads <= 0:
        print("The numbers of threads [{}] must be positive integer.".format(
            args.num_of_threads))
        return False

    if os.path.exists(args.destination):
        if not args.force:
            print(
                'File or directory [{}] is exists use `-F` option to overwrite it.'
                .format(args.destination))
            return False
        elif os.path.isdir(args.destination):
            print('Overwrite destination [{}].'.format(args.destination))
            shutil.rmtree(args.destination, ignore_errors=True)
            os.mkdir(args.destination)
        else:
            print('Cannot overwrite file [{}] please delete it.'.format(
                args.destination))
            return False
    else:
        os.mkdir(args.destination)

    _, ext = os.path.splitext(args.source)
    if ext.lower() == '.csv':

        if os.path.exists(args.source):
            cc = CreateCache(args.source,
                             shuffle=args.shuffle,
                             num_of_threads=args.num_of_threads)
            print('Number of Data: {}'.format(cc._size))
            print('Shuffle:        {}'.format(cc._shuffle))
            print('Normalize:      {}'.format(args.normalize))
            cc.create(args.destination, normalize=args.normalize)
        else:
            with CsvDataSource(args.source,
                               shuffle=args.shuffle,
                               normalize=args.normalize) as source:
                _convert(args, source)

    elif ext.lower() == '.cache':
        with CacheDataSource(args.source,
                             shuffle=args.shuffle,
                             normalize=args.normalize) as source:
            _convert(args, source)
    else:
        print('Command `conv_dataset` only supports CSV or CACHE as source.')
    return True
Ejemplo n.º 7
0
def test_concat_data_source(test_data_csv_csv_10, test_data_csv_csv_20, shuffle):
    data_list = [test_data_csv_csv_10, test_data_csv_csv_20]
    ds_list = [CsvDataSource(csvfilename, shuffle)
               for csvfilename in data_list]

    cds = ConcatDataSource(ds_list, shuffle)
    cds.reset()
    order = []
    for n in range(0, cds.size):
        data, label = cds.next()
        assert data[0][0] == label[0]
        order.append(int(round(data[0][0])))
    original_order = list(range(10)) + list(range(20))
    if shuffle:
        assert not original_order == order
        assert sorted(original_order) == sorted(order)
    else:
        assert original_order == order