def _export_files(self, outdir): with open('{}/nnp_version.txt'.format(outdir), 'w') as f: f.write('0.1\n') if self._parameter_type == 'included': self.write_nntxt('{}/network.nntxt'.format(outdir), self._nnp) else: nnp_wo_parameter = nnabla_pb2.NNablaProtoBuf() nnp_wo_parameter.CopyFrom(self._nnp) nnp_wo_parameter.ClearField('parameter') self._write_nntxt('{}/network.nntxt'.format(outdir), nnp_wo_parameter) if self._parameter_type == 'protobuf': nnp_parameter_only = nnabla_pb2.NNablaProtoBuf() for param in self._nnp.parameter: parameter = nnp_parameter_only.parameter.add() parameter.CopyFrom(param) self._write_protobuf('{}/parameter.protobuf'.format(outdir), nnp_parameter_only) elif self._parameter_type == 'h5': self._write_h5('{}/parameter.h5'.format(outdir), self._nnp) elif self._parameter_type == 'none': pass # store without param. else: print('Unsupported parameter type `{}`.'.format( self._parameter_type))
def load_parameters(path): """Load parameters from a file with the specified format. Args: path : path or file object """ _, ext = os.path.splitext(path) if ext == '.h5': # TODO temporary work around to suppress FutureWarning message. import warnings warnings.simplefilter('ignore', category=FutureWarning) import h5py with h5py.File(path, 'r') as hd: keys = [] def _get_keys(name): ds = hd[name] if not isinstance(ds, h5py.Dataset): # Group return # To preserve order of parameters keys.append((ds.attrs.get('index', None), name)) hd.visit(_get_keys) for _, key in sorted(keys): ds = hd[key] var = get_parameter_or_create(key, ds.shape, need_grad=ds.attrs['need_grad']) var.data.cast(ds.dtype)[...] = ds[...] elif ext == '.protobuf': proto = nnabla_pb2.NNablaProtoBuf() with open(path, 'rb') as f: proto.MergeFromString(f.read()) set_parameter_from_proto(proto) elif ext == '.nntxt' or ext == '.prototxt': proto = nnabla_pb2.NNablaProtoBuf() with open(path, 'r') as f: text_format.Merge(f.read(), proto) set_parameter_from_proto(proto) elif ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with zipfile.ZipFile(path, 'r') as nnp: for name in nnp.namelist(): nnp.extract(name, tmpdir) _, ext = os.path.splitext(name) if ext in ['.protobuf', '.h5']: load_parameters(os.path.join(tmpdir, name)) finally: shutil.rmtree(tmpdir) logger.info("Parameter load ({}): {}".format(format, path))
def save_parameters(path): """Save all parameters into a file with the specified format. Currently hdf5 and protobuf formats are supported. Args: path : path or file object """ _, ext = os.path.splitext(path) params = get_parameters(grad_only=False) if ext == '.h5': import h5py with h5py.File(path, 'w') as hd: params = get_parameters(grad_only=False) for i, (k, v) in enumerate(iteritems(params)): hd[k] = v.d hd[k].attrs['need_grad'] = v.need_grad # To preserve order of parameters hd[k].attrs['index'] = i elif ext == '.protobuf': proto = nnabla_pb2.NNablaProtoBuf() for variable_name, variable in params.items(): parameter = proto.parameter.add() parameter.variable_name = variable_name parameter.shape.dim.extend(variable.shape) parameter.data.extend(numpy.array(variable.d).flatten().tolist()) parameter.need_grad = variable.need_grad with open(path, "wb") as f: f.write(proto.SerializeToString()) else: logger.critical('Only supported hdf5 or protobuf.') assert False logger.info("Parameter save ({}): {}".format(ext, path))
def _opti_file_loader(ctx, fileloaders, nnp, filename, ext): '''.optimizer optimizer file loader This loader only handles .optimizer file. ''' file_type = get_buf_type(filename) if file_type == 'protobuf': opti_proto = nnabla_pb2.NNablaProtoBuf() with get_file_handle_load(nnp, filename, '.protobuf') as f: opti_proto.MergeFromString(f.read()) for p_opti in opti_proto.optimizer: o = ctx.optimizers.get(p_opti.name, None) if o: o.solver.set_states_from_protobuf(p_opti) else: logger.warn( 'No matched optimizer is found for {}.'.format(filename)) elif file_type == 'h5': loaded = False for o in ctx.optimizers.values(): key = '{}_{}_optimizer.h5.optimizer'.format( o.name, re.sub(r'(|Cuda)$', '', str(o.solver.name)) ) if key == filename: o.solver.set_states(_load_solve_state_from_h5(nnp, filename)) loaded = True if not loaded: logger.warn( "No matched optimizer is found for {}.".format(filename))
def _load_nnp_to_proto(nnp_path): import google.protobuf.text_format as text_format import tempfile import zipfile import shutil proto = nnabla_pb2.NNablaProtoBuf() tmpdir = tempfile.mkdtemp() try: with zipfile.ZipFile(nnp_path, "r") as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if name == "nnp_version.txt": pass # Currently nnp_version.txt is ignored elif ext in [".nntxt", ".prototxt"]: nnp.extract(name, tmpdir) with open(os.path.join(tmpdir, name), "rt") as f: text_format.Merge(f.read(), proto) elif ext in [".protobuf", ".h5"]: nnp.extract(name, tmpdir) nn.load_parameters(os.path.join(tmpdir, name)) finally: shutil.rmtree(tmpdir) return proto
def _pb_parameter_file_loader(ctx, file_loaders, nnp, filename, ext): with get_file_handle_load(nnp, filename, ext) as f: try: ctx.proto except: ctx.proto = nnabla_pb2.NNablaProtoBuf() ctx.proto.MergeFromString(f.read()) nn.parameter.set_parameter_from_proto(ctx.proto)
def _load_nntxt_to_proto(nntxt_path): import google.protobuf.text_format as text_format proto = nnabla_pb2.NNablaProtoBuf() with open(nntxt_path, "rt") as f: text_format.Merge(f.read(), proto) return proto
def execute(self): self._nnp = nnabla_pb2.NNablaProtoBuf() other_files = [] for ifile in self._args: print('Importing {}'.format(ifile)) ext = os.path.splitext(ifile)[1].lower() if ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with zipfile.ZipFile(ifile, 'r') as nnpzip: for name in nnpzip.namelist(): if os.path.splitext(name)[1].lower() in [ '.nntxt', '.prototxt' ]: nnpzip.extract(name, tmpdir) with open(os.path.join(tmpdir, name), 'rt') as f: text_format.Merge(f.read(), self._nnp) for name in nnpzip.namelist(): # Param if os.path.splitext(name)[1].lower() in [ '.protobuf', '.h5' ]: nnpzip.extract(name, tmpdir) self.load_parameters(os.path.join( tmpdir, name)) finally: shutil.rmtree(tmpdir) elif ext in ['.nntxt', '.prototxt']: with open(ifile, 'rt') as f: text_format.Merge(f.read(), self._nnp) elif ext in ['.protobuf', '.h5']: self.load_parameters(ifile) else: other_files.append(ifile) executor_name = self._nnp.executor[0].network_name network = self.find_network(executor_name) parameter_variable_list = self.find_parameter_variable(network) if parameter_variable_list and not self._nnp.parameter: self.generate_parameters_data(parameter_variable_list, network.batch_size) if self._executor_index is not None: if self._executor_index < len(self._nnp.executor): self._nnp = self._shrink_with_executor( self._nnp.executor[self._executor_index]) if self._expand_network: self._nnp = expander.NnpExpander(self._nnp).execute() class nnp: pass nnp.protobuf = self._nnp nnp.other_files = other_files return nnp
def _protobuf_parameter_file_saver(ctx, filename, ext): proto = nnabla_pb2.NNablaProtoBuf() for variable_name, variable in ctx.parameters.items(): parameter = proto.parameter.add() parameter.variable_name = variable_name parameter.shape.dim.extend(variable.shape) parameter.data.extend(numpy.array(variable.d).flatten().tolist()) parameter.need_grad = variable.need_grad with get_file_handle_save(filename, ext) as f: f.write(proto.SerializeToString())
def _load_nntxt_to_proto(nntxt_path): import google.protobuf.text_format as text_format proto = nnabla_pb2.NNablaProtoBuf() if hasattr(nntxt_path, 'read'): nntxt = nntxt_path.read() else: with open(nntxt_path, "rt") as f: nntxt = f.read() text_format.Merge(nntxt, proto) return proto
def execute(self): nnp = nnabla_pb2.NNablaProtoBuf() nnp.CopyFrom(self._nnp) for network in nnp.network: self._expand_network(network) for optimizer in nnp.optimizer: self._expand_parameter_variable(optimizer) for executor in nnp.executor: self._expand_parameter_variable(executor) return nnp
def load(filenames, prepare_data_iterator=True): '''load Load network information from files. Args: filenames (list): List of filenames. Returns: dict: Network infomation. ''' class Info: pass info = Info() proto = nnabla_pb2.NNablaProtoBuf() for filename in filenames: _, ext = os.path.splitext(filename) if 'txt' in ext: with open(filename, 'rt') as f: text_format.Merge(f.read(), proto) elif ext in ['.protobuf', '.h5']: nn.load_parameters(filename, proto) default_context = None if proto.HasField('global_config'): info.global_config = _global_config(proto) default_context = info.global_config.default_context else: default_context = nn.context() if proto.HasField('training_config'): info.training_config = _training_config(proto) if len(proto.dataset) > 0: info.datasets = _datasets(proto, prepare_data_iterator) if len(proto.network) > 0: info.networks = _networks(proto, default_context) if len(proto.optimizer) > 0: info.optimizers = _optimizers( proto, default_context, info.networks, info.datasets) if len(proto.monitor) > 0: info.monitors = _monitors( proto, default_context, info.networks, info.datasets) if len(proto.executor) > 0: info.executors = _executors(proto, info.networks) return info
def load_parameters(path): """Load parameters from a file with the specified format. Args: path : path or file object """ _, ext = os.path.splitext(path) if ext == '.h5': import h5py with h5py.File(path, 'r') as hd: keys = [] def _get_keys(name): ds = hd[name] if not isinstance(ds, h5py.Dataset): # Group return # To preserve order of parameters keys.append((ds.attrs.get('index', None), name)) hd.visit(_get_keys) for _, key in sorted(keys): ds = hd[key] var = get_parameter_or_create(key, ds.shape, need_grad=ds.attrs['need_grad']) var.data.cast(ds.dtype)[...] = ds[...] elif ext == '.protobuf': proto = nnabla_pb2.NNablaProtoBuf() with open(path, 'rb') as f: proto.MergeFromString(f.read()) for parameter in proto.parameter: var = get_parameter_or_create(parameter.variable_name, parameter.shape.dim) param = numpy.reshape(parameter.data, parameter.shape.dim) var.d = param var.need_grad = parameter.need_grad elif ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with zipfile.ZipFile(path, 'r') as nnp: for name in nnp.namelist(): nnp.extract(name, tmpdir) _, ext = os.path.splitext(name) if ext in ['.protobuf', '.h5']: load_parameters(os.path.join(tmpdir, name)) finally: shutil.rmtree(tmpdir) logger.info("Parameter load ({}): {}".format(format, path))
def execute(self): nnp = nnabla_pb2.NNablaProtoBuf() nnp.CopyFrom(self._nnp) nnp.ClearField('network') for network in self._nnp.network: net = nnp.network.add() net.CopyFrom(self._expand_network(network)) for optimizer in nnp.optimizer: self._expand_parameter_variable(optimizer) for executor in nnp.executor: self._expand_parameter_variable(executor) return nnp
def load_parameters(path, proto=None): """Load parameters from a file with the specified format. Args: path : path or file object """ _, ext = os.path.splitext(path) if proto is None: proto = nnabla_pb2.NNablaProtoBuf() if ext == '.h5': import h5py with h5py.File(path, 'r') as hd: keys = [] def _get_keys(name): ds = hd[name] if not isinstance(ds, h5py.Dataset): # Group return # To preserve order of parameters keys.append((ds.attrs.get('index', None), name)) hd.visit(_get_keys) for _, key in sorted(keys): ds = hd[key] var = get_parameter_or_create(key, ds.shape, need_grad=ds.attrs['need_grad']) var.data.cast(ds.dtype)[...] = ds[...] parameter = proto.parameter.add() parameter.variable_name = key parameter.shape.dim.extend(var.shape) parameter.data.extend(numpy.array(var.d).flatten().tolist()) parameter.need_grad = var.need_grad elif ext == '.protobuf': with open(path, 'rb') as f: proto.MergeFromString(f.read()) for parameter in proto.parameter: var = get_parameter_or_create(parameter.variable_name, parameter.shape.dim) param = numpy.reshape(parameter.data, parameter.shape.dim) var.d = param var.need_grad = parameter.need_grad logger.info("Parameter load ({}): {}".format(format, path)) return proto
def execute(self): self._nnp = nnabla_pb2.NNablaProtoBuf() other_files = [] for ifile in self._args: print('Importing {}'.format(ifile)) ext = os.path.splitext(ifile)[1].lower() if ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with zipfile.ZipFile(ifile, 'r') as nnpzip: for name in nnpzip.namelist(): if os.path.splitext(name)[1].lower() in [ '.nntxt', '.prototxt' ]: nnpzip.extract(name, tmpdir) with open(os.path.join(tmpdir, name), 'rt') as f: text_format.Merge(f.read(), self._nnp) for name in nnpzip.namelist(): # Param if os.path.splitext(name)[1].lower() in [ '.protobuf', '.h5' ]: nnpzip.extract(name, tmpdir) self.load_parameters(os.path.join( tmpdir, name)) finally: shutil.rmtree(tmpdir) elif ext in ['.nntxt', '.prototxt']: with open(ifile, 'rt') as f: text_format.Merge(f.read(), self._nnp) elif ext in ['.protobuf', '.h5']: self.load_parameters(ifile) else: other_files.append(ifile) if self._expand_network: self._nnp = expander.NnpExpander(self._nnp).execute() class nnp: pass nnp.protobuf = self._nnp nnp.other_files = other_files return nnp
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"): """Load parameters from a file with the specified format. Args: path : path or file object """ if isinstance(path, str): _, ext = os.path.splitext(path) else: ext = extension ctx = FileHandlerContext() if proto is None: ctx.proto = nnabla_pb2.NNablaProtoBuf() else: ctx.proto = proto ctx.needs_proto = needs_proto # Get parameter file loaders file_loaders = get_parameter_file_loader() load_files(ctx, file_loaders, path, ext) return ctx.proto
def _expand_network(self, network): print(' Expanding {}.'.format(network.name)) repeat_ids = collections.OrderedDict() for ri in network.repeat_info: repeat_ids[ri.id] = ri.times # Check whether parameter name complies with old rule. self._old_version_param_name = False for param in self._parameters: for ri in repeat_ids: m = re.search('{}\[([0-9]+)\]$'.format(ri), param) if m: if int(m.group(1)) < repeat_ids[ri]: self._old_version_param_name = True # Expand repeat network = self._expand_repeat(network) functions = [] for func in network.function: functions.append((func.name, func.type, [n for n in func.input], [n for n in func.output])) sorted_functions = self._sort_functions(functions) func_list = [] for f in functions: func_list.append(f[0]) net = nnabla_pb2.NNablaProtoBuf().network.add() net.CopyFrom(network) net.ClearField('function') for f in sorted_functions: func = net.function.add() func.CopyFrom(network.function[func_list.index(f[0])]) return net
def save_parameters(path, params=None, extension=None): """Save all parameters into a file with the specified format. Currently hdf5 and protobuf formats are supported. Args: path : path or file object params (dict, optional): Parameters to be saved. Dictionary is of a parameter name (:obj:`str`) to :obj:`~nnabla.Variable`. """ if isinstance(path, str): _, ext = os.path.splitext(path) else: ext = extension params = get_parameters(grad_only=False) if params is None else params if ext == '.h5': # TODO temporary work around to suppress FutureWarning message. import warnings warnings.simplefilter('ignore', category=FutureWarning) import h5py with get_file_handle_save(path, ext) as hd: for i, (k, v) in enumerate(iteritems(params)): hd[k] = v.d hd[k].attrs['need_grad'] = v.need_grad # To preserve order of parameters hd[k].attrs['index'] = i elif ext == '.protobuf': proto = nnabla_pb2.NNablaProtoBuf() for variable_name, variable in params.items(): parameter = proto.parameter.add() parameter.variable_name = variable_name parameter.shape.dim.extend(variable.shape) parameter.data.extend(numpy.array(variable.d).flatten().tolist()) parameter.need_grad = variable.need_grad with get_file_handle_save(path, ext) as f: f.write(proto.SerializeToString()) else: logger.critical('Only supported hdf5 or protobuf.') assert False logger.info("Parameter save ({}): {}".format(ext, path))
def save_optimizer_states(filebase, ext, train_config): filelist = [] if ext == '.protobuf': filename = filebase + '_optimizer.protobuf.optimizer' proto = nnabla_pb2.NNablaProtoBuf() proto_optimizers = [] for o in train_config.optimizers.values(): proto_optimizers.append(_create_optimizer_lite(o.optimizer)) proto.optimizer.extend(proto_optimizers) with get_file_handle_save(filename, '.protobuf') as f: f.write(proto.SerializeToString()) filelist.append(filename) else: for o in train_config.optimizers.values(): f_name = '{}_{}_optimizer.h5'.format( o.optimizer.name, re.sub(r'(|Cuda)$', '', str(o.optimizer.solver.name))) filename = '{}_{}'.format(filebase, f_name) o.optimizer.solver.save_states(filename) name_ext = '{}.optimizer'.format(filename) os.rename(filename, name_ext) filelist.append(name_ext) return filelist
def _shrink_with_executor(self, executor): print(' Try to leave only executor[{}].'.format(executor.name)) network = None for n in self._nnp.network: if n.name == executor.network_name: network = n if network is None: return None nnp = nnabla_pb2.NNablaProtoBuf() nnp.CopyFrom(self._nnp) nnp.ClearField('optimizer') nnp.ClearField('monitor') nnp.ClearField('network') net = nnp.network.add() net.CopyFrom(network) nnp.ClearField('executor') exe = nnp.executor.add() exe.CopyFrom(executor) return nnp
def create_function_nnp(inputs, outputs, func_name, func_args, func_kwargs): if func_name is None: return for category_name, category in nnabla.utils.converter.get_category_info( ).items(): if func_name in category: function = category[func_name] nnp = nnabla_pb2.NNablaProtoBuf() net = nnp.network.add() net.name = 'network1' net.batch_size = 1 func = net.function.add() func.name = func_name func.type = func_name # Prepare input func_inputs = [] data_names = [] parameter_names = [] input_data = [] for n, i in enumerate(inputs): if i is not None: if len(list(function['inputs'].items())) == 1: input_name, input_info = list(function['inputs'].items())[0] if 'variadic' in input_info and input_info['variadic']: input_name += str(n) else: input_name, input_info = list(function['inputs'].items())[n] func_inputs.append(input_name) var = net.variable.add() var.name = input_name if 'parameter' in input_info and input_info['parameter']: parameter_names.append(input_name) var.type = 'Parameter' shape = list(i.d.shape)[:] if func.name == 'BatchNormalization': shape = [1] + shape var.shape.dim.extend(shape) param = nnp.parameter.add() param.variable_name = input_name param.shape.dim.extend(shape) param.data.extend(i.d.flatten()) else: input_data.append(i.d.flatten()) data_names.append(input_name) var.type = 'Buffer' shape = list(i.d.shape)[:] # exclude the cases no need to extend dimension if input_name == 'rmean' or input_name == 't': pass elif func.name == 'PReLU' and input_name == "x1": pass elif func.name == 'Transpose': pass elif func.name == 'Concatenate': pass else: shape = [1] + shape var.shape.dim.extend(shape) func.input.extend(func_inputs) # Prepare output func_outputs = [] output_data = [] for n, o in enumerate(outputs): output_name = 'y{}'.format(n) func_outputs.append(output_name) var = net.variable.add() var.name = output_name var.type = 'Buffer' shape = list(o.d.shape)[:] shape = [-1] + shape var.shape.dim.extend(shape) output_data.append(o.d.flatten()) func.output.extend(func_outputs) # Prepare argument if 'arguments' in function: for n, (arg_name, arg) in enumerate(function['arguments'].items()): param = eval('func.{}_param'.format(function['snake_name'])) if not func_args and not func_kwargs: continue if func.name == 'Interpolate': del func_args[0] if n < len(func_args): a = func_args[n] else: if func.name == 'Concatenate' or func.name == 'Stack': a = func_kwargs['axis'] else: a = func_kwargs.get('keepdims') # This is used to fix the problem of flip (axes == None) if a is None: f = ['Sum', 'Mean', 'Max', 'Min', 'Prod'] if 'axes' in arg_name: if func.name in f: a = net.variable[0].shape.dim[:-1] a = [x - 1 for x in a] else: a = len(net.variable[0].shape.dim) - 2 if a is not None: if 'axis' == arg_name: if func.name == 'Concatenate': pass else: a += 1 if 'axes' in arg_name: if func.name == 'Transpose': pass else: if isinstance(a, tuple) or isinstance(a, list): a = list(a) else: a = [a] a = [x + 1 for x in a] if isinstance(a, tuple) or isinstance(a, list): if arg['type'] == 'Shape': exec('param.{}.dim.extend(list(a))'.format(arg_name)) else: exec('param.{}.extend(a)'.format(arg_name)) elif isinstance(a, numpy.ndarray): a = a.flatten() if arg['type'] == 'Shape': if function['snake_name'] == 'broadcast': exec('param.{}.dim.extend([1] + list(a))'.format( arg_name)) else: exec('param.{}.dim.extend(list(a))'.format( arg_name)) else: exec('param.{}.extend(a)'.format(arg_name)) else: if 'repeated' in arg['type']: exec('param.{}.extend([a])'.format(arg_name)) elif arg['type'] == 'string': exec('param.{} = "{}"'.format(arg_name, a)) else: if arg_name == 'base_axis': a = a + 1 exec('param.{} = {}'.format(arg_name, a)) # Prepare executor exe = nnp.executor.add() exe.name = 'inference' exe.network_name = 'network1' for d in data_names: dat = exe.data_variable.add() dat.variable_name = d dat.data_name = d for o in func_outputs: out = exe.output_variable.add() out.variable_name = o out.data_name = o for p in parameter_names: par = exe.parameter_variable.add() par.variable_name = p return nnp, input_data, output_data
def load(filenames, prepare_data_iterator=True, batch_size=None): '''load Load network information from files. Args: filenames (list): List of filenames. Returns: dict: Network infomation. ''' class Info: pass info = Info() proto = nnabla_pb2.NNablaProtoBuf() for filename in filenames: _, ext = os.path.splitext(filename) # TODO: Here is some known problems. # - Even when protobuf file includes network structure, # it will not loaded. # - Even when prototxt file includes parameter, # it will not loaded. if ext in ['.nntxt', '.prototxt']: with open(filename, 'rt') as f: text_format.Merge(f.read(), proto) elif ext in ['.protobuf', '.h5']: nn.load_parameters(filename) elif ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with zipfile.ZipFile(filename, 'r') as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if name == 'nnp_version.txt': nnp.extract(name, tmpdir) with open(os.path.join(tmpdir, name), 'rt') as f: pass # Currently nnp_version.txt is ignored. elif ext in ['.nntxt', '.prototxt']: nnp.extract(name, tmpdir) with open(os.path.join(tmpdir, name), 'rt') as f: text_format.Merge(f.read(), proto) elif ext in ['.protobuf', '.h5']: nnp.extract(name, tmpdir) nn.load_parameters(os.path.join(tmpdir, name)) finally: shutil.rmtree(tmpdir) default_context = None if proto.HasField('global_config'): info.global_config = _global_config(proto) default_context = info.global_config.default_context if 'cuda' in default_context.backend: try: import nnabla_ext.cuda.cudnn except: pass else: default_context = nn.context() if proto.HasField('training_config'): info.training_config = _training_config(proto) if len(proto.dataset) > 0: info.datasets = _datasets(proto, prepare_data_iterator) if len(proto.network) > 0: info.networks = _networks(proto, default_context, batch_size) if len(proto.optimizer) > 0: info.optimizers = _optimizers( proto, default_context, info.networks, info.datasets) if len(proto.monitor) > 0: info.monitors = _monitors( proto, default_context, info.networks, info.datasets) if len(proto.executor) > 0: info.executors = _executors(proto, info.networks) return info
def create_proto(contents, include_params=False): proto = nnabla_pb2.NNablaProtoBuf() if 'global_config' in contents: proto.global_config.MergeFrom( _create_global_config( contents['global_config']['default_context'])) if 'training_config' in contents: proto.training_config.MergeFrom( _create_training_config( contents['training_config']['max_epoch'], contents['training_config']['iter_per_epoch'], contents['training_config']['save_best'])) networks = {} if 'networks' in contents: proto_nets = [] for net in contents['networks']: networks[net['name']] = _create_network(net) proto_nets.append(networks[net['name']]) proto.network.extend(proto_nets) datasets = {} if 'datasets' in contents: proto_datasets = [] for d in contents['datasets']: if 'cache_dir' in d: cache_dir = d['cache_dir'] else: cache_dir = None datasets[d['name']] = _create_dataset(d['name'], d['uri'], cache_dir, d['variables'], d['shuffle'], d['batch_size'], d['no_image_normalization']) proto_datasets.append(datasets[d['name']]) proto.dataset.extend(proto_datasets) if 'optimizers' in contents: proto_optimizers = [] for o in contents['optimizers']: proto_optimizers.append( _create_optimizer(o['name'], o['solver'], networks[o['network']], datasets[o['dataset']])) proto.optimizer.extend(proto_optimizers) if 'monitors' in contents: proto_monitors = [] for m in contents['monitors']: proto_monitors.append( _create_monitor(m['name'], m['monitor'], networks[m['network']], datasets[m['dataset']])) proto.monitor.extend(proto_monitors) if 'executors' in contents: proto_executors = [] for e in contents['executors']: proto_executors.append( _create_executor(e['name'], networks[e['network']], e['data'], e['output'], e.get('remp', {}))) proto.executor.extend(proto_executors) if include_params is True: params = get_parameters(grad_only=False) for variable_name, variable in params.items(): parameter = proto.parameter.add() parameter.variable_name = variable_name parameter.shape.dim.extend(variable.shape) parameter.data.extend(numpy.array(variable.d).flatten().tolist()) parameter.need_grad = variable.need_grad return proto
def _shrink_nnp(nnp, pos_start, pos_end): if len(nnp.protobuf.executor) != 1 or \ len(nnp.protobuf.network) != 1: print('[ERROR] Please make only one network in nnp.') sys.exit(-1) from nnabla.utils import nnabla_pb2 class _nnp: pass _nnp.protobuf = nnabla_pb2.NNablaProtoBuf() _nnp.other_files = nnp.other_files net = nnabla_pb2.NNablaProtoBuf().network.add() net.CopyFrom(nnp.protobuf.network[0]) # Shrink network variables = {} net.ClearField('function') for i in range(pos_start, pos_end+1): f = nnp.protobuf.network[0].function[i] func = net.function.add() func.CopyFrom(f) for v in func.input: variables[v] = True for v in func.output: variables[v] = True net.ClearField('variable') for v in nnp.protobuf.network[0].variable: if v.name in variables: variables[v.name] = v.type var = net.variable.add() var.CopyFrom(v) # Shrink parameter params = [] for param in nnp.protobuf.parameter: if param.variable_name in variables: p = nnabla_pb2.NNablaProtoBuf().parameter.add() p.CopyFrom(param) params.append(p) for p in params: param = _nnp.protobuf.parameter.add() param.CopyFrom(p) # Shrink executor exe = nnabla_pb2.NNablaProtoBuf().executor.add() exe.CopyFrom(nnp.protobuf.executor[0]) exe.ClearField('data_variable') for vname in nnp.protobuf.network[0].function[pos_start].input: if variables[vname] == 'Buffer': v = exe.data_variable.add() v.variable_name = vname exe.ClearField('generator_variable') for var in nnp.protobuf.executor[0].generator_variable: if var.variable_name in variables: v = exe.generator_variable.add() v.CopyFrom(var) exe.ClearField('output_variable') for vname in nnp.protobuf.network[0].function[pos_end].output: if variables[vname] == 'Buffer': v = exe.output_variable.add() v.variable_name = vname exe.ClearField('parameter_variable') for var in nnp.protobuf.executor[0].parameter_variable: if var.variable_name in variables: v = exe.parameter_variable.add() v.CopyFrom(var) n = _nnp.protobuf.network.add() n.CopyFrom(net) e = _nnp.protobuf.executor.add() e.CopyFrom(exe) return _nnp
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"): """Load parameters from a file with the specified format. Args: path : path or file object """ if isinstance(path, str): _, ext = os.path.splitext(path) else: ext = extension if ext == '.h5': # TODO temporary work around to suppress FutureWarning message. import warnings warnings.simplefilter('ignore', category=FutureWarning) import h5py with get_file_handle_load(path, ext) as hd: keys = [] def _get_keys(name): ds = hd[name] if not isinstance(ds, h5py.Dataset): # Group return # To preserve order of parameters keys.append((ds.attrs.get('index', None), name)) hd.visit(_get_keys) for _, key in sorted(keys): ds = hd[key] var = get_parameter_or_create( key, ds.shape, need_grad=ds.attrs['need_grad']) var.data.cast(ds.dtype)[...] = ds[...] if needs_proto: if proto is None: proto = nnabla_pb2.NNablaProtoBuf() parameter = proto.parameter.add() parameter.variable_name = key parameter.shape.dim.extend(ds.shape) parameter.data.extend( numpy.array(ds[...]).flatten().tolist()) parameter.need_grad = False if ds.attrs['need_grad']: parameter.need_grad = True else: if proto is None: proto = nnabla_pb2.NNablaProtoBuf() if ext == '.protobuf': with get_file_handle_load(path, ext) as f: proto.MergeFromString(f.read()) set_parameter_from_proto(proto) elif ext == '.nntxt' or ext == '.prototxt': with get_file_handle_load(path, ext) as f: text_format.Merge(f.read(), proto) set_parameter_from_proto(proto) elif ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with get_file_handle_load(path, ext) as nnp: for name in nnp.namelist(): nnp.extract(name, tmpdir) _, ext = os.path.splitext(name) if ext in ['.protobuf', '.h5']: proto = load_parameters(os.path.join( tmpdir, name), proto, needs_proto) finally: shutil.rmtree(tmpdir) logger.info("Parameter load ({}): {}".format(format, path)) else: logger.error("Invalid parameter file '{}'".format(path)) return proto
def proto_from_str(nntxt_str): proto = nnabla_pb2.NNablaProtoBuf() text_format.Merge(nntxt_str, proto) return proto
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False): '''load Load network information from files. Args: filenames (list): List of filenames. Returns: dict: Network information. ''' class Info: pass info = Info() proto = nnabla_pb2.NNablaProtoBuf() for filename in filenames: _, ext = os.path.splitext(filename) # TODO: Here is some known problems. # - Even when protobuf file includes network structure, # it will not loaded. # - Even when prototxt file includes parameter, # it will not loaded. if ext in ['.nntxt', '.prototxt']: if not parameter_only: with open(filename, 'rt') as f: try: text_format.Merge(f.read(), proto) except: logger.critical('Failed to read {}.'.format(filename)) logger.critical( '2 byte characters may be used for file name or folder name.' ) raise if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters(filename) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: nn.load_parameters(filename) else: logger.info('Skip loading parameter.') elif ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with zipfile.ZipFile(filename, 'r') as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if name == 'nnp_version.txt': nnp.extract(name, tmpdir) with open(os.path.join(tmpdir, name), 'rt') as f: pass # TODO currently do nothing with version. elif ext in ['.nntxt', '.prototxt']: nnp.extract(name, tmpdir) if not parameter_only: with open(os.path.join(tmpdir, name), 'rt') as f: text_format.Merge(f.read(), proto) if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters( os.path.join(tmpdir, name)) elif ext in ['.protobuf', '.h5']: nnp.extract(name, tmpdir) if not exclude_parameter: nn.load_parameters(os.path.join(tmpdir, name)) else: logger.info('Skip loading parameter.') finally: shutil.rmtree(tmpdir) default_context = None if proto.HasField('global_config'): info.global_config = _global_config(proto) default_context = info.global_config.default_context if 'cuda' in default_context.backend: import nnabla_ext.cudnn elif 'cuda:float' in default_context.backend: try: import nnabla_ext.cudnn except: pass else: import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() comm = current_communicator() if comm: default_context.device_id = str(comm.rank) if proto.HasField('training_config'): info.training_config = _training_config(proto) info.datasets = _datasets( proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0) info.networks = _networks(proto, default_context, batch_size) info.optimizers = _optimizers(proto, default_context, info.networks, info.datasets) info.monitors = _monitors(proto, default_context, info.networks, info.datasets) info.executors = _executors(proto, info.networks) return info
def _expand_repeat(self, network): def _search_repeat_id(mes, rid): return list( mes.repeat_id).index(rid) if rid in mes.repeat_id else None def _add_suffix(name, suffix, num): return '{}_{}_{}'.format(name, suffix, num) ######################################################################## # Prepare output network message net = nnabla_pb2.NNablaProtoBuf().network.add() net.CopyFrom(network) ######################################################################## # Finish when repeat_info is not present if len(net.repeat_info) == 0: return net ######################################################################## # Use first repeat_info ri = net.repeat_info[0] del net.repeat_info[0] ######################################################################## # Expand variables net.ClearField('variable') for vpos, var in enumerate(network.variable): if var.type == 'Parameter': if var.name not in self._parameter_original_names: self._parameter_original_names[var.name] = [] pos = _search_repeat_id(var, ri.id) if pos is not None: for i in range(ri.times): if var.type == 'Parameter': if self._old_version_param_name: name = _add_suffix(var.name, ri.id, i) else: name = var.name.replace('{{{}}}'.format(ri.id), '_{}'.format(i)) self._parameter_original_names[var.name].append(name) else: name = _add_suffix(var.name, ri.id, i) v = net.variable.add() v.CopyFrom(var) v.name = name del v.repeat_id[pos] else: if var.type == 'Parameter' and len(var.repeat_id) == 0 and len( self._parameter_original_names[var.name]) == 0: self._parameter_original_names[var.name].append(var.name) v = net.variable.add() v.CopyFrom(var) ######################################################################## # Expand functions ######################################################################## ######################################################################## # Prepare delayed inputs delay_var = {} for fpos, func in enumerate(network.function): if func.type == 'Delay': if func.recurrent_param.repeat_id == ri.id: delay_var[func.output[0]] = [] for i in range(ri.times): if i == 0: delay_var[func.output[0]].append(func.input[1]) else: v = func.input[0] v = _add_suffix(v, ri.id, i - 1) delay_var[func.output[0]].append(v) ######################################################################## # Prepare repeat end inputs repeat_end_var = {} for fpos, func in enumerate(network.function): if func.type == 'RepeatEnd': if func.repeat_param.repeat_id == ri.id: repeat_end_var[func.output[0]] = [] for i in range(func.repeat_param.times): repeat_end_var[func.output[0]].append( _add_suffix(func.input[0], func.repeat_param.repeat_id, i)) ######################################################################## # Prepare repeat start inputs repeat_start_var = {} for fpos, func in enumerate(network.function): if func.type == 'RepeatStart': if func.repeat_param.repeat_id == ri.id: repeat_start_var[func.output[0]] = [] for i in range(ri.times): if i == 0: v = func.input[0] if v in repeat_end_var: v = repeat_end_var[v][ri.times - 1] repeat_start_var[func.output[0]].append(v) else: v = func.input[1] if v in repeat_end_var: v = repeat_end_var[v][i - 1] else: v = _add_suffix(v, ri.id, i - 1) repeat_start_var[func.output[0]].append(v) ######################################################################## # Expand network net.ClearField('function') for fpos, func in enumerate(network.function): if func.type == 'RepeatStart' or func.type == 'RepeatEnd': if func.repeat_param.repeat_id == ri.id: continue if func.type == 'Delay': if func.recurrent_param.repeat_id == ri.id: continue if func.type == 'RecurrentInput': if func.recurrent_param.repeat_id == ri.id: f = net.function.add() f.CopyFrom(func) f.type = 'Split' f.split_param.axis = func.recurrent_param.axis f.ClearField('output') for i in range(ri.times): f.output.append(_add_suffix(func.output[0], ri.id, i)) pos = _search_repeat_id(func, ri.id) del f.repeat_id[pos] f.ClearField('recurrent_param') continue if func.type == 'RecurrentOutput': if func.recurrent_param.repeat_id == ri.id: f = net.function.add() f.CopyFrom(func) f.type = 'Stack' f.stack_param.axis = func.recurrent_param.axis f.ClearField('input') for i in range(ri.times): f.input.append(_add_suffix(func.input[0], ri.id, i)) f.ClearField('recurrent_param') continue pos = _search_repeat_id(func, ri.id) if pos is not None: for i in range(ri.times): f = net.function.add() f.CopyFrom(func) del f.repeat_id[pos] f.name = _add_suffix(func.name, ri.id, i) for n, v in enumerate(func.input): vname = None if v in self._parameter_original_names: if len(self._parameter_original_names[v] ) == ri.times: vname = self._parameter_original_names[v][i] else: vname = v elif v in repeat_start_var: vname = repeat_start_var[v][i] elif v in repeat_end_var: vname = repeat_end_var[v][i] elif v in delay_var: vname = delay_var[v][i] else: vname = _add_suffix(v, ri.id, i) f.input[n] = vname for n, v in enumerate(func.output): vname = _add_suffix(v, ri.id, i) f.output[n] = vname else: f = net.function.add() f.CopyFrom(func) for n, v in enumerate(func.input): if v in repeat_end_var: vname = repeat_end_var[v][ri.times - 1] f.input[n] = vname return self._expand_repeat(net)
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False, extension=".nntxt"): '''load Load network information from files. Args: filenames (list): file-like object or List of filenames. extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp". Returns: dict: Network information. ''' class Info: pass info = Info() proto = nnabla_pb2.NNablaProtoBuf() # optimizer checkpoint opti_proto = nnabla_pb2.NNablaProtoBuf() OPTI_BUF_EXT = ['.optimizer'] opti_h5_files = {} tmpdir = tempfile.mkdtemp() if isinstance(filenames, list) or isinstance(filenames, tuple): pass elif isinstance(filenames, str) or hasattr(filenames, 'read'): filenames = [filenames] for filename in filenames: if isinstance(filename, str): _, ext = os.path.splitext(filename) else: ext = extension # TODO: Here is some known problems. # - Even when protobuf file includes network structure, # it will not loaded. # - Even when prototxt file includes parameter, # it will not loaded. if ext in ['.nntxt', '.prototxt']: if not parameter_only: with get_file_handle_load(filename, ext) as f: try: text_format.Merge(f.read(), proto) except: logger.critical('Failed to read {}.'.format(filename)) logger.critical( '2 byte characters may be used for file name or folder name.' ) raise if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters(filename, extension=ext) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: nn.load_parameters(filename, extension=ext) else: logger.info('Skip loading parameter.') elif ext == '.nnp': with get_file_handle_load(filename, ext) as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if name == 'nnp_version.txt': pass # TODO currently do nothing with version. elif ext in ['.nntxt', '.prototxt']: if not parameter_only: with nnp.open(name, 'r') as f: text_format.Merge(f.read(), proto) if len(proto.parameter) > 0: if not exclude_parameter: with nnp.open(name, 'r') as f: nn.load_parameters(f, extension=ext) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: with nnp.open(name, 'r') as f: nn.load_parameters(f, extension=ext) else: logger.info('Skip loading parameter.') elif ext in OPTI_BUF_EXT: buf_type = get_buf_type(name) if buf_type == 'protobuf': with nnp.open(name, 'r') as f: with get_file_handle_load( f, '.protobuf') as opti_p: opti_proto.MergeFromString(opti_p.read()) elif buf_type == 'h5': nnp.extract(name, tmpdir) opti_h5_files[name] = os.path.join(tmpdir, name) default_context = None if proto.HasField('global_config'): info.global_config = _global_config(proto) default_context = info.global_config.default_context if 'cuda' in default_context.backend: import nnabla_ext.cudnn elif 'cuda:float' in default_context.backend: try: import nnabla_ext.cudnn except: pass try: x = nn.Variable() y = nn.Variable() func = F.ReLU(default_context, inplace=True) func.setup([x], [y]) func.forward([x], [y]) except: logger.warn('Fallback to CPU context.') import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() else: import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() comm = current_communicator() if comm: default_context.device_id = str(comm.local_rank) if proto.HasField('training_config'): info.training_config = _training_config(proto) info.datasets = _datasets( proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0) info.networks = _networks(proto, default_context, batch_size) info.optimizers = _optimizers(proto, default_context, info.networks, info.datasets) _load_optimizer_checkpoint(opti_proto, opti_h5_files, info) shutil.rmtree(tmpdir) info.monitors = _monitors(proto, default_context, info.networks, info.datasets) info.executors = _executors(proto, info.networks) return info