Пример #1
0
def _create_dataset(uri, batch_size, shuffle, no_image_normalization, cache_dir, overwrite_cache, create_cache_explicitly, prepare_data_iterator):
    class Dataset:
        pass
    dataset = Dataset()
    dataset.uri = uri
    dataset.normalize = not no_image_normalization

    if prepare_data_iterator:
        if cache_dir == '':
            cache_dir = None
        if cache_dir and create_cache_explicitly:
            if not os.path.exists(cache_dir) or len(os.listdir(cache_dir)) == 0 or overwrite_cache:
                if not os.path.exists(cache_dir):
                    os.mkdir(cache_dir)
                logger.log(99, 'Creating cache data for "' + uri + '"')
                with data_iterator_csv_dataset(uri, batch_size, shuffle, normalize=False, cache_dir=cache_dir) as di:
                    index = 0
                    while index < di.size:
                        progress('', (1.0 * di.position) / di.size)
                        di.next()
                        index += batch_size
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir, batch_size, shuffle, normalize=dataset.normalize))
        elif not cache_dir or overwrite_cache or not os.path.exists(cache_dir) or len(os.listdir(cache_dir)) == 0:
            if cache_dir and not os.path.exists(cache_dir):
                os.mkdir(cache_dir)
            dataset.data_iterator = (lambda: data_iterator_csv_dataset(
                uri, batch_size, shuffle, normalize=dataset.normalize, cache_dir=cache_dir))
        else:
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir, batch_size, shuffle, normalize=dataset.normalize))
    else:
        dataset.data_iterator = None
    return dataset
Пример #2
0
def _create_dataset(uri, batch_size, shuffle, no_image_normalization, cache_dir, overwrite_cache, create_cache_explicitly, prepare_data_iterator):
    class Dataset:
        pass
    dataset = Dataset()
    dataset.uri = uri
    dataset.normalize = not no_image_normalization

    if prepare_data_iterator:
        if cache_dir == '':
            cache_dir = None
        if cache_dir and create_cache_explicitly:
            if not os.path.exists(cache_dir) or overwrite_cache:
                if not os.path.exists(cache_dir):
                    os.mkdir(cache_dir)
                logger.info('Creating cache data for "' + uri + '"')
                with data_iterator_csv_dataset(uri, batch_size, shuffle, normalize=False, cache_dir=cache_dir) as di:
                    index = 0
                    while index < di.size:
                        progress('', (1.0 * di.position) / di.size)
                        di.next()
                        index += batch_size
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir, batch_size, shuffle, normalize=dataset.normalize))
        elif not cache_dir or overwrite_cache or not os.path.exists(cache_dir):
            if cache_dir and not os.path.exists(cache_dir):
                os.mkdir(cache_dir)
            dataset.data_iterator = (lambda: data_iterator_csv_dataset(
                uri, batch_size, shuffle, normalize=dataset.normalize, cache_dir=cache_dir))
        else:
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir, batch_size, shuffle, normalize=dataset.normalize))
    else:
        dataset.data_iterator = None
    return dataset
Пример #3
0
def test_data_iterator_csv_dataset(test_data_csv_png_10, test_data_csv_png_20,
                                   size, batch_size, shuffle, normalize,
                                   with_memory_cache, with_file_cache,
                                   with_context):

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3')
    nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_max_size', '10000')
    nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_num_of_data', '9')

    if size == 10:
        csvfilename = test_data_csv_png_10
    elif size == 20:
        csvfilename = test_data_csv_png_20

    logger.info(csvfilename)

    if with_context:
        with data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache) as di:
            check_data_iterator_result(di, batch_size, shuffle, normalize)
    else:
        di = data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache)
        check_data_iterator_result(di, batch_size, shuffle, normalize)
        di.close()
