Exemplo n.º 1
0
    def _export_files(self, outdir):
        with open('{}/nnp_version.txt'.format(outdir), 'w') as f:
            f.write('0.1\n')
        if self._parameter_type == 'included':
            self.write_nntxt('{}/network.nntxt'.format(outdir), self._nnp)
        else:
            nnp_wo_parameter = nnabla_pb2.NNablaProtoBuf()
            nnp_wo_parameter.CopyFrom(self._nnp)
            nnp_wo_parameter.ClearField('parameter')
            self._write_nntxt('{}/network.nntxt'.format(outdir),
                              nnp_wo_parameter)

            if self._parameter_type == 'protobuf':
                nnp_parameter_only = nnabla_pb2.NNablaProtoBuf()
                for param in self._nnp.parameter:
                    parameter = nnp_parameter_only.parameter.add()
                    parameter.CopyFrom(param)
                self._write_protobuf('{}/parameter.protobuf'.format(outdir),
                                     nnp_parameter_only)
            elif self._parameter_type == 'h5':
                self._write_h5('{}/parameter.h5'.format(outdir), self._nnp)
            elif self._parameter_type == 'none':
                pass  # store without param.
            else:
                print('Unsupported parameter type `{}`.'.format(
                    self._parameter_type))
Exemplo n.º 2
0
def load_parameters(path):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    _, ext = os.path.splitext(path)
    if ext == '.h5':
        # TODO temporary work around to suppress FutureWarning message.
        import warnings
        warnings.simplefilter('ignore', category=FutureWarning)
        import h5py
        with h5py.File(path, 'r') as hd:
            keys = []

            def _get_keys(name):
                ds = hd[name]
                if not isinstance(ds, h5py.Dataset):
                    # Group
                    return
                # To preserve order of parameters
                keys.append((ds.attrs.get('index', None), name))

            hd.visit(_get_keys)
            for _, key in sorted(keys):
                ds = hd[key]
                var = get_parameter_or_create(key,
                                              ds.shape,
                                              need_grad=ds.attrs['need_grad'])
                var.data.cast(ds.dtype)[...] = ds[...]
    elif ext == '.protobuf':
        proto = nnabla_pb2.NNablaProtoBuf()
        with open(path, 'rb') as f:
            proto.MergeFromString(f.read())
            set_parameter_from_proto(proto)
    elif ext == '.nntxt' or ext == '.prototxt':
        proto = nnabla_pb2.NNablaProtoBuf()
        with open(path, 'r') as f:
            text_format.Merge(f.read(), proto)
            set_parameter_from_proto(proto)

    elif ext == '.nnp':
        try:
            tmpdir = tempfile.mkdtemp()
            with zipfile.ZipFile(path, 'r') as nnp:
                for name in nnp.namelist():
                    nnp.extract(name, tmpdir)
                    _, ext = os.path.splitext(name)
                    if ext in ['.protobuf', '.h5']:
                        load_parameters(os.path.join(tmpdir, name))
        finally:
            shutil.rmtree(tmpdir)
    logger.info("Parameter load ({}): {}".format(format, path))
Exemplo n.º 3
0
def save_parameters(path):
    """Save all parameters into a file with the specified format.

    Currently hdf5 and protobuf formats are supported.

    Args:
      path : path or file object
    """
    _, ext = os.path.splitext(path)
    params = get_parameters(grad_only=False)
    if ext == '.h5':
        import h5py
        with h5py.File(path, 'w') as hd:
            params = get_parameters(grad_only=False)
            for i, (k, v) in enumerate(iteritems(params)):
                hd[k] = v.d
                hd[k].attrs['need_grad'] = v.need_grad
                # To preserve order of parameters
                hd[k].attrs['index'] = i
    elif ext == '.protobuf':
        proto = nnabla_pb2.NNablaProtoBuf()
        for variable_name, variable in params.items():
            parameter = proto.parameter.add()
            parameter.variable_name = variable_name
            parameter.shape.dim.extend(variable.shape)
            parameter.data.extend(numpy.array(variable.d).flatten().tolist())
            parameter.need_grad = variable.need_grad

        with open(path, "wb") as f:
            f.write(proto.SerializeToString())
    else:
        logger.critical('Only supported hdf5 or protobuf.')
        assert False
    logger.info("Parameter save ({}): {}".format(ext, path))
Exemplo n.º 4
0
def _opti_file_loader(ctx, fileloaders, nnp, filename, ext):
    '''.optimizer
    optimizer file loader
    This loader only handles .optimizer file.
    '''
    file_type = get_buf_type(filename)
    if file_type == 'protobuf':
        opti_proto = nnabla_pb2.NNablaProtoBuf()
        with get_file_handle_load(nnp, filename, '.protobuf') as f:
            opti_proto.MergeFromString(f.read())
        for p_opti in opti_proto.optimizer:
            o = ctx.optimizers.get(p_opti.name, None)
            if o:
                o.solver.set_states_from_protobuf(p_opti)
            else:
                logger.warn(
                    'No matched optimizer is found for {}.'.format(filename))
    elif file_type == 'h5':
        loaded = False
        for o in ctx.optimizers.values():
            key = '{}_{}_optimizer.h5.optimizer'.format(
                o.name,
                re.sub(r'(|Cuda)$', '', str(o.solver.name))
            )
            if key == filename:
                o.solver.set_states(_load_solve_state_from_h5(nnp, filename))
                loaded = True
        if not loaded:
            logger.warn(
                "No matched optimizer is found for {}.".format(filename))
Exemplo n.º 5
0
def _load_nnp_to_proto(nnp_path):
    import google.protobuf.text_format as text_format
    import tempfile
    import zipfile
    import shutil
    proto = nnabla_pb2.NNablaProtoBuf()

    tmpdir = tempfile.mkdtemp()
    try:
        with zipfile.ZipFile(nnp_path, "r") as nnp:
            for name in nnp.namelist():
                _, ext = os.path.splitext(name)
                if name == "nnp_version.txt":
                    pass  # Currently nnp_version.txt is ignored
                elif ext in [".nntxt", ".prototxt"]:
                    nnp.extract(name, tmpdir)
                    with open(os.path.join(tmpdir, name), "rt") as f:
                        text_format.Merge(f.read(), proto)
                elif ext in [".protobuf", ".h5"]:
                    nnp.extract(name, tmpdir)
                    nn.load_parameters(os.path.join(tmpdir, name))
    finally:
        shutil.rmtree(tmpdir)

    return proto
