コード例 #1
0
ファイル: nnp_graph.py プロジェクト: ssgalitsky/nnabla
    def _get_variable_or_create(self, v, callback, current_scope):

        if v.variable is not None:
            return v.variable

        v = callback._apply_generate_variable(v)

        if v.variable is not None:
            return v.variable

        pvar = v.proto
        name = pvar.name
        shape = list(pvar.shape.dim)
        if shape[0] < 0:
            shape[0] = self.batch_size
        shape = tuple(shape)
        assert np.all(np.array(shape) > 0
                      ), "Shape must be positive. Given {}.".format(shape)

        if pvar.type != 'Parameter':
            # Create a new variable and returns.
            var = nn.Variable(shape)
            v.variable = var
            var.name = name
            return var

        # Trying to load the parameter from .nnp file.
        callback.verbose('Loading parameter `{}` from .nnp.'.format(name))
        try:
            param = get_parameter(name)
            if param is None:
                logger.info(
                    'Parameter `{}` is not found. Initializing.'.format(name))
                tmp = _create_variable(pvar, name, shape, self.rng)
                param = tmp.variable_instance
                set_parameter(name, param)
            # Always copy param to current scope even if it already exists.
            with nn.parameter_scope('', current_scope):
                set_parameter(name, param)
        except:
            import sys
            import traceback
            raise ValueError(
                'An error occurs during creation of a variable `{}` as a'
                ' parameter variable. The error was:\n----\n{}\n----\n'
                'The parameters registered was {}'.format(
                    name, traceback.format_exc(), '\n'.join(
                        list(nn.get_parameters(grad_only=False).keys()))))
        assert shape == param.shape
        param = param.get_unlinked_variable(need_grad=v.need_grad)
        v.variable = param
        param.name = name
        return param
コード例 #2
0
ファイル: importer.py プロジェクト: yuhonghong7035/nnabla
    def generate_parameters_data(self, var_list, batch_size):
        from nnabla.utils.load import _create_variable
        import numpy as np

        rng = np.random.RandomState(0)
        for var in var_list:
            shape = tuple([d if d >= 1 else batch_size for d in var.shape.dim])
            variable = _create_variable(var, var.name, shape, rng)
            p = self._nnp.parameter.add()
            p.variable_name = variable.name
            p.shape.dim.extend(variable.shape)
            p.data.extend(variable.variable_instance.d.flatten())