Пример #4
0
def _create_dataset(uri, batch_size, shuffle, no_image_normalization, cache_dir, overwrite_cache, create_cache_explicitly, prepare_data_iterator):
    class Dataset:
        pass
    dataset = Dataset()
    dataset.uri = uri
    dataset.normalize = not no_image_normalization

    comm = current_communicator()

    # use same random state for each process until slice is called
    rng = numpy.random.RandomState(0)
    use_memory_cache = comm.size == 1 if comm else True

    if prepare_data_iterator:
        if cache_dir == '':
            cache_dir = None

        # Disable implicit cache creation when MPI is available.
        if cache_dir and (create_cache_explicitly or comm):
            cache_index = os.path.join(cache_dir, "cache_index.csv")
            if not os.path.exists(cache_index) or overwrite_cache:
                if single_or_rankzero():
                    logger.log(99, 'Creating cache data for "' + uri + '"')

                    try:
                        os.makedirs(cache_dir)
                    except OSError:
                        pass  # python2 does not support exists_ok arg

                    with data_iterator_csv_dataset(uri, batch_size, shuffle, rng=rng, normalize=False, cache_dir=cache_dir, with_memory_cache=False) as di:
                        pass

            rng = numpy.random.RandomState(0)
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir, batch_size, shuffle, rng=rng, normalize=dataset.normalize, with_memory_cache=use_memory_cache))
        elif not cache_dir or overwrite_cache or not os.path.exists(cache_dir) or len(os.listdir(cache_dir)) == 0:
            if comm:
                logger.critical(
                    'Implicit cache creation does not support with MPI')
                import sys
                sys.exit(-1)
            else:
                if cache_dir:
                    try:
                        os.makedirs(cache_dir)
                    except OSError:
                        pass  # python2 does not support exists_ok arg
                dataset.data_iterator = (lambda: data_iterator_csv_dataset(
                    uri, batch_size, shuffle, rng=rng, normalize=dataset.normalize, cache_dir=cache_dir))
        else:
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir, batch_size, shuffle, rng=rng, normalize=dataset.normalize, with_memory_cache=use_memory_cache))
    else:
        dataset.data_iterator = None
    return dataset
Пример #5
0
 def _load_dataset(self, dataset_path, batch_size=100, shuffle=False):
     if os.path.isfile(dataset_path):
         logger.info("Load a dataset from {}.".format(dataset_path))
         return data_iterator_csv_dataset(dataset_path,
                                          batch_size,
                                          shuffle=shuffle)
     return None
Пример #6
0
def loadData(batch_size):

    cache_dir = "./cache"
    if os.path.isdir(cache_dir) == False:
        os.mkdir(cache_dir)
        dataset = data_iterator_csv_dataset(
            "../AIStudy.BoardGame/experience.csv",
            batch_size,
            shuffle=True,
            normalize=True,
            cache_dir=cache_dir)
    else:
        dataset = data_iterator_cache(cache_dir,
                                      batch_size,
                                      shuffle=True,
                                      normalize=True)

    variables = dataset.variables
    print(variables)
    """

    s0Index = variables.index('s0')
    print(("index(s0)={}").format(s0Index))
    for n in range(1000):
        data = dataset.next()
        if n==0:
            print(("shape(s0)={}").format(data[s0Index].shape))
        print(("epoch={},position={},size={}, data_size={}").format(dataset.epoch, dataset.position,dataset.size, len(data[0])))
    """

    return dataset
Пример #7
0
def test_data_iterator_csv_dataset(test_data_csv_png_10,
                                   test_data_csv_png_20,
                                   size,
                                   batch_size,
                                   shuffle,
                                   normalize,
                                   with_memory_cache,
                                   with_file_cache,
                                   with_context):

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_max_size', '10000')
    nnabla_config.set(
        'DATA_ITERATOR', 'data_source_buffer_num_of_data', '9')

    if size == 10:
        csvfilename = test_data_csv_png_10
    elif size == 20:
        csvfilename = test_data_csv_png_20

    logger.info(csvfilename)

    if with_context:
        with data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache) as di:
            check_data_iterator_result(di, batch_size, shuffle, normalize)
    else:
        di = data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache)
        check_data_iterator_result(di, batch_size, shuffle, normalize)
        di.close()