Exemplo n.º 6
0
def _pb_parameter_file_loader(ctx, file_loaders, nnp, filename, ext):
    with get_file_handle_load(nnp, filename, ext) as f:
        try:
            ctx.proto
        except:
            ctx.proto = nnabla_pb2.NNablaProtoBuf()
        ctx.proto.MergeFromString(f.read())
        nn.parameter.set_parameter_from_proto(ctx.proto)
Exemplo n.º 7
0
def _load_nntxt_to_proto(nntxt_path):
    import google.protobuf.text_format as text_format
    proto = nnabla_pb2.NNablaProtoBuf()

    with open(nntxt_path, "rt") as f:
        text_format.Merge(f.read(), proto)

    return proto
Exemplo n.º 8
0
    def execute(self):
        self._nnp = nnabla_pb2.NNablaProtoBuf()
        other_files = []
        for ifile in self._args:
            print('Importing {}'.format(ifile))
            ext = os.path.splitext(ifile)[1].lower()
            if ext == '.nnp':
                try:
                    tmpdir = tempfile.mkdtemp()
                    with zipfile.ZipFile(ifile, 'r') as nnpzip:
                        for name in nnpzip.namelist():
                            if os.path.splitext(name)[1].lower() in [
                                    '.nntxt', '.prototxt'
                            ]:
                                nnpzip.extract(name, tmpdir)
                                with open(os.path.join(tmpdir, name),
                                          'rt') as f:
                                    text_format.Merge(f.read(), self._nnp)
                        for name in nnpzip.namelist():  # Param
                            if os.path.splitext(name)[1].lower() in [
                                    '.protobuf', '.h5'
                            ]:
                                nnpzip.extract(name, tmpdir)
                                self.load_parameters(os.path.join(
                                    tmpdir, name))
                finally:
                    shutil.rmtree(tmpdir)
            elif ext in ['.nntxt', '.prototxt']:
                with open(ifile, 'rt') as f:
                    text_format.Merge(f.read(), self._nnp)
            elif ext in ['.protobuf', '.h5']:
                self.load_parameters(ifile)
            else:
                other_files.append(ifile)

        executor_name = self._nnp.executor[0].network_name
        network = self.find_network(executor_name)
        parameter_variable_list = self.find_parameter_variable(network)
        if parameter_variable_list and not self._nnp.parameter:
            self.generate_parameters_data(parameter_variable_list,
                                          network.batch_size)

        if self._executor_index is not None:
            if self._executor_index < len(self._nnp.executor):
                self._nnp = self._shrink_with_executor(
                    self._nnp.executor[self._executor_index])

        if self._expand_network:
            self._nnp = expander.NnpExpander(self._nnp).execute()

        class nnp:
            pass

        nnp.protobuf = self._nnp
        nnp.other_files = other_files
        return nnp
Exemplo n.º 9
0
def _protobuf_parameter_file_saver(ctx, filename, ext):
    proto = nnabla_pb2.NNablaProtoBuf()
    for variable_name, variable in ctx.parameters.items():
        parameter = proto.parameter.add()
        parameter.variable_name = variable_name
        parameter.shape.dim.extend(variable.shape)
        parameter.data.extend(numpy.array(variable.d).flatten().tolist())
        parameter.need_grad = variable.need_grad
    with get_file_handle_save(filename, ext) as f:
        f.write(proto.SerializeToString())
Exemplo n.º 10
0
def _load_nntxt_to_proto(nntxt_path):
    import google.protobuf.text_format as text_format
    proto = nnabla_pb2.NNablaProtoBuf()
    if hasattr(nntxt_path, 'read'):
        nntxt = nntxt_path.read()
    else:
        with open(nntxt_path, "rt") as f:
            nntxt = f.read()
    text_format.Merge(nntxt, proto)

    return proto
Exemplo n.º 11
0
    def execute(self):
        nnp = nnabla_pb2.NNablaProtoBuf()
        nnp.CopyFrom(self._nnp)
        for network in nnp.network:
            self._expand_network(network)
        for optimizer in nnp.optimizer:
            self._expand_parameter_variable(optimizer)
        for executor in nnp.executor:
            self._expand_parameter_variable(executor)

        return nnp
Exemplo n.º 12
0
def load(filenames, prepare_data_iterator=True):
    '''load
    Load network information from files.

    Args:
        filenames (list): List of filenames.
    Returns:
        dict: Network infomation.
    '''
    class Info:
        pass
    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()
    for filename in filenames:
        _, ext = os.path.splitext(filename)
        if 'txt' in ext:
            with open(filename, 'rt') as f:
                text_format.Merge(f.read(), proto)
        elif ext in ['.protobuf', '.h5']:
            nn.load_parameters(filename, proto)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
    else:
        default_context = nn.context()

    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    if len(proto.dataset) > 0:
        info.datasets = _datasets(proto, prepare_data_iterator)

    if len(proto.network) > 0:
        info.networks = _networks(proto, default_context)

    if len(proto.optimizer) > 0:
        info.optimizers = _optimizers(
            proto, default_context, info.networks, info.datasets)

    if len(proto.monitor) > 0:
        info.monitors = _monitors(
            proto, default_context, info.networks, info.datasets)

    if len(proto.executor) > 0:
        info.executors = _executors(proto, info.networks)

    return info
Exemplo n.º 13
0
def load_parameters(path):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    _, ext = os.path.splitext(path)
    if ext == '.h5':
        import h5py
        with h5py.File(path, 'r') as hd:
            keys = []

            def _get_keys(name):
                ds = hd[name]
                if not isinstance(ds, h5py.Dataset):
                    # Group
                    return
                # To preserve order of parameters
                keys.append((ds.attrs.get('index', None), name))

            hd.visit(_get_keys)
            for _, key in sorted(keys):
                ds = hd[key]
                var = get_parameter_or_create(key,
                                              ds.shape,
                                              need_grad=ds.attrs['need_grad'])
                var.data.cast(ds.dtype)[...] = ds[...]
    elif ext == '.protobuf':
        proto = nnabla_pb2.NNablaProtoBuf()
        with open(path, 'rb') as f:
            proto.MergeFromString(f.read())
            for parameter in proto.parameter:
                var = get_parameter_or_create(parameter.variable_name,
                                              parameter.shape.dim)
                param = numpy.reshape(parameter.data, parameter.shape.dim)
                var.d = param
                var.need_grad = parameter.need_grad
    elif ext == '.nnp':
        try:
            tmpdir = tempfile.mkdtemp()
            with zipfile.ZipFile(path, 'r') as nnp:
                for name in nnp.namelist():
                    nnp.extract(name, tmpdir)
                    _, ext = os.path.splitext(name)
                    if ext in ['.protobuf', '.h5']:
                        load_parameters(os.path.join(tmpdir, name))
        finally:
            shutil.rmtree(tmpdir)
    logger.info("Parameter load ({}): {}".format(format, path))
