コード例 #1
0
def test_file_close_exception(extension):
    with pytest.raises(ZeroDivisionError) as excinfo:
        with create_temp_with_dir("tmp{}".format(extension)) as filename:
            with get_file_handle_save(filename, ext=extension) as f:
                file_handler = f
                1 / 0
    assert file_handler.closed

    with pytest.raises(ZeroDivisionError) as excinfo:
        with create_temp_with_dir("tmp{}".format(extension)) as filename:
            # create a file at first
            with open(filename, "w") as f:
                f.write("\n")

            with get_file_handle_load(None, filename, ext=extension) as f:
                file_handler = f
                1 / 0
    assert file_handler.closed
コード例 #2
0
ファイル: load.py プロジェクト: ishihara-y/nnabla
def load(filenames,
         prepare_data_iterator=True,
         batch_size=None,
         exclude_parameter=False,
         parameter_only=False,
         extension=".nntxt"):
    '''load
    Load network information from files.

    Args:
        filenames (list): file-like object or List of filenames.
        extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
    Returns:
        dict: Network information.
    '''
    class Info:
        pass

    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()

    if isinstance(filenames, list) or isinstance(filenames, tuple):
        pass
    elif isinstance(filenames, str) or hasattr(filenames, 'read'):
        filenames = [filenames]

    for filename in filenames:
        if isinstance(filename, str):
            _, ext = os.path.splitext(filename)
        else:
            ext = extension

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            if not parameter_only:
                with get_file_handle_load(filename, ext) as f:
                    try:
                        text_format.Merge(f.read(), proto)
                    except:
                        logger.critical('Failed to read {}.'.format(filename))
                        logger.critical(
                            '2 byte characters may be used for file name or folder name.'
                        )
                        raise
            if len(proto.parameter) > 0:
                if not exclude_parameter:
                    nn.load_parameters(filename, extension=ext)
        elif ext in ['.protobuf', '.h5']:
            if not exclude_parameter:
                nn.load_parameters(filename, extension=ext)
            else:
                logger.info('Skip loading parameter.')

        elif ext == '.nnp':
            with get_file_handle_load(filename, ext) as nnp:
                for name in nnp.namelist():
                    _, ext = os.path.splitext(name)
                    if name == 'nnp_version.txt':
                        pass  # TODO currently do nothing with version.
                    elif ext in ['.nntxt', '.prototxt']:
                        if not parameter_only:
                            with nnp.open(name, 'r') as f:
                                text_format.Merge(f.read(), proto)
                        if len(proto.parameter) > 0:
                            if not exclude_parameter:
                                with nnp.open(name, 'r') as f:
                                    nn.load_parameters(f, extension=ext)
                    elif ext in ['.protobuf', '.h5']:
                        if not exclude_parameter:
                            with nnp.open(name, 'r') as f:
                                nn.load_parameters(f, extension=ext)
                        else:
                            logger.info('Skip loading parameter.')

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            import nnabla_ext.cudnn
        elif 'cuda:float' in default_context.backend:
            try:
                import nnabla_ext.cudnn
            except:
                pass
        try:
            x = nn.Variable()
            y = nn.Variable()
            func = F.ReLU(default_context, inplace=True)
            func.setup([x], [y])
            func.forward([x], [y])
        except:
            logger.warn('Fallback to CPU context.')
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cpu
        default_context = nnabla_ext.cpu.context()

    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.rank)
    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    info.datasets = _datasets(
        proto, prepare_data_iterator if prepare_data_iterator is not None else
        info.training_config.max_epoch > 0)

    info.networks = _networks(proto, default_context, batch_size)

    info.optimizers = _optimizers(proto, default_context, info.networks,
                                  info.datasets)

    info.monitors = _monitors(proto, default_context, info.networks,
                              info.datasets)

    info.executors = _executors(proto, info.networks)

    return info
コード例 #3
0
ファイル: parameter.py プロジェクト: ujtakk/nnabla
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    if isinstance(path, str):
        _, ext = os.path.splitext(path)
    else:
        ext = extension

    if ext == '.h5':
        # TODO temporary work around to suppress FutureWarning message.
        import warnings
        warnings.simplefilter('ignore', category=FutureWarning)
        import h5py
        with get_file_handle_load(path, ext) as hd:
            keys = []

            def _get_keys(name):
                ds = hd[name]
                if not isinstance(ds, h5py.Dataset):
                    # Group
                    return
                # To preserve order of parameters
                keys.append((ds.attrs.get('index', None), name))
            hd.visit(_get_keys)
            for _, key in sorted(keys):
                ds = hd[key]

                var = get_parameter_or_create(
                    key, ds.shape, need_grad=ds.attrs['need_grad'])
                var.data.cast(ds.dtype)[...] = ds[...]

                if needs_proto:
                    if proto is None:
                        proto = nnabla_pb2.NNablaProtoBuf()
                    parameter = proto.parameter.add()
                    parameter.variable_name = key
                    parameter.shape.dim.extend(ds.shape)
                    parameter.data.extend(
                        numpy.array(ds[...]).flatten().tolist())
                    parameter.need_grad = False
                    if ds.attrs['need_grad']:
                        parameter.need_grad = True

    else:
        if proto is None:
            proto = nnabla_pb2.NNablaProtoBuf()

        if ext == '.protobuf':
            with get_file_handle_load(path, ext) as f:
                proto.MergeFromString(f.read())
                set_parameter_from_proto(proto)
        elif ext == '.nntxt' or ext == '.prototxt':
            with get_file_handle_load(path, ext) as f:
                text_format.Merge(f.read(), proto)
                set_parameter_from_proto(proto)

        elif ext == '.nnp':
            try:
                tmpdir = tempfile.mkdtemp()
                with get_file_handle_load(path, ext) as nnp:
                    for name in nnp.namelist():
                        nnp.extract(name, tmpdir)
                        _, ext = os.path.splitext(name)
                        if ext in ['.protobuf', '.h5']:
                            proto = load_parameters(os.path.join(
                                tmpdir, name), proto, needs_proto)
            finally:
                shutil.rmtree(tmpdir)
                logger.info("Parameter load ({}): {}".format(format, path))
        else:
            logger.error("Invalid parameter file '{}'".format(path))
    return proto