예제 #1
0
def test_create_cache(test_data_csv_csv_20, test_data_csv_png_20,
                      input_file_fmt, cache_file_fmt, shuffle, normalize,
                      num_of_threads):
    if input_file_fmt == 'csv':
        csvfilename = test_data_csv_csv_20
    else:
        csvfilename = test_data_csv_png_20

    nnabla_config.set('DATA_ITERATOR', 'cache_file_format', cache_file_fmt)

    with create_temp_with_dir() as tmpdir:
        cc = CreateCache(csvfilename,
                         shuffle=shuffle,
                         num_of_threads=num_of_threads)
        cc.create(tmpdir, normalize=normalize)

        # get cache data source and csv file data source
        with closing(CacheDataSource(tmpdir)) as cache_source:
            csv_source = CsvDataSource(csvfilename, normalize=normalize)

            check_relative_csv_file_result(cache_file_fmt, csvfilename, tmpdir)

            assert cache_source.size == csv_source.size
            assert set(cache_source.variables) == set(csv_source.variables)

            if shuffle:
                with open(os.path.join(tmpdir, 'order.csv'), 'r') as f:
                    csv_source._order = [int(row[1]) for row in csv.reader(f)]

            for _ in range(cache_source.size):
                cache_data = associate_variables_and_data(cache_source)
                csv_data = associate_variables_and_data(csv_source)

                for v in cache_source.variables:
                    assert_allclose(cache_data[v], csv_data[v])
예제 #2
0
def test_data_iterator_csv_dataset(test_data_csv_png_10, test_data_csv_png_20,
                                   size, batch_size, shuffle, normalize,
                                   with_memory_cache, with_file_cache,
                                   with_context):

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3')
    nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_max_size', '10000')
    nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_num_of_data', '9')

    if size == 10:
        csvfilename = test_data_csv_png_10
    elif size == 20:
        csvfilename = test_data_csv_png_20

    logger.info(csvfilename)

    if with_context:
        with data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache) as di:
            check_data_iterator_result(di, batch_size, shuffle, normalize)
    else:
        di = data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache)
        check_data_iterator_result(di, batch_size, shuffle, normalize)
        di.close()
예제 #3
0
def test_data_iterator_concat_datasets(test_data_csv_png_10,
                                       test_data_csv_png_20,
                                       batch_size,
                                       shuffle,
                                       use_thread,
                                       normalize,
                                       with_memory_cache,
                                       with_file_cache,
                                       with_context,
                                       stop_exhausted):

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_max_size', '10000')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_num_of_data', '9')

    csvfilename_1 = test_data_csv_png_10
    csvfilename_2 = test_data_csv_png_20

    ds1 = CsvDataSource(csvfilename_1,
                        shuffle=shuffle,
                        normalize=normalize)

    ds2 = CsvDataSource(csvfilename_2,
                        shuffle=shuffle,
                        normalize=normalize)

    if with_context:
        with data_iterator_concat_datasets([ds1, ds2],
                                           batch_size=batch_size,
                                           shuffle=shuffle,
                                           with_memory_cache=with_memory_cache,
                                           with_file_cache=with_file_cache,
                                           use_thread=use_thread,
                                           stop_exhausted=stop_exhausted) as di:
            check_data_iterator_concat_result(
                di, batch_size, normalize, ds1.size, ds2.size, stop_exhausted)
    else:
        di = data_iterator_concat_datasets([ds1, ds2],
                                           batch_size=batch_size,
                                           shuffle=shuffle,
                                           with_memory_cache=with_memory_cache,
                                           with_file_cache=with_file_cache,
                                           use_thread=use_thread,
                                           stop_exhausted=stop_exhausted)
        check_data_iterator_concat_result(
            di, batch_size, normalize, ds1.size, ds2.size, stop_exhausted)
        di.close()
예제 #4
0
def test_data_iterator_csv_dataset(test_data_csv_png_10,
                                   test_data_csv_png_20,
                                   size,
                                   batch_size,
                                   shuffle,
                                   normalize,
                                   with_memory_cache,
                                   with_file_cache,
                                   with_context):

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_max_size', '10000')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_num_of_data', '9')

    if size == 10:
        csvfilename = test_data_csv_png_10
    elif size == 20:
        csvfilename = test_data_csv_png_20

    logger.info(csvfilename)

    if with_context:
        with data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache) as di:
            check_data_iterator_result(di, batch_size, shuffle, normalize)
    else:
        di = data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache)
        check_data_iterator_result(di, batch_size, shuffle, normalize)
        di.close()