Exemplo n.º 14
0
    def execute(self):
        nnp = nnabla_pb2.NNablaProtoBuf()
        nnp.CopyFrom(self._nnp)

        nnp.ClearField('network')
        for network in self._nnp.network:
            net = nnp.network.add()
            net.CopyFrom(self._expand_network(network))

        for optimizer in nnp.optimizer:
            self._expand_parameter_variable(optimizer)

        for executor in nnp.executor:
            self._expand_parameter_variable(executor)

        return nnp
Exemplo n.º 15
0
def load_parameters(path, proto=None):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    _, ext = os.path.splitext(path)
    if proto is None:
        proto = nnabla_pb2.NNablaProtoBuf()
    if ext == '.h5':
        import h5py
        with h5py.File(path, 'r') as hd:
            keys = []

            def _get_keys(name):
                ds = hd[name]
                if not isinstance(ds, h5py.Dataset):
                    # Group
                    return
                # To preserve order of parameters
                keys.append((ds.attrs.get('index', None), name))

            hd.visit(_get_keys)
            for _, key in sorted(keys):
                ds = hd[key]
                var = get_parameter_or_create(key,
                                              ds.shape,
                                              need_grad=ds.attrs['need_grad'])
                var.data.cast(ds.dtype)[...] = ds[...]
                parameter = proto.parameter.add()
                parameter.variable_name = key
                parameter.shape.dim.extend(var.shape)
                parameter.data.extend(numpy.array(var.d).flatten().tolist())
                parameter.need_grad = var.need_grad
    elif ext == '.protobuf':
        with open(path, 'rb') as f:
            proto.MergeFromString(f.read())
            for parameter in proto.parameter:
                var = get_parameter_or_create(parameter.variable_name,
                                              parameter.shape.dim)
                param = numpy.reshape(parameter.data, parameter.shape.dim)
                var.d = param
                var.need_grad = parameter.need_grad
    logger.info("Parameter load ({}): {}".format(format, path))
    return proto
Exemplo n.º 16
0
    def execute(self):
        self._nnp = nnabla_pb2.NNablaProtoBuf()
        other_files = []
        for ifile in self._args:
            print('Importing {}'.format(ifile))
            ext = os.path.splitext(ifile)[1].lower()
            if ext == '.nnp':
                try:
                    tmpdir = tempfile.mkdtemp()
                    with zipfile.ZipFile(ifile, 'r') as nnpzip:
                        for name in nnpzip.namelist():
                            if os.path.splitext(name)[1].lower() in [
                                    '.nntxt', '.prototxt'
                            ]:
                                nnpzip.extract(name, tmpdir)
                                with open(os.path.join(tmpdir, name),
                                          'rt') as f:
                                    text_format.Merge(f.read(), self._nnp)
                        for name in nnpzip.namelist():  # Param
                            if os.path.splitext(name)[1].lower() in [
                                    '.protobuf', '.h5'
                            ]:
                                nnpzip.extract(name, tmpdir)
                                self.load_parameters(os.path.join(
                                    tmpdir, name))
                finally:
                    shutil.rmtree(tmpdir)
            elif ext in ['.nntxt', '.prototxt']:
                with open(ifile, 'rt') as f:
                    text_format.Merge(f.read(), self._nnp)
            elif ext in ['.protobuf', '.h5']:
                self.load_parameters(ifile)
            else:
                other_files.append(ifile)

        if self._expand_network:
            self._nnp = expander.NnpExpander(self._nnp).execute()

        class nnp:
            pass

        nnp.protobuf = self._nnp
        nnp.other_files = other_files
        return nnp
Exemplo n.º 17
0
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    if isinstance(path, str):
        _, ext = os.path.splitext(path)
    else:
        ext = extension

    ctx = FileHandlerContext()
    if proto is None:
        ctx.proto = nnabla_pb2.NNablaProtoBuf()
    else:
        ctx.proto = proto
    ctx.needs_proto = needs_proto
    # Get parameter file loaders
    file_loaders = get_parameter_file_loader()
    load_files(ctx, file_loaders, path, ext)
    return ctx.proto
Exemplo n.º 18
0
    def _expand_network(self, network):

        print(' Expanding {}.'.format(network.name))

        repeat_ids = collections.OrderedDict()
        for ri in network.repeat_info:
            repeat_ids[ri.id] = ri.times

        # Check whether parameter name complies with old rule.
        self._old_version_param_name = False
        for param in self._parameters:
            for ri in repeat_ids:
                m = re.search('{}\[([0-9]+)\]$'.format(ri), param)
                if m:
                    if int(m.group(1)) < repeat_ids[ri]:
                        self._old_version_param_name = True

        # Expand repeat
        network = self._expand_repeat(network)

        functions = []
        for func in network.function:
            functions.append((func.name,
                              func.type,
                              [n for n in func.input],
                              [n for n in func.output]))

        sorted_functions = self._sort_functions(functions)
        func_list = []
        for f in functions:
            func_list.append(f[0])

        net = nnabla_pb2.NNablaProtoBuf().network.add()
        net.CopyFrom(network)
        net.ClearField('function')
        for f in sorted_functions:
            func = net.function.add()
            func.CopyFrom(network.function[func_list.index(f[0])])

        return net
