def func_set_nnabla_support(version_spec=None): if version_spec is None: func_list = [] for cat, cat_info in _nnabla_func_info.items(): for func, func_info in cat_info.items(): func_list.append(func) return set(func_list) else: nnp, nnp_version = version_spec # load func_info to _nnabla_func_info_cur for cat, cat_info in _nnabla_func_info.items(): for func, func_info in cat_info.items(): _nnabla_func_info_cur[func] = func_info old_set = func_set_get_from_repo(nnp_version) unsupported = _func_set_nnabla_unsupport() # logger.info(f"The following functions are not supported in version: {nnp_version}:") # for f in unsupported: # logger.info(f"{f}") old_set -= unsupported nnp_set = func_set_import_nnp(nnp) if nnp_set & old_set != nnp_set: unsupported_functions = nnp_set - old_set logger.error( f"The following functions are not supported in nnabla v{nnp_version} but appear in .nnp file." ) for f in unsupported_functions: logger.error(f"{f}") raise ValueError( f"nnp file contains unsupported functions by nnabla version: {nnp_version}." ) return old_set
def _import_backend(self, backend): try: module = importlib.import_module( "audio_utils.{}_utils".format(backend)) except ImportError: logger.error("Import {} as audio backend failed".format(backend)) # log backend status in _check_backend module = None self.backends[backend] = module
def _get_analysis(self, comm): def _analyse_gpu_cost_time(result, threshold): aver = np.mean(result, axis=0)[1] _node_l = [*filter(lambda n: n[1] > aver * threshold, result)] if len(_node_l): ranks = ', '.join([str(int(n[0])) for n in _node_l]) _str = ('Gpu of Rank {} ran slower than average ' 'by a factor of {} or more'.format(ranks, threshold)) return _str return '' result = self._reap_multinode_data(comm) if comm.rank == 0: error_str = _analyse_gpu_cost_time(result, self._error_threshold) if error_str: logger.error(error_str) raise Exception(error_str) else: warning_str = _analyse_gpu_cost_time(result, self._warning_threshold) if warning_str: logger.warning(warning_str)
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"): """Load parameters from a file with the specified format. Args: path : path or file object """ if isinstance(path, str): _, ext = os.path.splitext(path) else: ext = extension if ext == '.h5': # TODO temporary work around to suppress FutureWarning message. import warnings warnings.simplefilter('ignore', category=FutureWarning) import h5py with get_file_handle_load(path, ext) as hd: keys = [] def _get_keys(name): ds = hd[name] if not isinstance(ds, h5py.Dataset): # Group return # To preserve order of parameters keys.append((ds.attrs.get('index', None), name)) hd.visit(_get_keys) for _, key in sorted(keys): ds = hd[key] var = get_parameter_or_create( key, ds.shape, need_grad=ds.attrs['need_grad']) var.data.cast(ds.dtype)[...] = ds[...] if needs_proto: if proto is None: proto = nnabla_pb2.NNablaProtoBuf() parameter = proto.parameter.add() parameter.variable_name = key parameter.shape.dim.extend(ds.shape) parameter.data.extend( numpy.array(ds[...]).flatten().tolist()) parameter.need_grad = False if ds.attrs['need_grad']: parameter.need_grad = True else: if proto is None: proto = nnabla_pb2.NNablaProtoBuf() if ext == '.protobuf': with get_file_handle_load(path, ext) as f: proto.MergeFromString(f.read()) set_parameter_from_proto(proto) elif ext == '.nntxt' or ext == '.prototxt': with get_file_handle_load(path, ext) as f: text_format.Merge(f.read(), proto) set_parameter_from_proto(proto) elif ext == '.nnp': try: tmpdir = tempfile.mkdtemp() with get_file_handle_load(path, ext) as nnp: for name in nnp.namelist(): nnp.extract(name, tmpdir) _, ext = os.path.splitext(name) if ext in ['.protobuf', '.h5']: proto = load_parameters(os.path.join( tmpdir, name), proto, needs_proto) finally: shutil.rmtree(tmpdir) logger.info("Parameter load ({}): {}".format(format, path)) else: logger.error("Invalid parameter file '{}'".format(path)) return proto
# Copyright (c) 2019 Sony Corporation. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. enabled = False try: import nvidia.dali from ._dali_iterator import * enabled = True except ImportError as e: print(e) from nnabla.logger import logger logger.error( "Skip importing nnabla daliIterator wrapper because nvidia.dali is not found." " Please make sure you've installed nvidia.dali.")