Ejemplo n.º 1
0
def load(filenames,
         prepare_data_iterator=True,
         batch_size=None,
         exclude_parameter=False,
         parameter_only=False,
         extension=".nntxt"):
    '''load
    Load network information from files.

    Args:
        filenames (list): file-like object or List of filenames.
        extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
    Returns:
        dict: Network information.
    '''
    class Info:
        pass

    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()

    # optimizer checkpoint
    opti_proto = nnabla_pb2.NNablaProtoBuf()
    OPTI_BUF_EXT = ['.optimizer']
    opti_h5_files = {}
    tmpdir = tempfile.mkdtemp()

    if isinstance(filenames, list) or isinstance(filenames, tuple):
        pass
    elif isinstance(filenames, str) or hasattr(filenames, 'read'):
        filenames = [filenames]

    for filename in filenames:
        if isinstance(filename, str):
            _, ext = os.path.splitext(filename)
        else:
            ext = extension

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            if not parameter_only:
                with get_file_handle_load(filename, ext) as f:
                    try:
                        text_format.Merge(f.read(), proto)
                    except:
                        logger.critical('Failed to read {}.'.format(filename))
                        logger.critical(
                            '2 byte characters may be used for file name or folder name.'
                        )
                        raise
            if len(proto.parameter) > 0:
                if not exclude_parameter:
                    nn.load_parameters(filename, extension=ext)
        elif ext in ['.protobuf', '.h5']:
            if not exclude_parameter:
                nn.load_parameters(filename, extension=ext)
            else:
                logger.info('Skip loading parameter.')

        elif ext == '.nnp':
            with get_file_handle_load(filename, ext) as nnp:
                for name in nnp.namelist():
                    _, ext = os.path.splitext(name)
                    if name == 'nnp_version.txt':
                        pass  # TODO currently do nothing with version.
                    elif ext in ['.nntxt', '.prototxt']:
                        if not parameter_only:
                            with nnp.open(name, 'r') as f:
                                text_format.Merge(f.read(), proto)
                        if len(proto.parameter) > 0:
                            if not exclude_parameter:
                                with nnp.open(name, 'r') as f:
                                    nn.load_parameters(f, extension=ext)
                    elif ext in ['.protobuf', '.h5']:
                        if not exclude_parameter:
                            with nnp.open(name, 'r') as f:
                                nn.load_parameters(f, extension=ext)
                        else:
                            logger.info('Skip loading parameter.')
                    elif ext in OPTI_BUF_EXT:
                        buf_type = get_buf_type(name)
                        if buf_type == 'protobuf':
                            with nnp.open(name, 'r') as f:
                                with get_file_handle_load(
                                        f, '.protobuf') as opti_p:
                                    opti_proto.MergeFromString(opti_p.read())
                        elif buf_type == 'h5':
                            nnp.extract(name, tmpdir)
                            opti_h5_files[name] = os.path.join(tmpdir, name)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            import nnabla_ext.cudnn
        elif 'cuda:float' in default_context.backend:
            try:
                import nnabla_ext.cudnn
            except:
                pass
        try:
            x = nn.Variable()
            y = nn.Variable()
            func = F.ReLU(default_context, inplace=True)
            func.setup([x], [y])
            func.forward([x], [y])
        except:
            logger.warn('Fallback to CPU context.')
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cpu
        default_context = nnabla_ext.cpu.context()

    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.local_rank)
    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    info.datasets = _datasets(
        proto, prepare_data_iterator if prepare_data_iterator is not None else
        info.training_config.max_epoch > 0)

    info.networks = _networks(proto, default_context, batch_size)

    info.optimizers = _optimizers(proto, default_context, info.networks,
                                  info.datasets)
    _load_optimizer_checkpoint(opti_proto, opti_h5_files, info)
    shutil.rmtree(tmpdir)

    info.monitors = _monitors(proto, default_context, info.networks,
                              info.datasets)

    info.executors = _executors(proto, info.networks)

    return info
Ejemplo n.º 2
0
def load(filenames,
         prepare_data_iterator=True,
         batch_size=None,
         exclude_parameter=False,
         parameter_only=False):
    '''load
    Load network information from files.

    Args:
        filenames (list): List of filenames.
    Returns:
        dict: Network information.
    '''
    class Info:
        pass

    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()
    for filename in filenames:
        _, ext = os.path.splitext(filename)

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            if not parameter_only:
                with open(filename, 'rt') as f:
                    try:
                        text_format.Merge(f.read(), proto)
                    except:
                        logger.critical('Failed to read {}.'.format(filename))
                        logger.critical(
                            '2 byte characters may be used for file name or folder name.'
                        )
                        raise
            if len(proto.parameter) > 0:
                if not exclude_parameter:
                    nn.load_parameters(filename)
        elif ext in ['.protobuf', '.h5']:
            if not exclude_parameter:
                nn.load_parameters(filename)
            else:
                logger.info('Skip loading parameter.')

        elif ext == '.nnp':
            try:
                tmpdir = tempfile.mkdtemp()
                with zipfile.ZipFile(filename, 'r') as nnp:
                    for name in nnp.namelist():
                        _, ext = os.path.splitext(name)
                        if name == 'nnp_version.txt':
                            nnp.extract(name, tmpdir)
                            with open(os.path.join(tmpdir, name), 'rt') as f:
                                pass  # TODO currently do nothing with version.
                        elif ext in ['.nntxt', '.prototxt']:
                            nnp.extract(name, tmpdir)
                            if not parameter_only:
                                with open(os.path.join(tmpdir, name),
                                          'rt') as f:
                                    text_format.Merge(f.read(), proto)
                            if len(proto.parameter) > 0:
                                if not exclude_parameter:
                                    nn.load_parameters(
                                        os.path.join(tmpdir, name))
                        elif ext in ['.protobuf', '.h5']:
                            nnp.extract(name, tmpdir)
                            if not exclude_parameter:
                                nn.load_parameters(os.path.join(tmpdir, name))
                            else:
                                logger.info('Skip loading parameter.')
            finally:
                shutil.rmtree(tmpdir)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            import nnabla_ext.cudnn
        elif 'cuda:float' in default_context.backend:
            try:
                import nnabla_ext.cudnn
            except:
                pass
        try:
            x = nn.Variable()
            y = nn.Variable()
            func = F.ReLU(default_context, inplace=True)
            func.setup([x], [y])
            func.forward([x], [y])
        except:
            logger.warn('Fallback to CPU context.')
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cpu
        default_context = nnabla_ext.cpu.context()

    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.rank)
    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    info.datasets = _datasets(
        proto, prepare_data_iterator if prepare_data_iterator is not None else
        info.training_config.max_epoch > 0)

    info.networks = _networks(proto, default_context, batch_size)

    info.optimizers = _optimizers(proto, default_context, info.networks,
                                  info.datasets)

    info.monitors = _monitors(proto, default_context, info.networks,
                              info.datasets)

    info.executors = _executors(proto, info.networks)

    return info