Пример #8
0
def common_forward(info, forward_func):
    batch_size = 1

    class ForwardConfig:
        pass

    class Args:
        pass

    args = Args()

    config = ForwardConfig
    if hasattr(info, 'global_config'):
        config.global_config = info.global_config
    config.executors = info.executors.values()
    config.networks = []
    for e in config.executors:
        if e.network.name in info.networks.keys():
            config.networks.append(info.networks[e.network.name])
        else:
            assert False, "{} is not found.".format(e.network.name)

    normalize = True
    for d in info.datasets.values():
        args.dataset = d.uri
        normalize = d.normalize
        break
    for e in config.executors:
        normalize = normalize and not e.no_image_normalization

    data_iterator = (lambda: data_iterator_csv_dataset(
        uri=args.dataset,
        batch_size=config.networks[0].batch_size,
        shuffle=False,
        normalize=normalize,
        with_memory_cache=False,
        with_file_cache=False))

    result = []
    with data_iterator() as di:
        index = 0
        while index < di.size:
            data = di.next()
            avg = forward_func(
                args, index, config, data, di.variables)
            index += len(avg[0])
            result.append(avg[0])

    return np.array(result)
Пример #9
0
def forward_command(args):
    callback.update_status(args)

    configure_progress(os.path.join(args.outdir, 'progress.txt'))
    files = []
    files.append(args.config)
    if args.param:
        files.append(args.param)
    batch_size = args.batch_size
    if batch_size < 1:
        batch_size = None

    class ForwardConfig:
        pass

    config = ForwardConfig
    info = load.load(files, prepare_data_iterator=False, batch_size=batch_size)
    config.global_config = info.global_config

    config.executors = info.executors.values()

    config.networks = []
    for e in config.executors:
        if e.network.name in info.networks.keys():
            config.networks.append(info.networks[e.network.name])
        else:
            logger.critical('Network {} is not found.'.format(
                config.executor.network.name))
            return False

    normalize = True
    for d in info.datasets.values():
        if d.uri == args.dataset or d.cache_dir == args.dataset:
            normalize = d.normalize
    for e in config.executors:
        normalize = normalize and not e.no_image_normalization

    orders = {}
    # With CSV
    if os.path.splitext(args.dataset)[1] == '.csv':
        data_iterator = (lambda: data_iterator_csv_dataset(
            uri=args.dataset,
            batch_size=config.networks[0].batch_size,
            shuffle=False,
            normalize=normalize,
            with_memory_cache=False,
            with_file_cache=False))

        # load dataset as csv
        filereader = FileReader(args.dataset)
        with filereader.open(textmode=True, encoding='utf-8-sig') as f:
            rows = [row for row in csv.reader(f)]
        row0 = rows.pop(0)
        if args.replace_path:
            root_path = os.path.dirname(args.dataset)
            root_path = os.path.abspath(root_path.replace('/|\\', os.path.sep))
        else:
            root_path = '.'
        rows = [row for row in rows if len(row)]
        rows = list(
            map(
                lambda row: list(
                    map(
                        lambda i, x: x if row0[i][0] == '#' or is_float(
                            x) else compute_full_path(root_path, x),
                        range(len(row)), row)), rows))
        for i in range(len(rows)):
            orders[i] = i
    # With Cache
    elif os.path.splitext(args.dataset)[1] == '.cache':
        data_iterator = (lambda: data_iterator_cache(uri=args.dataset,
                                                     batch_size=config.
                                                     networks[0].batch_size,
                                                     shuffle=False,
                                                     normalize=normalize))

        # Get original CSV
        original_csv = os.path.join(args.dataset, 'original.csv')
        try:
            # load dataset as csv
            filereader = FileReader(original_csv)
            with filereader.open(textmode=True, encoding='utf-8-sig') as f:
                rows = [row for row in csv.reader(f)]
            row0 = rows.pop(0)
            root_path = '.'
            rows = list(
                map(
                    lambda row: list(
                        map(
                            lambda x: x if is_float(x) else compute_full_path(
                                root_path, x), row)), rows))
        except:
            print('Cannot open', original_csv)
            pass

        # Get original Data order.
        order_csv = os.path.join(args.dataset, 'order.csv')
        try:
            filereader = FileReader(order_csv)
            with filereader.open(textmode=True) as f:
                for original, shuffled in [[int(x) for x in row]
                                           for row in csv.reader(f)]:
                    orders[original] = shuffled
        except:
            print('Cannot open', order_csv)
            for i in range(len(rows)):
                orders[i] = i
    else:
        print('Unsupported extension "{}" in "{}".'.format(
            os.path.splitext(args.dataset)[1], args.dataset))

    callback.update_status(('data.max', len(rows)))
    callback.update_status(('data.current', 0))
    callback.update_status('processing', True)

    result_csv_filename = os.path.join(args.outdir, args.outfile)
    with open(result_csv_filename, 'w', encoding='utf-8') as f:
        writer = csv.writer(f, lineterminator='\n')
        with data_iterator() as di:
            index = 0
            while index < di.size:
                data = di.next()
                result, outputs = _forward(args, index, config, data,
                                           di.variables)
                if index == 0:
                    for name, dim in zip(result.names, result.dims):
                        if dim == 1:
                            if e.repeat_evaluation_type == "std":
                                name = "Uncertainty(Std)"
                            row0.append(name)
                        else:
                            for d in range(dim):
                                row0.append(name + '__' + str(d))
                    writer.writerow(row0)
                for i, output in enumerate(outputs):
                    if index + i < len(rows):
                        import copy
                        row = copy.deepcopy(rows[orders[index + i]])
                        row.extend(output)
                        writer.writerow(row)
                index += len(outputs)

                callback.update_status(('data.current', min([index,
                                                             len(rows)])))
                callback.update_forward_time()
                callback.update_status()

                logger.log(
                    99, 'data {} / {}'.format(min([index, len(rows)]),
                                              len(rows)))

    callback.process_evaluation_result(args.outdir, result_csv_filename)

    logger.log(99, 'Forward Completed.')
    progress(None)

    callback.update_status(('output_result.csv_header', ','.join(row0)))
    callback.update_status(('output_result.column_num', len(row0)))
    callback.update_status(('output_result.data_num', len(rows)))
    callback.update_status('finished')

    return True