Exemplo n.º 19
0
def save_parameters(path, params=None, extension=None):
    """Save all parameters into a file with the specified format.

    Currently hdf5 and protobuf formats are supported.

    Args:
      path : path or file object
      params (dict, optional): Parameters to be saved. Dictionary is of a parameter name (:obj:`str`) to :obj:`~nnabla.Variable`.
    """
    if isinstance(path, str):
        _, ext = os.path.splitext(path)
    else:
        ext = extension
    params = get_parameters(grad_only=False) if params is None else params
    if ext == '.h5':
        # TODO temporary work around to suppress FutureWarning message.
        import warnings
        warnings.simplefilter('ignore', category=FutureWarning)
        import h5py
        with get_file_handle_save(path, ext) as hd:
            for i, (k, v) in enumerate(iteritems(params)):
                hd[k] = v.d
                hd[k].attrs['need_grad'] = v.need_grad
                # To preserve order of parameters
                hd[k].attrs['index'] = i
    elif ext == '.protobuf':
        proto = nnabla_pb2.NNablaProtoBuf()
        for variable_name, variable in params.items():
            parameter = proto.parameter.add()
            parameter.variable_name = variable_name
            parameter.shape.dim.extend(variable.shape)
            parameter.data.extend(numpy.array(variable.d).flatten().tolist())
            parameter.need_grad = variable.need_grad

        with get_file_handle_save(path, ext) as f:
            f.write(proto.SerializeToString())
    else:
        logger.critical('Only supported hdf5 or protobuf.')
        assert False
    logger.info("Parameter save ({}): {}".format(ext, path))
Exemplo n.º 20
0
def save_optimizer_states(filebase, ext, train_config):
    filelist = []
    if ext == '.protobuf':
        filename = filebase + '_optimizer.protobuf.optimizer'
        proto = nnabla_pb2.NNablaProtoBuf()
        proto_optimizers = []
        for o in train_config.optimizers.values():
            proto_optimizers.append(_create_optimizer_lite(o.optimizer))
        proto.optimizer.extend(proto_optimizers)
        with get_file_handle_save(filename, '.protobuf') as f:
            f.write(proto.SerializeToString())
            filelist.append(filename)
    else:
        for o in train_config.optimizers.values():
            f_name = '{}_{}_optimizer.h5'.format(
                o.optimizer.name,
                re.sub(r'(|Cuda)$', '', str(o.optimizer.solver.name)))
            filename = '{}_{}'.format(filebase, f_name)
            o.optimizer.solver.save_states(filename)
            name_ext = '{}.optimizer'.format(filename)
            os.rename(filename, name_ext)
            filelist.append(name_ext)
    return filelist
Exemplo n.º 21
0
    def _shrink_with_executor(self, executor):
        print(' Try to leave only executor[{}].'.format(executor.name))
        network = None
        for n in self._nnp.network:
            if n.name == executor.network_name:
                network = n
        if network is None:
            return None

        nnp = nnabla_pb2.NNablaProtoBuf()
        nnp.CopyFrom(self._nnp)

        nnp.ClearField('optimizer')
        nnp.ClearField('monitor')

        nnp.ClearField('network')
        net = nnp.network.add()
        net.CopyFrom(network)

        nnp.ClearField('executor')
        exe = nnp.executor.add()
        exe.CopyFrom(executor)

        return nnp
Exemplo n.º 22
0
def create_function_nnp(inputs, outputs, func_name, func_args, func_kwargs):
    if func_name is None:
        return

    for category_name, category in nnabla.utils.converter.get_category_info(
    ).items():
        if func_name in category:
            function = category[func_name]

    nnp = nnabla_pb2.NNablaProtoBuf()
    net = nnp.network.add()
    net.name = 'network1'
    net.batch_size = 1

    func = net.function.add()
    func.name = func_name
    func.type = func_name

    # Prepare input
    func_inputs = []
    data_names = []
    parameter_names = []
    input_data = []
    for n, i in enumerate(inputs):
        if i is not None:
            if len(list(function['inputs'].items())) == 1:
                input_name, input_info = list(function['inputs'].items())[0]
                if 'variadic' in input_info and input_info['variadic']:
                    input_name += str(n)
            else:
                input_name, input_info = list(function['inputs'].items())[n]
            func_inputs.append(input_name)
            var = net.variable.add()
            var.name = input_name
            if 'parameter' in input_info and input_info['parameter']:
                parameter_names.append(input_name)

                var.type = 'Parameter'
                shape = list(i.d.shape)[:]
                if func.name == 'BatchNormalization':
                    shape = [1] + shape
                var.shape.dim.extend(shape)

                param = nnp.parameter.add()
                param.variable_name = input_name
                param.shape.dim.extend(shape)
                param.data.extend(i.d.flatten())

            else:
                input_data.append(i.d.flatten())
                data_names.append(input_name)

                var.type = 'Buffer'
                shape = list(i.d.shape)[:]
                # exclude the cases no need to extend dimension
                if input_name == 'rmean' or input_name == 't':
                    pass
                elif func.name == 'PReLU' and input_name == "x1":
                    pass
                elif func.name == 'Transpose':
                    pass
                elif func.name == 'Concatenate':
                    pass
                else:
                    shape = [1] + shape
                var.shape.dim.extend(shape)

    func.input.extend(func_inputs)

    # Prepare output
    func_outputs = []
    output_data = []
    for n, o in enumerate(outputs):
        output_name = 'y{}'.format(n)
        func_outputs.append(output_name)
        var = net.variable.add()
        var.name = output_name
        var.type = 'Buffer'
        shape = list(o.d.shape)[:]
        shape = [-1] + shape
        var.shape.dim.extend(shape)
        output_data.append(o.d.flatten())

    func.output.extend(func_outputs)

    # Prepare argument
    if 'arguments' in function:
        for n, (arg_name, arg) in enumerate(function['arguments'].items()):
            param = eval('func.{}_param'.format(function['snake_name']))
            if not func_args and not func_kwargs:
                continue
            if func.name == 'Interpolate':
                del func_args[0]
            if n < len(func_args):
                a = func_args[n]
            else:
                if func.name == 'Concatenate' or func.name == 'Stack':
                    a = func_kwargs['axis']
                else:
                    a = func_kwargs.get('keepdims')
            # This is used to fix the problem of flip (axes == None)
            if a is None:
                f = ['Sum', 'Mean', 'Max', 'Min', 'Prod']
                if 'axes' in arg_name:
                    if func.name in f:
                        a = net.variable[0].shape.dim[:-1]
                        a = [x - 1 for x in a]
                    else:
                        a = len(net.variable[0].shape.dim) - 2

            if a is not None:
                if 'axis' == arg_name:
                    if func.name == 'Concatenate':
                        pass
                    else:
                        a += 1
                if 'axes' in arg_name:
                    if func.name == 'Transpose':
                        pass
                    else:
                        if isinstance(a, tuple) or isinstance(a, list):
                            a = list(a)
                        else:
                            a = [a]
                        a = [x + 1 for x in a]
                if isinstance(a, tuple) or isinstance(a, list):
                    if arg['type'] == 'Shape':
                        exec('param.{}.dim.extend(list(a))'.format(arg_name))
                    else:
                        exec('param.{}.extend(a)'.format(arg_name))
                elif isinstance(a, numpy.ndarray):
                    a = a.flatten()
                    if arg['type'] == 'Shape':
                        if function['snake_name'] == 'broadcast':
                            exec('param.{}.dim.extend([1] + list(a))'.format(
                                arg_name))
                        else:
                            exec('param.{}.dim.extend(list(a))'.format(
                                arg_name))
                    else:
                        exec('param.{}.extend(a)'.format(arg_name))
                else:
                    if 'repeated' in arg['type']:
                        exec('param.{}.extend([a])'.format(arg_name))
                    elif arg['type'] == 'string':
                        exec('param.{} = "{}"'.format(arg_name, a))
                    else:
                        if arg_name == 'base_axis':
                            a = a + 1
                        exec('param.{} = {}'.format(arg_name, a))

    # Prepare executor
    exe = nnp.executor.add()
    exe.name = 'inference'
    exe.network_name = 'network1'
    for d in data_names:
        dat = exe.data_variable.add()
        dat.variable_name = d
        dat.data_name = d

    for o in func_outputs:
        out = exe.output_variable.add()
        out.variable_name = o
        out.data_name = o

    for p in parameter_names:
        par = exe.parameter_variable.add()
        par.variable_name = p

    return nnp, input_data, output_data
