def generate_case_from_nntxt_str(nntxt_str, param_format, dataset_sample_num, batch_size=None): proto = proto_from_str(nntxt_str) with generate_csv_png(dataset_sample_num, get_input_size(proto)) as dataset_csv_file: # To test dataset, we create a randomly generated dataset. for ds in proto.dataset: ds.batch_size = batch_size if batch_size else ds.batch_size ds.uri = dataset_csv_file ds.cache_dir = os.path.join(os.path.dirname(dataset_csv_file), "data.cache") nntxt_io = io.StringIO() text_format.PrintMessage(proto, nntxt_io) nntxt_io.seek(0) version = io.StringIO() version.write('{}\n'.format(nnp_version())) version.seek(0) param = io.BytesIO() prepare_parameters(nntxt_str) nn.parameter.save_parameters(param, extension=param_format) with create_temp_with_dir(NNP_FILE) as temp_nnp_file_name: with get_file_handle_save(temp_nnp_file_name, ".nnp") as nnp: nnp.writestr('nnp_version.txt', version.read()) nnp.writestr('network.nntxt', nntxt_io.read()) nnp.writestr('parameter{}'.format(param_format), param.read()) yield temp_nnp_file_name
def save_parameters(path, params=None, extension=None): """Save all parameters into a file with the specified format. Currently hdf5 and protobuf formats are supported. Args: path : path or file object params (dict, optional): Parameters to be saved. Dictionary is of a parameter name (:obj:`str`) to :obj:`~nnabla.Variable`. """ if isinstance(path, str): _, ext = os.path.splitext(path) else: ext = extension params = get_parameters(grad_only=False) if params is None else params if ext == '.h5': # TODO temporary work around to suppress FutureWarning message. import warnings warnings.simplefilter('ignore', category=FutureWarning) import h5py with get_file_handle_save(path, ext) as hd: for i, (k, v) in enumerate(iteritems(params)): hd[k] = v.d hd[k].attrs['need_grad'] = v.need_grad # To preserve order of parameters hd[k].attrs['index'] = i elif ext == '.protobuf': proto = nnabla_pb2.NNablaProtoBuf() for variable_name, variable in params.items(): parameter = proto.parameter.add() parameter.variable_name = variable_name parameter.shape.dim.extend(variable.shape) parameter.data.extend(numpy.array(variable.d).flatten().tolist()) parameter.need_grad = variable.need_grad with get_file_handle_save(path, ext) as f: f.write(proto.SerializeToString()) else: logger.critical('Only supported hdf5 or protobuf.') assert False logger.info("Parameter save ({}): {}".format(ext, path))
def test_file_close_exception(extension): with pytest.raises(ZeroDivisionError) as excinfo: with create_temp_with_dir("tmp{}".format(extension)) as filename: with get_file_handle_save(filename, ext=extension) as f: file_handler = f 1 / 0 assert file_handler.closed with pytest.raises(ZeroDivisionError) as excinfo: with create_temp_with_dir("tmp{}".format(extension)) as filename: # create a file at first with open(filename, "w") as f: f.write("\n") with get_file_handle_load(None, filename, ext=extension) as f: file_handler = f 1 / 0 assert file_handler.closed
def save_optimizer_states(filebase, ext, train_config): filelist = [] if ext == '.protobuf': filename = filebase + '_optimizer.protobuf.optimizer' proto = nnabla_pb2.NNablaProtoBuf() proto_optimizers = [] for o in train_config.optimizers.values(): proto_optimizers.append(_create_optimizer_lite(o.optimizer)) proto.optimizer.extend(proto_optimizers) with get_file_handle_save(filename, '.protobuf') as f: f.write(proto.SerializeToString()) filelist.append(filename) else: for o in train_config.optimizers.values(): f_name = '{}_{}_optimizer.h5'.format( o.optimizer.name, re.sub(r'(|Cuda)$', '', str(o.optimizer.solver.name))) filename = '{}_{}'.format(filebase, f_name) o.optimizer.solver.save_states(filename) name_ext = '{}.optimizer'.format(filename) os.rename(filename, name_ext) filelist.append(name_ext) return filelist
def save(filename, contents, include_params=False, variable_batch_size=True, extension=".nnp"): '''Save network definition, inference/training execution configurations etc. Args: filename (str or file object): Filename to store information. The file extension is used to determine the saving file format. ``.nnp``: (Recommended) Creating a zip archive with nntxt (network definition etc.) and h5 (parameters). ``.nntxt``: Protobuf in text format. ``.protobuf``: Protobuf in binary format (unsafe in terms of backward compatibility). contents (dict): Information to store. include_params (bool): Includes parameter into single file. This is ignored when the extension of filename is nnp. variable_batch_size (bool): By ``True``, the first dimension of all variables is considered as batch size, and left as a placeholder (more specifically ``-1``). The placeholder dimension will be filled during/after loading. extension: if files is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp". Example: The following example creates a two inputs and two outputs MLP, and save the network structure and the initialized parameters. .. code-block:: python import nnabla as nn import nnabla.functions as F import nnabla.parametric_functions as PF from nnabla.utils.save import save batch_size = 16 x0 = nn.Variable([batch_size, 100]) x1 = nn.Variable([batch_size, 100]) h1_0 = PF.affine(x0, 100, name='affine1_0') h1_1 = PF.affine(x1, 100, name='affine1_0') h1 = F.tanh(h1_0 + h1_1) h2 = F.tanh(PF.affine(h1, 50, name='affine2')) y0 = PF.affine(h2, 10, name='affiney_0') y1 = PF.affine(h2, 10, name='affiney_1') contents = { 'networks': [ {'name': 'net1', 'batch_size': batch_size, 'outputs': {'y0': y0, 'y1': y1}, 'names': {'x0': x0, 'x1': x1}}], 'executors': [ {'name': 'runtime', 'network': 'net1', 'data': ['x0', 'x1'], 'output': ['y0', 'y1']}]} save('net.nnp', contents) To get a trainable model, use following code instead. .. code-block:: python contents = { 'global_config': {'default_context': ctx}, 'training_config': {'max_epoch': args.max_epoch, 'iter_per_epoch': args_added.iter_per_epoch, 'save_best': True}, 'networks': [ {'name': 'training', 'batch_size': args.batch_size, 'outputs': {'loss': loss_t}, 'names': {'x': x, 'y': t, 'loss': loss_t}}, {'name': 'validation', 'batch_size': args.batch_size, 'outputs': {'loss': loss_v}, 'names': {'x': x, 'y': t, 'loss': loss_v}}], 'optimizers': [ {'name': 'optimizer', 'solver': solver, 'network': 'training', 'dataset': 'mnist_training', 'weight_decay': 0, 'lr_decay': 1, 'lr_decay_interval': 1, 'update_interval': 1}], 'datasets': [ {'name': 'mnist_training', 'uri': 'MNIST_TRAINING', 'cache_dir': args.cache_dir + '/mnist_training.cache/', 'variables': {'x': x, 'y': t}, 'shuffle': True, 'batch_size': args.batch_size, 'no_image_normalization': True}, {'name': 'mnist_validation', 'uri': 'MNIST_VALIDATION', 'cache_dir': args.cache_dir + '/mnist_test.cache/', 'variables': {'x': x, 'y': t}, 'shuffle': False, 'batch_size': args.batch_size, 'no_image_normalization': True }], 'monitors': [ {'name': 'training_loss', 'network': 'validation', 'dataset': 'mnist_training'}, {'name': 'validation_loss', 'network': 'validation', 'dataset': 'mnist_validation'}], } ''' if isinstance(filename, str): _, ext = os.path.splitext(filename) else: ext = extension if ext == '.nntxt' or ext == '.prototxt': logger.info("Saving {} as prototxt".format(filename)) proto = create_proto(contents, include_params, variable_batch_size) with get_file_handle_save(filename, ext) as file: text_format.PrintMessage(proto, file) elif ext == '.protobuf': logger.info("Saving {} as protobuf".format(filename)) proto = create_proto(contents, include_params, variable_batch_size) with get_file_handle_save(filename, ext) as file: file.write(proto.SerializeToString()) elif ext == '.nnp': logger.info("Saving {} as nnp".format(filename)) nntxt = io.StringIO() save(nntxt, contents, include_params=False, variable_batch_size=variable_batch_size, extension='.nntxt') nntxt.seek(0) version = io.StringIO() version.write('{}\n'.format(nnp_version())) version.seek(0) param = io.BytesIO() save_parameters(param, extension='.protobuf') param.seek(0) with get_file_handle_save(filename, ext) as nnp: nnp.writestr('nnp_version.txt', version.read()) nnp.writestr('network.nntxt', nntxt.read()) nnp.writestr('parameter.protobuf', param.read())