Пример #10
0
                         'vision', 'imagenet'))
        from tiny_imagenet_data import data_iterator_tiny_imagenet
        with data_iterator_tiny_imagenet(args.batch_size, 'train') as di:
            test_data_iterator(di, args)
    elif args.uri == 'TINY_IMAGENET_VAL':
        sys.path.append(
            os.path.join(os.path.dirname(os.path.abspath(__file__)), '..',
                         'vision', 'imagenet'))
        from tiny_imagenet_data import data_iterator_tiny_imagenet
        with data_iterator_tiny_imagenet(args.batch_size, 'val') as di:
            test_data_iterator(di, args)
    else:
        if os.path.splitext(args.uri)[1].lower() == '.cache':
            from nnabla.utils.data_iterator import data_iterator_cache
            with data_iterator_cache(uri=args.uri,
                                     batch_size=args.batch_size,
                                     shuffle=args.shuffle,
                                     with_memory_cache=args.memory_cache,
                                     normalize=args.normalize) as di:
                test_data_iterator(di, args)
        else:
            from nnabla.utils.data_iterator import data_iterator_csv_dataset
            with data_iterator_csv_dataset(uri=args.uri,
                                           batch_size=args.batch_size,
                                           shuffle=args.shuffle,
                                           normalize=args.normalize,
                                           with_memory_cache=args.memory_cache,
                                           with_file_cache=args.file_cache,
                                           cache_dir=args.output) as di:
                test_data_iterator(di, args)