Exemplo n.º 23
0
def load(filenames, prepare_data_iterator=True, batch_size=None):
    '''load
    Load network information from files.

    Args:
        filenames (list): List of filenames.
    Returns:
        dict: Network infomation.
    '''
    class Info:
        pass
    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()
    for filename in filenames:
        _, ext = os.path.splitext(filename)

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            with open(filename, 'rt') as f:
                text_format.Merge(f.read(), proto)
        elif ext in ['.protobuf', '.h5']:
            nn.load_parameters(filename)

        elif ext == '.nnp':
            try:
                tmpdir = tempfile.mkdtemp()
                with zipfile.ZipFile(filename, 'r') as nnp:
                    for name in nnp.namelist():
                        _, ext = os.path.splitext(name)
                        if name == 'nnp_version.txt':
                            nnp.extract(name, tmpdir)
                            with open(os.path.join(tmpdir, name), 'rt') as f:
                                pass  # Currently nnp_version.txt is ignored.
                        elif ext in ['.nntxt', '.prototxt']:
                            nnp.extract(name, tmpdir)
                            with open(os.path.join(tmpdir, name), 'rt') as f:
                                text_format.Merge(f.read(), proto)
                        elif ext in ['.protobuf', '.h5']:
                            nnp.extract(name, tmpdir)
                            nn.load_parameters(os.path.join(tmpdir, name))
            finally:
                shutil.rmtree(tmpdir)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            try:
                import nnabla_ext.cuda.cudnn
            except:
                pass
    else:
        default_context = nn.context()

    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    if len(proto.dataset) > 0:
        info.datasets = _datasets(proto, prepare_data_iterator)

    if len(proto.network) > 0:
        info.networks = _networks(proto, default_context, batch_size)

    if len(proto.optimizer) > 0:
        info.optimizers = _optimizers(
            proto, default_context, info.networks, info.datasets)

    if len(proto.monitor) > 0:
        info.monitors = _monitors(
            proto, default_context, info.networks, info.datasets)

    if len(proto.executor) > 0:
        info.executors = _executors(proto, info.networks)

    return info
Exemplo n.º 24
0
def create_proto(contents, include_params=False):
    proto = nnabla_pb2.NNablaProtoBuf()
    if 'global_config' in contents:
        proto.global_config.MergeFrom(
            _create_global_config(
                contents['global_config']['default_context']))
    if 'training_config' in contents:
        proto.training_config.MergeFrom(
            _create_training_config(
                contents['training_config']['max_epoch'],
                contents['training_config']['iter_per_epoch'],
                contents['training_config']['save_best']))
    networks = {}
    if 'networks' in contents:
        proto_nets = []
        for net in contents['networks']:
            networks[net['name']] = _create_network(net)
            proto_nets.append(networks[net['name']])
        proto.network.extend(proto_nets)
    datasets = {}
    if 'datasets' in contents:
        proto_datasets = []
        for d in contents['datasets']:
            if 'cache_dir' in d:
                cache_dir = d['cache_dir']
            else:
                cache_dir = None
            datasets[d['name']] = _create_dataset(d['name'], d['uri'],
                                                  cache_dir, d['variables'],
                                                  d['shuffle'],
                                                  d['batch_size'],
                                                  d['no_image_normalization'])
            proto_datasets.append(datasets[d['name']])
        proto.dataset.extend(proto_datasets)
    if 'optimizers' in contents:
        proto_optimizers = []
        for o in contents['optimizers']:
            proto_optimizers.append(
                _create_optimizer(o['name'], o['solver'],
                                  networks[o['network']],
                                  datasets[o['dataset']]))
        proto.optimizer.extend(proto_optimizers)
    if 'monitors' in contents:
        proto_monitors = []
        for m in contents['monitors']:
            proto_monitors.append(
                _create_monitor(m['name'], m['monitor'],
                                networks[m['network']],
                                datasets[m['dataset']]))
        proto.monitor.extend(proto_monitors)
    if 'executors' in contents:
        proto_executors = []
        for e in contents['executors']:
            proto_executors.append(
                _create_executor(e['name'], networks[e['network']], e['data'],
                                 e['output'], e.get('remp', {})))
        proto.executor.extend(proto_executors)

    if include_params is True:
        params = get_parameters(grad_only=False)
        for variable_name, variable in params.items():
            parameter = proto.parameter.add()
            parameter.variable_name = variable_name
            parameter.shape.dim.extend(variable.shape)
            parameter.data.extend(numpy.array(variable.d).flatten().tolist())
            parameter.need_grad = variable.need_grad

    return proto
