Ejemplo n.º 1
0
def load_train_state(filename, info):
    info.exclude_parameter = False
    info.parameter_only = False
    file_loaders = get_decorated_file_loader()
    info.parameter_scope = nn.parameter.get_current_parameter_scope()
    load_files(info, file_loaders, filename)
    logger.info("Load training resume states: {}".format(filename))
Ejemplo n.º 2
0
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    if isinstance(path, str):
        _, ext = os.path.splitext(path)
    else:
        ext = extension

    ctx = FileHandlerContext()
    if proto is None:
        ctx.proto = nnabla_pb2.NNablaProtoBuf()
    else:
        ctx.proto = proto
    ctx.needs_proto = needs_proto
    # Get parameter file loaders
    file_loaders = get_parameter_file_loader()
    load_files(ctx, file_loaders, path, ext)
    return ctx.proto
Ejemplo n.º 3
0
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False, extension=".nntxt", context=None):
    '''load
    Load network information from files.

    Args:
        filenames (list): file-like object or List of filenames.
        extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
    Returns:
        dict: Network information.
    '''
    class Info:
        pass
    info = Info()

    info.prepare_data_iterator = prepare_data_iterator
    info.batch_size = batch_size
    info.exclude_parameter = exclude_parameter
    info.parameter_only = parameter_only
    info.proto = nnabla_pb2.NNablaProtoBuf()

    # first stage file loaders
    file_loaders = get_initial_file_loader()

    # using global parameter scope, keep consistency with legacy implementation.
    # To avoid to surprise previous developers, but it is better using
    # stand-alone OrderedDict() instance.
    info.parameter_scope = nn.parameter.get_current_parameter_scope()
    load_files(info, file_loaders, filenames, extension)

    default_context = None
    if context:
        if context == 'cpu':
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
        else:
            cs = context.split(':')
            if cs[0] == 'cudnn':
                if len(cs) == 1:
                    devid = 0
                else:
                    devid = int(cs[1])
            import nnabla_ext.cudnn
            default_context = nnabla_ext.cudnn.context(device_id=devid)
        if default_context is None:
            logger.warn('Invalid context [{}]'.format(context))
        elif info.proto.HasField('global_config'):
            info.global_config = _global_config(proto)
            info.global_config.default_context = default_context

    if default_context is None:
        if info.proto.HasField('global_config'):
            info.global_config = _global_config(info.proto)
            default_context = info.global_config.default_context
            if 'cuda' in default_context.backend:
                import nnabla_ext.cudnn
            elif 'cuda:float' in default_context.backend:
                try:
                    import nnabla_ext.cudnn
                except:
                    pass
        else:
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
            info.global_config = _global_config(
                None, default_context=default_context)

    default_context = _check_context(default_context)
    logger.log(99, 'Using context "{}"'.format(default_context))
    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.local_rank)
    if info.proto.HasField('training_config'):
        info.training_config = _training_config(info.proto)

    info.default_context = default_context
    info.datasets = _datasets(
        info.proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0)

    info.renamed_variables = {}
    info.networks = _networks(info, nn.graph_def.ProtoGraph.from_proto(info.proto, param_scope=info.parameter_scope,
                                                                       rng=numpy.random.RandomState(0)))

    info.optimizers = _optimizers(info)
    info.monitors = _monitors(info)
    info.executors = _executors(info)

    return info