Beispiel #1
0
def conv_dataset_command(args):
    if os.path.exists(args.destination):
        if not args.force:
            print(
                'File or directory [{}] is exists use `-F` option to overwrite it.'
                .format(args.destination))
            return False
        elif os.path.isdir(args.destination):
            print('Overwrite destination [{}].'.format(args.destination))
            shutil.rmtree(args.destination, ignore_errors=True)
            os.mkdir(args.destination)
        else:
            print('Cannot overwrite file [{}] please delete it.'.format(
                args.destination))
            return False
    else:
        os.mkdir(args.destination)
    datasource = None
    _, ext = os.path.splitext(args.source)
    if ext.lower() == '.csv':
        with CsvDataSource(args.source,
                           shuffle=args.shuffle,
                           normalize=args.normalize) as source:
            _convert(args, source)
    elif ext.lower() == '.cache':
        with CacheDataSource(args.source,
                             shuffle=args.shuffle,
                             normalize=args.normalize) as source:
            _convert(args, source)
    else:
        print('Command `conv_dataset` only supports CSV or CACHE as source.')
    return True
Beispiel #2
0
def test_create_cache(test_data_csv_csv_20, test_data_csv_png_20,
                      input_file_fmt, cache_file_fmt, shuffle, normalize,
                      num_of_threads):
    if input_file_fmt == 'csv':
        csvfilename = test_data_csv_csv_20
    else:
        csvfilename = test_data_csv_png_20

    nnabla_config.set('DATA_ITERATOR', 'cache_file_format', cache_file_fmt)

    with create_temp_with_dir() as tmpdir:
        cc = CreateCache(csvfilename,
                         shuffle=shuffle,
                         num_of_threads=num_of_threads)
        cc.create(tmpdir, normalize=normalize)

        # get cache data source and csv file data source
        with closing(CacheDataSource(tmpdir)) as cache_source:
            csv_source = CsvDataSource(csvfilename, normalize=normalize)

            check_relative_csv_file_result(cache_file_fmt, csvfilename, tmpdir)

            assert cache_source.size == csv_source.size
            assert set(cache_source.variables) == set(csv_source.variables)

            if shuffle:
                with open(os.path.join(tmpdir, 'order.csv'), 'r') as f:
                    csv_source._order = [int(row[1]) for row in csv.reader(f)]

            for _ in range(cache_source.size):
                cache_data = associate_variables_and_data(cache_source)
                csv_data = associate_variables_and_data(csv_source)

                for v in cache_source.variables:
                    assert_allclose(cache_data[v], csv_data[v])
Beispiel #3
0
def conv_dataset_command(args):
    if type(args.num_of_threads) == int and args.num_of_threads <= 0:
        print("The numbers of threads [{}] must be positive integer.".format(
            args.num_of_threads))
        return False

    if os.path.exists(args.destination):
        if not args.force:
            print(
                'File or directory [{}] is exists use `-F` option to overwrite it.'
                .format(args.destination))
            return False
        elif os.path.isdir(args.destination):
            print('Overwrite destination [{}].'.format(args.destination))
            shutil.rmtree(args.destination, ignore_errors=True)
            os.mkdir(args.destination)
        else:
            print('Cannot overwrite file [{}] please delete it.'.format(
                args.destination))
            return False
    else:
        os.mkdir(args.destination)

    _, ext = os.path.splitext(args.source)
    if ext.lower() == '.csv':

        if os.path.exists(args.source):
            cc = CreateCache(args.source,
                             shuffle=args.shuffle,
                             num_of_threads=args.num_of_threads)
            print('Number of Data: {}'.format(cc._size))
            print('Shuffle:        {}'.format(cc._shuffle))
            print('Normalize:      {}'.format(args.normalize))
            cc.create(args.destination, normalize=args.normalize)
        else:
            with CsvDataSource(args.source,
                               shuffle=args.shuffle,
                               normalize=args.normalize) as source:
                _convert(args, source)

    elif ext.lower() == '.cache':
        with CacheDataSource(args.source,
                             shuffle=args.shuffle,
                             normalize=args.normalize) as source:
            _convert(args, source)
    else:
        print('Command `conv_dataset` only supports CSV or CACHE as source.')
    return True