Exemplo n.º 25
0
def _shrink_nnp(nnp, pos_start, pos_end):
    if len(nnp.protobuf.executor) != 1 or \
            len(nnp.protobuf.network) != 1:
        print('[ERROR] Please make only one network in nnp.')
        sys.exit(-1)
    from nnabla.utils import nnabla_pb2

    class _nnp:
        pass
    _nnp.protobuf = nnabla_pb2.NNablaProtoBuf()
    _nnp.other_files = nnp.other_files
    net = nnabla_pb2.NNablaProtoBuf().network.add()
    net.CopyFrom(nnp.protobuf.network[0])

    # Shrink network
    variables = {}
    net.ClearField('function')
    for i in range(pos_start, pos_end+1):
        f = nnp.protobuf.network[0].function[i]
        func = net.function.add()
        func.CopyFrom(f)
        for v in func.input:
            variables[v] = True
        for v in func.output:
            variables[v] = True

    net.ClearField('variable')
    for v in nnp.protobuf.network[0].variable:
        if v.name in variables:
            variables[v.name] = v.type
            var = net.variable.add()
            var.CopyFrom(v)

    # Shrink parameter
    params = []
    for param in nnp.protobuf.parameter:
        if param.variable_name in variables:
            p = nnabla_pb2.NNablaProtoBuf().parameter.add()
            p.CopyFrom(param)
            params.append(p)
    for p in params:
        param = _nnp.protobuf.parameter.add()
        param.CopyFrom(p)

    # Shrink executor
    exe = nnabla_pb2.NNablaProtoBuf().executor.add()
    exe.CopyFrom(nnp.protobuf.executor[0])

    exe.ClearField('data_variable')
    for vname in nnp.protobuf.network[0].function[pos_start].input:
        if variables[vname] == 'Buffer':
            v = exe.data_variable.add()
            v.variable_name = vname

    exe.ClearField('generator_variable')
    for var in nnp.protobuf.executor[0].generator_variable:
        if var.variable_name in variables:
            v = exe.generator_variable.add()
            v.CopyFrom(var)

    exe.ClearField('output_variable')
    for vname in nnp.protobuf.network[0].function[pos_end].output:
        if variables[vname] == 'Buffer':
            v = exe.output_variable.add()
            v.variable_name = vname

    exe.ClearField('parameter_variable')
    for var in nnp.protobuf.executor[0].parameter_variable:
        if var.variable_name in variables:
            v = exe.parameter_variable.add()
            v.CopyFrom(var)

    n = _nnp.protobuf.network.add()
    n.CopyFrom(net)
    e = _nnp.protobuf.executor.add()
    e.CopyFrom(exe)
    return _nnp
Exemplo n.º 26
0
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    if isinstance(path, str):
        _, ext = os.path.splitext(path)
    else:
        ext = extension

    if ext == '.h5':
        # TODO temporary work around to suppress FutureWarning message.
        import warnings
        warnings.simplefilter('ignore', category=FutureWarning)
        import h5py
        with get_file_handle_load(path, ext) as hd:
            keys = []

            def _get_keys(name):
                ds = hd[name]
                if not isinstance(ds, h5py.Dataset):
                    # Group
                    return
                # To preserve order of parameters
                keys.append((ds.attrs.get('index', None), name))
            hd.visit(_get_keys)
            for _, key in sorted(keys):
                ds = hd[key]

                var = get_parameter_or_create(
                    key, ds.shape, need_grad=ds.attrs['need_grad'])
                var.data.cast(ds.dtype)[...] = ds[...]

                if needs_proto:
                    if proto is None:
                        proto = nnabla_pb2.NNablaProtoBuf()
                    parameter = proto.parameter.add()
                    parameter.variable_name = key
                    parameter.shape.dim.extend(ds.shape)
                    parameter.data.extend(
                        numpy.array(ds[...]).flatten().tolist())
                    parameter.need_grad = False
                    if ds.attrs['need_grad']:
                        parameter.need_grad = True

    else:
        if proto is None:
            proto = nnabla_pb2.NNablaProtoBuf()

        if ext == '.protobuf':
            with get_file_handle_load(path, ext) as f:
                proto.MergeFromString(f.read())
                set_parameter_from_proto(proto)
        elif ext == '.nntxt' or ext == '.prototxt':
            with get_file_handle_load(path, ext) as f:
                text_format.Merge(f.read(), proto)
                set_parameter_from_proto(proto)

        elif ext == '.nnp':
            try:
                tmpdir = tempfile.mkdtemp()
                with get_file_handle_load(path, ext) as nnp:
                    for name in nnp.namelist():
                        nnp.extract(name, tmpdir)
                        _, ext = os.path.splitext(name)
                        if ext in ['.protobuf', '.h5']:
                            proto = load_parameters(os.path.join(
                                tmpdir, name), proto, needs_proto)
            finally:
                shutil.rmtree(tmpdir)
                logger.info("Parameter load ({}): {}".format(format, path))
        else:
            logger.error("Invalid parameter file '{}'".format(path))
    return proto
Exemplo n.º 27
0
def proto_from_str(nntxt_str):
    proto = nnabla_pb2.NNablaProtoBuf()
    text_format.Merge(nntxt_str, proto)
    return proto