Пример #11
0
def forward_command(args):
    configure_progress(os.path.join(args.outdir, 'progress.txt'))
    files = []
    files.append(args.config)
    if args.param:
        files.append(args.param)
    batch_size = args.batch_size
    if batch_size < 1:
        batch_size = None

    class ForwardConfig:
        pass

    config = ForwardConfig
    info = load.load(files, prepare_data_iterator=False, batch_size=batch_size)
    config.global_config = info.global_config

    config.executors = info.executors.values()

    config.networks = []
    for e in config.executors:
        if e.network.name in info.networks.keys():
            config.networks.append(info.networks[e.network.name])
        else:
            logger.critical('Network {} is not found.'.format(
                config.executor.network.name))
            return False

    normalize = True
    for d in info.datasets.values():
        if d.uri == args.dataset:
            normalize = d.normalize
    for e in config.executors:
        normalize = normalize and not e.no_image_normalization

    data_iterator = (lambda: data_iterator_csv_dataset(uri=args.dataset,
                                                       batch_size=config.
                                                       networks[0].batch_size,
                                                       shuffle=False,
                                                       normalize=normalize,
                                                       with_memory_cache=False,
                                                       with_file_cache=False))

    # load dataset as csv
    filereader = FileReader(args.dataset)
    with filereader.open(textmode=True) as f:
        rows = [row for row in csv.reader(f)]
    row0 = rows.pop(0)
    root_path = os.path.dirname(args.dataset)
    root_path = os.path.abspath(root_path.replace('/|\\', os.path.sep))
    rows = list(
        map(
            lambda row: list(
                map(
                    lambda x: x
                    if is_float(x) else compute_full_path(root_path, x), row)),
            rows))

    with open(os.path.join(args.outdir, 'output_result.csv'), 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        with data_iterator() as di:
            index = 0
            while index < di.size:
                data = di.next()
                result, outputs = _forward(args, index, config, data,
                                           di.variables)
                if index == 0:
                    for name, dim in zip(result.names, result.dims):
                        if dim == 1:
                            row0.append(name)
                        else:
                            for d in range(dim):
                                row0.append(name + '__' + str(d))
                    writer.writerow(row0)
                for i, output in enumerate(outputs):
                    if index + i < len(rows):
                        import copy
                        row = copy.deepcopy(rows[index + i])
                        row.extend(output)
                        writer.writerow(row)
                index += len(outputs)
                logger.log(
                    99, 'data {} / {}'.format(min([index, len(rows)]),
                                              len(rows)))

    logger.log(99, 'Forward Completed.')
    progress(None)
    return True
Пример #12
0
def get_data_iterator_and_num_class(args):
    """
        Get Data_iterator for training and test data set.
        Also, obtain class / category information from data.
    """
    if args.train_csv:
        from nnabla.utils.data_iterator import data_iterator_csv_dataset
        data_iterator = data_iterator_csv_dataset

        if args.test_csv:
            assert os.path.isfile(
                args.test_csv), "csv file for test not found."

            # check the number of the classes / categories
            with open(args.train_csv, "r") as f:
                csv_data_train = f.readlines()[1:]  # line 1:"x:image,y:label"
            classes_train = {
                line.split(",")[-1].strip()
                for line in csv_data_train
            }

            with open(args.test_csv, "r") as f:
                # first line:"x:image,y:label"
                csv_data_test = f.readlines()[1:]
            classes_test = {
                line.split(",")[-1].strip()
                for line in csv_data_test
            }
            classes_train.update(classes_test)

            num_class = len(classes_train)

            data_iterator_train = data_iterator_csv_dataset(args.train_csv,
                                                            args.batch_size,
                                                            args.shuffle,
                                                            normalize=False)

            data_iterator_valid = data_iterator_csv_dataset(args.test_csv,
                                                            args.batch_size,
                                                            args.shuffle,
                                                            normalize=False)
        else:
            print("No csv file for test given. So split the training data")
            assert isintance(args.ratio, float), "ratio must be in (0.0, 1.0)"

            # check the number of the classes / categories
            with open(args.train_csv, "r") as f:
                # first line is "x:image,y:label"
                csv_data_train = f.readlines()[1:]
            all_classes = {
                line.split(",")[-1].strip()
                for line in csv_data_train
            }
            num_class = len(all_classes)
            all_data = data_iterator_csv_dataset(args.train_csv,
                                                 args.batch_size,
                                                 args.shuffle,
                                                 normalize=False)

            num_samples = all_data.size
            num_train_samples = int(args.ratio * num_samples)

            data_iterator_train = all_data.slice(rng=None,
                                                 slice_start=0,
                                                 slice_end=num_train_samples)

            data_iterator_valid = all_data.slice(rng=None,
                                                 slice_start=num_train_samples,
                                                 slice_end=num_samples)

    else:
        # use caltech101 data like tutorial
        from caltech101_data import data_iterator_caltech101
        assert isintance(args.ratio, float), "ratio must be in (0.0, 1.0)"
        data_iterator = data_iterator_caltech101
        num_class = 101  # pre-defined (excluding background class)
        all_data = data_iterator(args.batch_size,
                                 width=args.width,
                                 height=args.height)

        num_samples = all_data.size
        num_train_samples = int(args.ratio * num_samples)

        data_iterator_train = all_data.slice(rng=None,
                                             slice_start=0,
                                             slice_end=num_train_samples)

        data_iterator_valid = all_data.slice(rng=None,
                                             slice_start=num_train_samples,
                                             slice_end=num_samples)

    print("training images: {}".format(data_iterator_train.size))
    print("validation images: {}".format(data_iterator_valid.size))
    print("{} categories included.".format(num_class))

    return data_iterator_train, data_iterator_valid, num_class
Пример #13
0
def forward_command(args):
    configure_progress(os.path.join(args.outdir, 'progress.txt'))
    files = []
    files.append(args.config)
    if args.param:
        files.append(args.param)

    class ForwardConfig:
        pass
    config = ForwardConfig
    info = load.load(files, prepare_data_iterator=False)
    config.global_config = info.global_config

    config.executors = info.executors.values()

    config.networks = []
    for e in config.executors:
        if e.network.name in info.networks.keys():
            config.networks.append(info.networks[e.network.name])
        else:
            logger.critical('Network {} does not found.'.format(
                config.executor.network.name))
            return

    normalize = True
    for d in info.datasets.values():
        if d.uri == args.dataset:
            normalize = d.normalize
    data_iterator = (lambda: data_iterator_csv_dataset(
        args.dataset, config.networks[0].batch_size, False, normalize=normalize))

    # load dataset as csv
    with open(args.dataset, 'rt') as f:
        rows = [row for row in csv.reader(f)]
    row0 = rows.pop(0)
    root_path = os.path.dirname(args.dataset)
    root_path = os.path.abspath(root_path.replace('/|\\', os.path.sep))
    rows = list(map(lambda row: list(map(lambda x: x if is_float(
        x) else compute_full_path(root_path, x), row)), rows))

    with data_iterator() as di:
        index = 0
        while index < di.size:
            data = di.next()
            result, outputs = forward(args, index, config, data, di.variables)
            if index == 0:
                for name, dim in zip(result.names, result.dims):
                    if dim == 1:
                        row0.append(name)
                    else:
                        for d in range(dim):
                            row0.append(name + '__' + str(d))
            for i, output in enumerate(outputs):
                if index + i < len(rows):
                    rows[index + i].extend(output)
            index += len(outputs)
            logger.log(
                99, 'data {} / {}'.format(min([index, len(rows)]), len(rows)))

    with open(os.path.join(args.outdir, 'output_result.csv'), 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerow(row0)
        writer.writerows(rows)

    logger.log(99, 'Forward Completed.')
    progress(None)
Пример #14
0
    return args


if __name__ == '__main__':
    config = get_args()
    logger.info("Running in %s" % config.context)

    seed(0)
    ctx = get_extension_context(config.context, device_id=config.device_id)
    nn.set_default_context(ctx)
    nn.clear_parameters()
    net = MLP(config)

    if config.process == 'train':
        net.train()
    elif config.process == 'evaluate':
        net.evaluate()
    elif config.process == 'infer':
        net.init_for_infer()
        if os.path.isfile(config.evaluation_dataset_path):
            logger.info("Load a dataset from {}.".format(
                config.evaluation_dataset_path))
            edata = data_iterator_csv_dataset(config.evaluation_dataset_path,
                                              1,
                                              shuffle=False)
            for i in range(edata.size):
                data = edata.next()
                x_d = data[0]
                result = net.infer(x_d)
                print("inference result = {}".format(result))
Пример #15
0
def forward_command(args):
    configure_progress(os.path.join(args.outdir, 'progress.txt'))
    files = []
    files.append(args.config)
    if args.param:
        files.append(args.param)

    class ForwardConfig:
        pass
    config = ForwardConfig
    info = load.load(files, prepare_data_iterator=False)
    config.global_config = info.global_config

    config.executors = info.executors.values()

    config.networks = []
    for e in config.executors:
        if e.network.name in info.networks.keys():
            config.networks.append(info.networks[e.network.name])
        else:
            logger.critical('Network {} does not found.'.format(
                config.executor.network.name))
            return

    normalize = True
    for d in info.datasets.values():
        if d.uri == args.dataset:
            normalize = d.normalize
    data_iterator = (lambda: data_iterator_csv_dataset(
        args.dataset, config.networks[0].batch_size, False, padding=True, normalize=normalize))

    # load dataset as csv
    with open(args.dataset, 'rt') as f:
        rows = [row for row in csv.reader(f)]
    row0 = rows.pop(0)
    root_path = os.path.dirname(args.dataset)
    root_path = os.path.abspath(root_path.replace('/|\\', os.path.sep))
    rows = map(lambda row: map(lambda x: x if is_float(
        x) else compute_full_path(root_path, x), row), rows)

    with data_iterator() as di:
        index = 0
        while index < di.size:
            data = di.next()
            result, outputs = forward(args, index, config, data, di.variables)
            if index == 0:
                for name, dim in zip(result.names, result.dims):
                    if dim == 1:
                        row0.append(name)
                    else:
                        for d in range(dim):
                            row0.append(name + '__' + str(d))
            for i, output in enumerate(outputs):
                if index + i < len(rows):
                    rows[index + i].extend(output)
            index += len(outputs)
            logger.log(
                99, 'data {} / {}'.format(min([index, len(rows)]), len(rows)))

    with open(os.path.join(args.outdir, 'output_result.csv'), 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerow(row0)
        writer.writerows(rows)

    logger.log(99, 'Forward Completed.')
    progress(None)
Пример #16
0
def test_data_iterator_csv_dataset(test_data_csv_png_10, test_data_csv_png_20,
                                   size, batch_size, shuffle, use_thread,
                                   normalize, with_memory_cache,
                                   with_file_cache, with_context,
                                   stop_exhausted):

    nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3')
    nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_max_size', '10000')
    nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_num_of_data', '9')

    if size == 10:
        csvfilename = test_data_csv_png_10
    elif size == 20:
        csvfilename = test_data_csv_png_20

    logger.info(csvfilename)

    main_thread = threading.current_thread().ident
    expect_epoch = [0]

    def end_epoch(epoch):
        if batch_size // size == 0:
            assert epoch == expect_epoch[0], "Failed for end epoch check"
        else:
            print(f"E: {epoch} <--> {expect_epoch[0]}")
        assert threading.current_thread(
        ).ident == main_thread, "Failed for thread checking"

    def begin_epoch(epoch):
        if batch_size // size == 0:
            assert epoch == expect_epoch[0], "Failed for begin epoch check"
        else:
            print(f"B: {epoch} <--> {expect_epoch[0]}")
        assert threading.current_thread(
        ).ident == main_thread, "Failed for thread checking"

    if with_context:
        with data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache,
                                       use_thread=use_thread,
                                       stop_exhausted=stop_exhausted) as di:
            di.register_epoch_end_callback(begin_epoch)
            di.register_epoch_end_callback(end_epoch)
            check_data_iterator_result(di, batch_size, shuffle, normalize,
                                       stop_exhausted, expect_epoch)
    else:
        di = data_iterator_csv_dataset(uri=csvfilename,
                                       batch_size=batch_size,
                                       shuffle=shuffle,
                                       normalize=normalize,
                                       with_memory_cache=with_memory_cache,
                                       with_file_cache=with_file_cache,
                                       use_thread=use_thread,
                                       stop_exhausted=stop_exhausted)
        di.register_epoch_end_callback(begin_epoch)
        di.register_epoch_end_callback(end_epoch)
        check_data_iterator_result(di, batch_size, shuffle, normalize,
                                   stop_exhausted, expect_epoch)
        di.close()