예제 #5
0
                        '"MNIST_TRAIN", "MNIST_TEST", "TINY_IMAGENET_TRAIN",'
                        '"TINY_IMAGENET_VAL"')
    args = parser.parse_args()

    logger.debug('memory_cache: {}'.format(args.memory_cache))
    logger.debug('file_cache: {}'.format(args.file_cache))
    logger.debug('shuffle: {}'.format(args.shuffle))
    logger.debug('batch_size: {}'.format(args.batch_size))
    logger.debug('cache_size: {}'.format(args.cache_size))
    logger.debug('memory_size: {}'.format(args.memory_size))
    logger.debug('output: {}'.format(args.output))
    logger.debug('normalize: {}'.format(args.normalize))
    logger.debug('max_epoch: {}'.format(args.max_epoch))
    logger.debug('wait: {}'.format(args.wait))

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size',
                      '{}'.format(args.cache_size))
    nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_max_size',
                      '{}'.format(args.memory_size))

    if args.uri == 'MNIST_TRAIN':
        sys.path.append(
            os.path.join(os.path.dirname(os.path.abspath(__file__)), '..',
                         'vision', 'mnist'))
        from mnist_data import data_iterator_mnist
        with data_iterator_mnist(args.batch_size, True, None, args.shuffle,
                                 args.memory_cache, args.file_cache) as di:
            test_data_iterator(di, args)
    elif args.uri == 'MNIST_TEST':
        sys.path.append(
            os.path.join(os.path.dirname(os.path.abspath(__file__)), '..',
                         'vision', 'mnist'))
예제 #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('input',
                        type=str,
                        nargs='+',
                        help='Source file or directory.')
    parser.add_argument('output', type=str, help='Destination directory.')
    parser.add_argument('-W',
                        '--width',
                        type=int,
                        default=320,
                        help='width of output image (default:320)')
    parser.add_argument('-H',
                        '--height',
                        type=int,
                        default=320,
                        help='height of output image (default:320)')
    parser.add_argument(
        '-m',
        '--mode',
        default='trimming',
        choices=['trimming', 'padding'],
        help='shaping mode (trimming or padding)  (default:trimming)')
    parser.add_argument(
        '-S',
        '--shuffle',
        choices=['True', 'False'],
        help='shuffle mode if not specified, train:True, val:False.' +
        ' Otherwise specified value will be used for both.')
    parser.add_argument('-N',
                        '--file-cache-size',
                        type=int,
                        default=100,
                        help='num of data in cache file (default:100)')
    parser.add_argument('-C',
                        '--cache-type',
                        default='npy',
                        choices=['h5', 'npy'],
                        help='cache format (h5 or npy) (default:npy)')
    parser.add_argument('--thinning',
                        type=int,
                        default=1,
                        help='Thinning rate')

    args = parser.parse_args()
    ############################################################################
    # Analyze tar
    # If it consists only of members corresponding to regular expression
    # 'n[0-9]{8}\.tar', it is judged as train data archive.
    # If it consists only of members corresponding to regular expression
    # 'ILSVRC2012_val_[0-9]{8}\.JPEG', it is judged as validation data archive.

    archives = {'train': None, 'val': None}
    for inputarg in args.input:
        print('Checking input file [{}]'.format(inputarg))
        archive = tarfile.open(inputarg)
        is_train = False
        is_val = False
        names = []
        for name in archive.getnames():
            if re.match(r'n[0-9]{8}\.tar', name):
                if is_val:
                    print('Train data {} includes in validation tar'.format(
                        name))
                    exit(-1)
                is_train = True
            elif re.match(r'ILSVRC2012_val_[0-9]{8}\.JPEG', name):
                if is_train:
                    print('Validation data {} includes in train tar'.format(
                        name))
                    exit(-1)
                is_val = True
            else:
                print('Invalid member {} includes in tar file'.format(name))
                exit(-1)
            names.append(name)
        if is_train:
            if archives['train'] is None:
                archives['train'] = (archive, names)
            else:
                print('Please specify only 1 training tar archive.')
                exit(-1)
        if is_val:
            if archives['val'] is None:
                archives['val'] = (archive, names)
            else:
                print('Please specify only 1 validation tar archive.')
                exit(-1)

    # Read label of validation data, (Use ascending label of wordnet_id)
    validation_ground_truth = []
    g_file = VALIDATION_DATA_LABEL
    with open(g_file, 'r') as f:
        for l in f.readlines():
            validation_ground_truth.append(int(l.rstrip()))

    ############################################################################
    # Prepare logging
    tmpdir = tempfile.mkdtemp()
    logfilename = os.path.join(tmpdir, 'nnabla.log')

    # Temporarily chdir to tmpdir just before importing nnabla to reflect nnabla.conf.
    cwd = os.getcwd()
    os.chdir(tmpdir)
    with open('nnabla.conf', 'w') as f:
        f.write('[LOG]\n')
        f.write('log_file_name = {}\n'.format(logfilename))
        f.write('log_file_format = %(funcName)s : %(message)s\n')
        f.write('log_console_level = CRITICAL\n')

    from nnabla.config import nnabla_config
    os.chdir(cwd)

    ############################################################################
    # Data iterator setting
    nnabla_config.set('DATA_ITERATOR', 'cache_file_format',
                      '.' + args.cache_type)
    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size',
                      str(args.file_cache_size))
    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_num_of_threads',
                      '1')

    if not os.path.isdir(args.output):
        os.makedirs(args.output)

    ############################################################################
    # Prepare status monitor
    from nnabla.utils.progress import configure_progress
    configure_progress(None, _progress)

    ############################################################################
    # Converter

    try:
        if archives['train'] is not None:
            from nnabla.logger import logger
            logger.info('StartCreatingCache')
            archive, names = archives['train']
            output = os.path.join(args.output, 'train')
            if not os.path.isdir(output):
                os.makedirs(output)
            _create_train_cache(archive, output, names, args)
        if archives['val'] is not None:
            from nnabla.logger import logger
            logger.info('StartCreatingCache')
            archive, names = archives['val']
            output = os.path.join(args.output, 'val')
            if not os.path.isdir(output):
                os.makedirs(output)
            _create_validation_cache(archive, output, names,
                                     validation_ground_truth, args)
    except KeyboardInterrupt:
        shutil.rmtree(tmpdir, ignore_errors=True)

        # Even if CTRL-C is pressed, it does not stop if there is a running
        # thread, so it sending a signal to itself.
        os.kill(os.getpid(), 9)

    ############################################################################
    # Finish
    _finish = True
    shutil.rmtree(tmpdir, ignore_errors=True)