Exemplo n.º 28
0
def load(filenames,
         prepare_data_iterator=True,
         batch_size=None,
         exclude_parameter=False,
         parameter_only=False):
    '''load
    Load network information from files.

    Args:
        filenames (list): List of filenames.
    Returns:
        dict: Network information.
    '''
    class Info:
        pass

    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()
    for filename in filenames:
        _, ext = os.path.splitext(filename)

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            if not parameter_only:
                with open(filename, 'rt') as f:
                    try:
                        text_format.Merge(f.read(), proto)
                    except:
                        logger.critical('Failed to read {}.'.format(filename))
                        logger.critical(
                            '2 byte characters may be used for file name or folder name.'
                        )
                        raise
            if len(proto.parameter) > 0:
                if not exclude_parameter:
                    nn.load_parameters(filename)
        elif ext in ['.protobuf', '.h5']:
            if not exclude_parameter:
                nn.load_parameters(filename)
            else:
                logger.info('Skip loading parameter.')

        elif ext == '.nnp':
            try:
                tmpdir = tempfile.mkdtemp()
                with zipfile.ZipFile(filename, 'r') as nnp:
                    for name in nnp.namelist():
                        _, ext = os.path.splitext(name)
                        if name == 'nnp_version.txt':
                            nnp.extract(name, tmpdir)
                            with open(os.path.join(tmpdir, name), 'rt') as f:
                                pass  # TODO currently do nothing with version.
                        elif ext in ['.nntxt', '.prototxt']:
                            nnp.extract(name, tmpdir)
                            if not parameter_only:
                                with open(os.path.join(tmpdir, name),
                                          'rt') as f:
                                    text_format.Merge(f.read(), proto)
                            if len(proto.parameter) > 0:
                                if not exclude_parameter:
                                    nn.load_parameters(
                                        os.path.join(tmpdir, name))
                        elif ext in ['.protobuf', '.h5']:
                            nnp.extract(name, tmpdir)
                            if not exclude_parameter:
                                nn.load_parameters(os.path.join(tmpdir, name))
                            else:
                                logger.info('Skip loading parameter.')
            finally:
                shutil.rmtree(tmpdir)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            import nnabla_ext.cudnn
        elif 'cuda:float' in default_context.backend:
            try:
                import nnabla_ext.cudnn
            except:
                pass
    else:
        import nnabla_ext.cpu
        default_context = nnabla_ext.cpu.context()

    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.rank)
    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    info.datasets = _datasets(
        proto, prepare_data_iterator if prepare_data_iterator is not None else
        info.training_config.max_epoch > 0)

    info.networks = _networks(proto, default_context, batch_size)

    info.optimizers = _optimizers(proto, default_context, info.networks,
                                  info.datasets)

    info.monitors = _monitors(proto, default_context, info.networks,
                              info.datasets)

    info.executors = _executors(proto, info.networks)

    return info
Exemplo n.º 29
0
    def _expand_repeat(self, network):
        def _search_repeat_id(mes, rid):
            return list(
                mes.repeat_id).index(rid) if rid in mes.repeat_id else None

        def _add_suffix(name, suffix, num):
            return '{}_{}_{}'.format(name, suffix, num)

        ########################################################################
        # Prepare output network message
        net = nnabla_pb2.NNablaProtoBuf().network.add()
        net.CopyFrom(network)

        ########################################################################
        # Finish when repeat_info is not present
        if len(net.repeat_info) == 0:
            return net

        ########################################################################
        # Use first repeat_info
        ri = net.repeat_info[0]
        del net.repeat_info[0]

        ########################################################################
        # Expand variables
        net.ClearField('variable')
        for vpos, var in enumerate(network.variable):
            if var.type == 'Parameter':
                if var.name not in self._parameter_original_names:
                    self._parameter_original_names[var.name] = []

            pos = _search_repeat_id(var, ri.id)
            if pos is not None:
                for i in range(ri.times):

                    if var.type == 'Parameter':
                        if self._old_version_param_name:
                            name = _add_suffix(var.name, ri.id, i)
                        else:
                            name = var.name.replace('{{{}}}'.format(ri.id),
                                                    '_{}'.format(i))
                        self._parameter_original_names[var.name].append(name)
                    else:
                        name = _add_suffix(var.name, ri.id, i)

                    v = net.variable.add()
                    v.CopyFrom(var)
                    v.name = name
                    del v.repeat_id[pos]
            else:
                if var.type == 'Parameter' and len(var.repeat_id) == 0 and len(
                        self._parameter_original_names[var.name]) == 0:
                    self._parameter_original_names[var.name].append(var.name)
                v = net.variable.add()
                v.CopyFrom(var)

        ########################################################################
        # Expand functions
        ########################################################################

        ########################################################################
        # Prepare delayed inputs
        delay_var = {}
        for fpos, func in enumerate(network.function):
            if func.type == 'Delay':
                if func.recurrent_param.repeat_id == ri.id:
                    delay_var[func.output[0]] = []
                    for i in range(ri.times):
                        if i == 0:
                            delay_var[func.output[0]].append(func.input[1])
                        else:
                            v = func.input[0]
                            v = _add_suffix(v, ri.id, i - 1)
                            delay_var[func.output[0]].append(v)

        ########################################################################
        # Prepare repeat end inputs
        repeat_end_var = {}
        for fpos, func in enumerate(network.function):
            if func.type == 'RepeatEnd':
                if func.repeat_param.repeat_id == ri.id:
                    repeat_end_var[func.output[0]] = []
                    for i in range(func.repeat_param.times):
                        repeat_end_var[func.output[0]].append(
                            _add_suffix(func.input[0],
                                        func.repeat_param.repeat_id, i))

        ########################################################################
        # Prepare repeat start inputs
        repeat_start_var = {}
        for fpos, func in enumerate(network.function):
            if func.type == 'RepeatStart':
                if func.repeat_param.repeat_id == ri.id:
                    repeat_start_var[func.output[0]] = []
                    for i in range(ri.times):
                        if i == 0:
                            v = func.input[0]
                            if v in repeat_end_var:
                                v = repeat_end_var[v][ri.times - 1]
                            repeat_start_var[func.output[0]].append(v)
                        else:
                            v = func.input[1]
                            if v in repeat_end_var:
                                v = repeat_end_var[v][i - 1]
                            else:
                                v = _add_suffix(v, ri.id, i - 1)
                            repeat_start_var[func.output[0]].append(v)

        ########################################################################
        # Expand network
        net.ClearField('function')
        for fpos, func in enumerate(network.function):
            if func.type == 'RepeatStart' or func.type == 'RepeatEnd':
                if func.repeat_param.repeat_id == ri.id:
                    continue
            if func.type == 'Delay':
                if func.recurrent_param.repeat_id == ri.id:
                    continue
            if func.type == 'RecurrentInput':
                if func.recurrent_param.repeat_id == ri.id:

                    f = net.function.add()
                    f.CopyFrom(func)
                    f.type = 'Split'
                    f.split_param.axis = func.recurrent_param.axis

                    f.ClearField('output')
                    for i in range(ri.times):
                        f.output.append(_add_suffix(func.output[0], ri.id, i))

                    pos = _search_repeat_id(func, ri.id)
                    del f.repeat_id[pos]
                    f.ClearField('recurrent_param')
                    continue

            if func.type == 'RecurrentOutput':
                if func.recurrent_param.repeat_id == ri.id:
                    f = net.function.add()
                    f.CopyFrom(func)
                    f.type = 'Stack'
                    f.stack_param.axis = func.recurrent_param.axis

                    f.ClearField('input')
                    for i in range(ri.times):
                        f.input.append(_add_suffix(func.input[0], ri.id, i))

                    f.ClearField('recurrent_param')
                    continue

            pos = _search_repeat_id(func, ri.id)
            if pos is not None:

                for i in range(ri.times):

                    f = net.function.add()
                    f.CopyFrom(func)

                    del f.repeat_id[pos]

                    f.name = _add_suffix(func.name, ri.id, i)
                    for n, v in enumerate(func.input):
                        vname = None
                        if v in self._parameter_original_names:
                            if len(self._parameter_original_names[v]
                                   ) == ri.times:
                                vname = self._parameter_original_names[v][i]
                            else:
                                vname = v
                        elif v in repeat_start_var:
                            vname = repeat_start_var[v][i]
                        elif v in repeat_end_var:
                            vname = repeat_end_var[v][i]
                        elif v in delay_var:
                            vname = delay_var[v][i]
                        else:
                            vname = _add_suffix(v, ri.id, i)
                        f.input[n] = vname
                    for n, v in enumerate(func.output):
                        vname = _add_suffix(v, ri.id, i)
                        f.output[n] = vname

            else:
                f = net.function.add()
                f.CopyFrom(func)
                for n, v in enumerate(func.input):
                    if v in repeat_end_var:
                        vname = repeat_end_var[v][ri.times - 1]
                        f.input[n] = vname

        return self._expand_repeat(net)
