def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False, extension=".nntxt"): '''load Load network information from files. Args: filenames (list): file-like object or List of filenames. extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp". Returns: dict: Network information. ''' class Info: pass info = Info() proto = nnabla_pb2.NNablaProtoBuf() # optimizer checkpoint opti_proto = nnabla_pb2.NNablaProtoBuf() OPTI_BUF_EXT = ['.optimizer'] opti_h5_files = {} tmpdir = tempfile.mkdtemp() if isinstance(filenames, list) or isinstance(filenames, tuple): pass elif isinstance(filenames, str) or hasattr(filenames, 'read'): filenames = [filenames] for filename in filenames: if isinstance(filename, str): _, ext = os.path.splitext(filename) else: ext = extension # TODO: Here is some known problems. # - Even when protobuf file includes network structure, # it will not loaded. # - Even when prototxt file includes parameter, # it will not loaded. if ext in ['.nntxt', '.prototxt']: if not parameter_only: with get_file_handle_load(filename, ext) as f: try: text_format.Merge(f.read(), proto) except: logger.critical('Failed to read {}.'.format(filename)) logger.critical( '2 byte characters may be used for file name or folder name.' ) raise if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters(filename, extension=ext) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: nn.load_parameters(filename, extension=ext) else: logger.info('Skip loading parameter.') elif ext == '.nnp': with get_file_handle_load(filename, ext) as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if name == 'nnp_version.txt': pass # TODO currently do nothing with version. elif ext in ['.nntxt', '.prototxt']: if not parameter_only: with nnp.open(name, 'r') as f: text_format.Merge(f.read(), proto) if len(proto.parameter) > 0: if not exclude_parameter: with nnp.open(name, 'r') as f: nn.load_parameters(f, extension=ext) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: with nnp.open(name, 'r') as f: nn.load_parameters(f, extension=ext) else: logger.info('Skip loading parameter.') elif ext in OPTI_BUF_EXT: buf_type = get_buf_type(name) if buf_type == 'protobuf': with nnp.open(name, 'r') as f: with get_file_handle_load( f, '.protobuf') as opti_p: opti_proto.MergeFromString(opti_p.read()) elif buf_type == 'h5': nnp.extract(name, tmpdir) opti_h5_files[name] = os.path.join(tmpdir, name) default_context = None if proto.HasField('global_config'): info.global_config = _global_config(proto) default_context = info.global_config.default_context if 'cuda' in default_context.backend: import nnabla_ext.cudnn elif 'cuda:float' in default_context.backend: try: import nnabla_ext.cudnn except: pass try: x = nn.Variable() y = nn.Variable() func = F.ReLU(default_context, inplace=True) func.setup([x], [y]) func.forward([x], [y]) except: logger.warn('Fallback to CPU context.') import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() else: import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() comm = current_communicator() if comm: default_context.device_id = str(comm.local_rank) if proto.HasField('training_config'): info.training_config = _training_config(proto) info.datasets = _datasets( proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0) info.networks = _networks(proto, default_context, batch_size) info.optimizers = _optimizers(proto, default_context, info.networks, info.datasets) _load_optimizer_checkpoint(opti_proto, opti_h5_files, info) shutil.rmtree(tmpdir) info.monitors = _monitors(proto, default_context, info.networks, info.datasets) info.executors = _executors(proto, info.networks) return info
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False): '''load Load network information from files. Args: filenames (list): List of filenames. Returns: dict: Network information. ''' class Info: pass info = Info() proto = nnabla_pb2.NNablaProtoBuf() for filename in filenames: _, ext = os.path.splitext(filename) # TODO: Here is some known problems. # - Even when protobuf file includes network structure, # it will not loaded. # - Even when prototxt file includes parameter, # it will not loaded. if ext in ['.nntxt', '.prototxt']: if not parameter_only: with open(filename, 'rt') as f: try: text_format.Merge(f.read(), proto) except: logger.critical('Failed to read {}.'.format(filename)) logger.critical( '2 byte characters may be used for file name or folder name.' ) raise if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters(filename) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: nn.load_parameters(filename) else: logger.info('Skip loading parameter.') elif ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with zipfile.ZipFile(filename, 'r') as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if name == 'nnp_version.txt': nnp.extract(name, tmpdir) with open(os.path.join(tmpdir, name), 'rt') as f: pass # TODO currently do nothing with version. elif ext in ['.nntxt', '.prototxt']: nnp.extract(name, tmpdir) if not parameter_only: with open(os.path.join(tmpdir, name), 'rt') as f: text_format.Merge(f.read(), proto) if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters( os.path.join(tmpdir, name)) elif ext in ['.protobuf', '.h5']: nnp.extract(name, tmpdir) if not exclude_parameter: nn.load_parameters(os.path.join(tmpdir, name)) else: logger.info('Skip loading parameter.') finally: shutil.rmtree(tmpdir) default_context = None if proto.HasField('global_config'): info.global_config = _global_config(proto) default_context = info.global_config.default_context if 'cuda' in default_context.backend: import nnabla_ext.cudnn elif 'cuda:float' in default_context.backend: try: import nnabla_ext.cudnn except: pass try: x = nn.Variable() y = nn.Variable() func = F.ReLU(default_context, inplace=True) func.setup([x], [y]) func.forward([x], [y]) except: logger.warn('Fallback to CPU context.') import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() else: import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() comm = current_communicator() if comm: default_context.device_id = str(comm.rank) if proto.HasField('training_config'): info.training_config = _training_config(proto) info.datasets = _datasets( proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0) info.networks = _networks(proto, default_context, batch_size) info.optimizers = _optimizers(proto, default_context, info.networks, info.datasets) info.monitors = _monitors(proto, default_context, info.networks, info.datasets) info.executors = _executors(proto, info.networks) return info