예제 #7
0
                    '--cache_size',
                    type=int,
                    default=100,
                    help='Cache size (num of data).')
parser.add_argument('-o',
                    '--output',
                    type=str,
                    default='cache',
                    help='If specified, cache data will output to here.')
args = parser.parse_args()

logger.debug('file_cache: {}'.format(args.file_cache))
logger.debug('cache_size: {}'.format(args.cache_size))
logger.debug('output: {}'.format(args.output))

nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size',
                  '{}'.format(args.cache_size))
nnabla_config.set('DATA_ITERATOR', 'cache_file_format', '.h5')

HERE = os.path.dirname(__file__)
nnabla_examples_root = os.path.join(HERE, '../../../../nnabla-examples')
mnist_examples_root = os.path.realpath(
    os.path.join(nnabla_examples_root, 'mnist-collection'))
sys.path.append(mnist_examples_root)

from mnist_data import MnistDataSource
mnist_training_cache = args.output + '/mnist_training.cache'
if not os.path.exists(mnist_training_cache):
    os.makedirs(mnist_training_cache)
DataSourceWithFileCache(data_source=MnistDataSource(train=True,
                                                    shuffle=False,
                                                    rng=None),
예제 #8
0
def test_data_iterator_csv_dataset(test_data_csv_png_10, test_data_csv_png_20,
                                   size, batch_size, shuffle, use_thread,
                                   normalize, with_memory_cache,
                                   with_file_cache, with_context,
                                   stop_exhausted):

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3')
    nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_max_size', '10000')
    nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_num_of_data', '9')

    if size == 10:
        csvfilename = test_data_csv_png_10
    elif size == 20:
        csvfilename = test_data_csv_png_20

    logger.info(csvfilename)

    main_thread = threading.current_thread().ident
    expect_epoch = [0]

    def end_epoch(epoch):
        if batch_size // size == 0:
            assert epoch == expect_epoch[0], "Failed for end epoch check"
        else:
            print(f"E: {epoch} <--> {expect_epoch[0]}")
        assert threading.current_thread(
        ).ident == main_thread, "Failed for thread checking"

    def begin_epoch(epoch):
        if batch_size // size == 0:
            assert epoch == expect_epoch[0], "Failed for begin epoch check"
        else:
            print(f"B: {epoch} <--> {expect_epoch[0]}")
        assert threading.current_thread(
        ).ident == main_thread, "Failed for thread checking"

    if with_context:
        with data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache,
                                       use_thread=use_thread,
                                       stop_exhausted=stop_exhausted) as di:
            di.register_epoch_end_callback(begin_epoch)
            di.register_epoch_end_callback(end_epoch)
            check_data_iterator_result(di, batch_size, shuffle, normalize,
                                       stop_exhausted, expect_epoch)
    else:
        di = data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache,
                                       use_thread=use_thread,
                                       stop_exhausted=stop_exhausted)
        di.register_epoch_end_callback(begin_epoch)
        di.register_epoch_end_callback(end_epoch)
        check_data_iterator_result(di, batch_size, shuffle, normalize,
                                   stop_exhausted, expect_epoch)
        di.close()