def test_file_close_exception(extension): with pytest.raises(ZeroDivisionError) as excinfo: with create_temp_with_dir("tmp{}".format(extension)) as filename: with get_file_handle_save(filename, ext=extension) as f: file_handler = f 1 / 0 assert file_handler.closed with pytest.raises(ZeroDivisionError) as excinfo: with create_temp_with_dir("tmp{}".format(extension)) as filename: # create a file at first with open(filename, "w") as f: f.write("\n") with get_file_handle_load(None, filename, ext=extension) as f: file_handler = f 1 / 0 assert file_handler.closed
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False, extension=".nntxt"): '''load Load network information from files. Args: filenames (list): file-like object or List of filenames. extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp". Returns: dict: Network information. ''' class Info: pass info = Info() proto = nnabla_pb2.NNablaProtoBuf() if isinstance(filenames, list) or isinstance(filenames, tuple): pass elif isinstance(filenames, str) or hasattr(filenames, 'read'): filenames = [filenames] for filename in filenames: if isinstance(filename, str): _, ext = os.path.splitext(filename) else: ext = extension # TODO: Here is some known problems. # - Even when protobuf file includes network structure, # it will not loaded. # - Even when prototxt file includes parameter, # it will not loaded. if ext in ['.nntxt', '.prototxt']: if not parameter_only: with get_file_handle_load(filename, ext) as f: try: text_format.Merge(f.read(), proto) except: logger.critical('Failed to read {}.'.format(filename)) logger.critical( '2 byte characters may be used for file name or folder name.' ) raise if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters(filename, extension=ext) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: nn.load_parameters(filename, extension=ext) else: logger.info('Skip loading parameter.') elif ext == '.nnp': with get_file_handle_load(filename, ext) as nnp: for name in nnp.namelist(): _, ext = os.path.splitext(name) if name == 'nnp_version.txt': pass # TODO currently do nothing with version. elif ext in ['.nntxt', '.prototxt']: if not parameter_only: with nnp.open(name, 'r') as f: text_format.Merge(f.read(), proto) if len(proto.parameter) > 0: if not exclude_parameter: with nnp.open(name, 'r') as f: nn.load_parameters(f, extension=ext) elif ext in ['.protobuf', '.h5']: if not exclude_parameter: with nnp.open(name, 'r') as f: nn.load_parameters(f, extension=ext) else: logger.info('Skip loading parameter.') default_context = None if proto.HasField('global_config'): info.global_config = _global_config(proto) default_context = info.global_config.default_context if 'cuda' in default_context.backend: import nnabla_ext.cudnn elif 'cuda:float' in default_context.backend: try: import nnabla_ext.cudnn except: pass try: x = nn.Variable() y = nn.Variable() func = F.ReLU(default_context, inplace=True) func.setup([x], [y]) func.forward([x], [y]) except: logger.warn('Fallback to CPU context.') import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() else: import nnabla_ext.cpu default_context = nnabla_ext.cpu.context() comm = current_communicator() if comm: default_context.device_id = str(comm.rank) if proto.HasField('training_config'): info.training_config = _training_config(proto) info.datasets = _datasets( proto, prepare_data_iterator if prepare_data_iterator is not None else info.training_config.max_epoch > 0) info.networks = _networks(proto, default_context, batch_size) info.optimizers = _optimizers(proto, default_context, info.networks, info.datasets) info.monitors = _monitors(proto, default_context, info.networks, info.datasets) info.executors = _executors(proto, info.networks) return info
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"): """Load parameters from a file with the specified format. Args: path : path or file object """ if isinstance(path, str): _, ext = os.path.splitext(path) else: ext = extension if ext == '.h5': # TODO temporary work around to suppress FutureWarning message. import warnings warnings.simplefilter('ignore', category=FutureWarning) import h5py with get_file_handle_load(path, ext) as hd: keys = [] def _get_keys(name): ds = hd[name] if not isinstance(ds, h5py.Dataset): # Group return # To preserve order of parameters keys.append((ds.attrs.get('index', None), name)) hd.visit(_get_keys) for _, key in sorted(keys): ds = hd[key] var = get_parameter_or_create( key, ds.shape, need_grad=ds.attrs['need_grad']) var.data.cast(ds.dtype)[...] = ds[...] if needs_proto: if proto is None: proto = nnabla_pb2.NNablaProtoBuf() parameter = proto.parameter.add() parameter.variable_name = key parameter.shape.dim.extend(ds.shape) parameter.data.extend( numpy.array(ds[...]).flatten().tolist()) parameter.need_grad = False if ds.attrs['need_grad']: parameter.need_grad = True else: if proto is None: proto = nnabla_pb2.NNablaProtoBuf() if ext == '.protobuf': with get_file_handle_load(path, ext) as f: proto.MergeFromString(f.read()) set_parameter_from_proto(proto) elif ext == '.nntxt' or ext == '.prototxt': with get_file_handle_load(path, ext) as f: text_format.Merge(f.read(), proto) set_parameter_from_proto(proto) elif ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with get_file_handle_load(path, ext) as nnp: for name in nnp.namelist(): nnp.extract(name, tmpdir) _, ext = os.path.splitext(name) if ext in ['.protobuf', '.h5']: proto = load_parameters(os.path.join( tmpdir, name), proto, needs_proto) finally: shutil.rmtree(tmpdir) logger.info("Parameter load ({}): {}".format(format, path)) else: logger.error("Invalid parameter file '{}'".format(path)) return proto