Exemplo n.º 30
0
def load(filenames,
         prepare_data_iterator=True,
         batch_size=None,
         exclude_parameter=False,
         parameter_only=False,
         extension=".nntxt"):
    '''load
    Load network information from files.

    Args:
        filenames (list): file-like object or List of filenames.
        extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
    Returns:
        dict: Network information.
    '''
    class Info:
        pass

    info = Info()

    proto = nnabla_pb2.NNablaProtoBuf()

    # optimizer checkpoint
    opti_proto = nnabla_pb2.NNablaProtoBuf()
    OPTI_BUF_EXT = ['.optimizer']
    opti_h5_files = {}
    tmpdir = tempfile.mkdtemp()

    if isinstance(filenames, list) or isinstance(filenames, tuple):
        pass
    elif isinstance(filenames, str) or hasattr(filenames, 'read'):
        filenames = [filenames]

    for filename in filenames:
        if isinstance(filename, str):
            _, ext = os.path.splitext(filename)
        else:
            ext = extension

        # TODO: Here is some known problems.
        #   - Even when protobuf file includes network structure,
        #     it will not loaded.
        #   - Even when prototxt file includes parameter,
        #     it will not loaded.

        if ext in ['.nntxt', '.prototxt']:
            if not parameter_only:
                with get_file_handle_load(filename, ext) as f:
                    try:
                        text_format.Merge(f.read(), proto)
                    except:
                        logger.critical('Failed to read {}.'.format(filename))
                        logger.critical(
                            '2 byte characters may be used for file name or folder name.'
                        )
                        raise
            if len(proto.parameter) > 0:
                if not exclude_parameter:
                    nn.load_parameters(filename, extension=ext)
        elif ext in ['.protobuf', '.h5']:
            if not exclude_parameter:
                nn.load_parameters(filename, extension=ext)
            else:
                logger.info('Skip loading parameter.')

        elif ext == '.nnp':
            with get_file_handle_load(filename, ext) as nnp:
                for name in nnp.namelist():
                    _, ext = os.path.splitext(name)
                    if name == 'nnp_version.txt':
                        pass  # TODO currently do nothing with version.
                    elif ext in ['.nntxt', '.prototxt']:
                        if not parameter_only:
                            with nnp.open(name, 'r') as f:
                                text_format.Merge(f.read(), proto)
                        if len(proto.parameter) > 0:
                            if not exclude_parameter:
                                with nnp.open(name, 'r') as f:
                                    nn.load_parameters(f, extension=ext)
                    elif ext in ['.protobuf', '.h5']:
                        if not exclude_parameter:
                            with nnp.open(name, 'r') as f:
                                nn.load_parameters(f, extension=ext)
                        else:
                            logger.info('Skip loading parameter.')
                    elif ext in OPTI_BUF_EXT:
                        buf_type = get_buf_type(name)
                        if buf_type == 'protobuf':
                            with nnp.open(name, 'r') as f:
                                with get_file_handle_load(
                                        f, '.protobuf') as opti_p:
                                    opti_proto.MergeFromString(opti_p.read())
                        elif buf_type == 'h5':
                            nnp.extract(name, tmpdir)
                            opti_h5_files[name] = os.path.join(tmpdir, name)

    default_context = None
    if proto.HasField('global_config'):
        info.global_config = _global_config(proto)
        default_context = info.global_config.default_context
        if 'cuda' in default_context.backend:
            import nnabla_ext.cudnn
        elif 'cuda:float' in default_context.backend:
            try:
                import nnabla_ext.cudnn
            except:
                pass
        try:
            x = nn.Variable()
            y = nn.Variable()
            func = F.ReLU(default_context, inplace=True)
            func.setup([x], [y])
            func.forward([x], [y])
        except:
            logger.warn('Fallback to CPU context.')
            import nnabla_ext.cpu
            default_context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cpu
        default_context = nnabla_ext.cpu.context()

    comm = current_communicator()
    if comm:
        default_context.device_id = str(comm.local_rank)
    if proto.HasField('training_config'):
        info.training_config = _training_config(proto)

    info.datasets = _datasets(
        proto, prepare_data_iterator if prepare_data_iterator is not None else
        info.training_config.max_epoch > 0)

    info.networks = _networks(proto, default_context, batch_size)

    info.optimizers = _optimizers(proto, default_context, info.networks,
                                  info.datasets)
    _load_optimizer_checkpoint(opti_proto, opti_h5_files, info)
    shutil.rmtree(tmpdir)

    info.monitors = _monitors(proto, default_context, info.networks,
                              info.datasets)

    info.executors = _executors(proto, info.networks)

    return info