예제 #1
0
def func_set_nnabla_support(version_spec=None):
    if version_spec is None:
        func_list = []
        for cat, cat_info in _nnabla_func_info.items():
            for func, func_info in cat_info.items():
                func_list.append(func)
        return set(func_list)
    else:
        nnp, nnp_version = version_spec
        # load func_info to _nnabla_func_info_cur
        for cat, cat_info in _nnabla_func_info.items():
            for func, func_info in cat_info.items():
                _nnabla_func_info_cur[func] = func_info
        old_set = func_set_get_from_repo(nnp_version)
        unsupported = _func_set_nnabla_unsupport()
        # logger.info(f"The following functions are not supported in version: {nnp_version}:")
        # for f in unsupported:
        #     logger.info(f"{f}")
        old_set -= unsupported
        nnp_set = func_set_import_nnp(nnp)
        if nnp_set & old_set != nnp_set:
            unsupported_functions = nnp_set - old_set
            logger.error(
                f"The following functions are not supported in nnabla v{nnp_version} but appear in .nnp file."
            )
            for f in unsupported_functions:
                logger.error(f"{f}")
            raise ValueError(
                f"nnp file contains unsupported functions by nnabla version: {nnp_version}."
            )
        return old_set
예제 #2
0
    def _import_backend(self, backend):
        try:
            module = importlib.import_module(
                "audio_utils.{}_utils".format(backend))
        except ImportError:
            logger.error("Import {} as audio backend failed".format(backend))
            # log backend status in _check_backend
            module = None

        self.backends[backend] = module
예제 #3
0
    def _get_analysis(self, comm):
        def _analyse_gpu_cost_time(result, threshold):
            aver = np.mean(result, axis=0)[1]
            _node_l = [*filter(lambda n: n[1] > aver * threshold, result)]
            if len(_node_l):
                ranks = ', '.join([str(int(n[0])) for n in _node_l])
                _str = ('Gpu of Rank {} ran slower than average '
                        'by a factor of {} or more'.format(ranks, threshold))
                return _str
            return ''

        result = self._reap_multinode_data(comm)
        if comm.rank == 0:
            error_str = _analyse_gpu_cost_time(result, self._error_threshold)
            if error_str:
                logger.error(error_str)
                raise Exception(error_str)
            else:
                warning_str = _analyse_gpu_cost_time(result,
                                                     self._warning_threshold)
                if warning_str:
                    logger.warning(warning_str)
예제 #4
0
파일: parameter.py 프로젝트: ujtakk/nnabla
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"):
    """Load parameters from a file with the specified format.

    Args:
      path : path or file object
    """
    if isinstance(path, str):
        _, ext = os.path.splitext(path)
    else:
        ext = extension

    if ext == '.h5':
        # TODO temporary work around to suppress FutureWarning message.
        import warnings
        warnings.simplefilter('ignore', category=FutureWarning)
        import h5py
        with get_file_handle_load(path, ext) as hd:
            keys = []

            def _get_keys(name):
                ds = hd[name]
                if not isinstance(ds, h5py.Dataset):
                    # Group
                    return
                # To preserve order of parameters
                keys.append((ds.attrs.get('index', None), name))
            hd.visit(_get_keys)
            for _, key in sorted(keys):
                ds = hd[key]

                var = get_parameter_or_create(
                    key, ds.shape, need_grad=ds.attrs['need_grad'])
                var.data.cast(ds.dtype)[...] = ds[...]

                if needs_proto:
                    if proto is None:
                        proto = nnabla_pb2.NNablaProtoBuf()
                    parameter = proto.parameter.add()
                    parameter.variable_name = key
                    parameter.shape.dim.extend(ds.shape)
                    parameter.data.extend(
                        numpy.array(ds[...]).flatten().tolist())
                    parameter.need_grad = False
                    if ds.attrs['need_grad']:
                        parameter.need_grad = True

    else:
        if proto is None:
            proto = nnabla_pb2.NNablaProtoBuf()

        if ext == '.protobuf':
            with get_file_handle_load(path, ext) as f:
                proto.MergeFromString(f.read())
                set_parameter_from_proto(proto)
        elif ext == '.nntxt' or ext == '.prototxt':
            with get_file_handle_load(path, ext) as f:
                text_format.Merge(f.read(), proto)
                set_parameter_from_proto(proto)

        elif ext == '.nnp':
            try:
                tmpdir = tempfile.mkdtemp()
                with get_file_handle_load(path, ext) as nnp:
                    for name in nnp.namelist():
                        nnp.extract(name, tmpdir)
                        _, ext = os.path.splitext(name)
                        if ext in ['.protobuf', '.h5']:
                            proto = load_parameters(os.path.join(
                                tmpdir, name), proto, needs_proto)
            finally:
                shutil.rmtree(tmpdir)
                logger.info("Parameter load ({}): {}".format(format, path))
        else:
            logger.error("Invalid parameter file '{}'".format(path))
    return proto
예제 #5
0
# Copyright (c) 2019 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

enabled = False
try:
    import nvidia.dali
    from ._dali_iterator import *
    enabled = True
except ImportError as e:
    print(e)
    from nnabla.logger import logger
    logger.error(
        "Skip importing nnabla daliIterator wrapper because nvidia.dali is not found."
        " Please make sure you've installed nvidia.dali.")