Пример #1
0
def test_numpy_functions_same_dimensions():
    values = [np.array([1, 2]), np.ones((3, 3))]
    units = [volt, second, siemens, mV, kHz]

    # numpy functions
    keep_dim_funcs = [
        np.abs, np.cumsum, np.max, np.mean, np.min, np.negative, np.ptp,
        np.round, np.squeeze, np.std, np.sum, np.transpose
    ]

    for value, unit in itertools.product(values, units):
        q_ar = value * unit
        for func in keep_dim_funcs:
            test_ar = func(q_ar)
            if not get_dimensions(test_ar) is q_ar.dim:
                raise AssertionError(
                    ('%s failed on %s -- dim was %s, is now '
                     '%s') % (func.__name__, repr(q_ar), q_ar.dim,
                              get_dimensions(test_ar)))

    # Python builtins should work on one-dimensional arrays
    value = np.arange(5)
    builtins = [abs, max, min, sum]
    for unit in units:
        q_ar = value * unit
        for func in builtins:
            test_ar = func(q_ar)
            if not get_dimensions(test_ar) is q_ar.dim:
                raise AssertionError(
                    ('%s failed on %s -- dim was %s, is now '
                     '%s') % (func.__name__, repr(q_ar), q_ar.dim,
                              get_dimensions(test_ar)))
Пример #2
0
def assert_quantity(q, values, unit):
    assert isinstance(q, Quantity) or (isinstance(q, np.ndarray)
                                       and have_same_dimensions(unit, 1))
    assert_equal(np.asarray(q), values)
    assert have_same_dimensions(
        q, unit), ('Dimension mismatch: (%s) (%s)' %
                   (get_dimensions(q), get_dimensions(unit)))
Пример #3
0
def test_numpy_functions_same_dimensions():
    values = [np.array([1, 2]), np.ones((3, 3))]
    units = [volt, second, siemens, mV, kHz]

    # numpy functions
    keep_dim_funcs = [np.abs, np.cumsum, np.max, np.mean, np.min, np.negative,
                      np.ptp, np.round, np.squeeze, np.std, np.sum,
                      np.transpose]

    for value, unit in itertools.product(values, units):
        q_ar = value * unit
        for func in keep_dim_funcs:
            test_ar = func(q_ar)
            if not get_dimensions(test_ar) is q_ar.dim:
                raise AssertionError(('%s failed on %s -- dim was %s, is now '
                                      '%s') % (func.__name__, repr(q_ar),
                                               q_ar.dim,
                                               get_dimensions(test_ar)))
    
    # Python builtins should work on one-dimensional arrays
    value = np.arange(5)
    builtins = [abs, max, min, sum]
    for unit in units:
        q_ar = value * unit
        for func in builtins:
            test_ar = func(q_ar)
            if not get_dimensions(test_ar) is q_ar.dim:
                raise AssertionError(('%s failed on %s -- dim was %s, is now '
                                      '%s') % (func.__name__, repr(q_ar),
                                               q_ar.dim,
                                               get_dimensions(test_ar)))
Пример #4
0
def test_get_dimensions():
    '''
    Test various ways of getting/comparing the dimensions of a quantity.
    '''
    q = 500 * ms
    assert get_dimensions(q) is get_or_create_dimension(q.dimensions._dims)
    assert get_dimensions(q) is q.dimensions
    assert q.has_same_dimensions(3 * second)
    dims = q.dimensions
    assert_equal(dims.get_dimension('time'), 1.)
    assert_equal(dims.get_dimension('length'), 0)

    assert get_dimensions(5) is DIMENSIONLESS
    assert get_dimensions(5.0) is DIMENSIONLESS
    assert get_dimensions(np.array(5, dtype=np.int)) is DIMENSIONLESS
    assert get_dimensions(np.array(5.0)) is DIMENSIONLESS
    assert get_dimensions(np.float32(5.0)) is DIMENSIONLESS
    assert get_dimensions(np.float64(5.0)) is DIMENSIONLESS
    assert is_scalar_type(5)
    assert is_scalar_type(5.0)
    assert is_scalar_type(np.array(5, dtype=np.int))
    assert is_scalar_type(np.array(5.0))
    assert is_scalar_type(np.float32(5.0))
    assert is_scalar_type(np.float64(5.0))
    with pytest.raises(TypeError):
        get_dimensions('a string')
    # wrong number of indices
    with pytest.raises(TypeError):
        get_or_create_dimension([1, 2, 3, 4, 5, 6])
    # not a sequence
    with pytest.raises(TypeError):
        get_or_create_dimension(42)
Пример #5
0
def assert_quantity(q, values, unit):
    assert isinstance(q, Quantity) or (have_same_dimensions(unit, 1) and
                                       (values.shape == () or
                                        isinstance(q, np.ndarray))), q
    assert_allclose(np.asarray(q), values)
    assert have_same_dimensions(q, unit), ('Dimension mismatch: (%s) (%s)' %
                                           (get_dimensions(q),
                                            get_dimensions(unit)))
Пример #6
0
def assert_quantity(q, values, unit):
    assert isinstance(
        q, Quantity) or (have_same_dimensions(unit, 1) and
                         (values.shape == () or isinstance(q, np.ndarray))), q
    assert_allclose(np.asarray(q), values)
    assert have_same_dimensions(
        q, unit), ('Dimension mismatch: (%s) (%s)' %
                   (get_dimensions(q), get_dimensions(unit)))
Пример #7
0
def test_get_dimensions():
    '''
    Test various ways of getting/comparing the dimensions of a quantity.
    '''
    q = 500 * ms
    assert get_dimensions(q) is get_or_create_dimension(q.dimensions._dims)
    assert get_dimensions(q) is q.dimensions
    assert q.has_same_dimensions(3 * second)
    dims = q.dimensions
    assert_equal(dims.get_dimension('time'), 1.)
    assert_equal(dims.get_dimension('length'), 0)
    
    assert get_dimensions(5) is DIMENSIONLESS
    assert get_dimensions(5.0) is DIMENSIONLESS
    assert get_dimensions(np.array(5, dtype=np.int)) is DIMENSIONLESS
    assert get_dimensions(np.array(5.0)) is DIMENSIONLESS
    assert get_dimensions(np.float32(5.0)) is DIMENSIONLESS
    assert get_dimensions(np.float64(5.0)) is DIMENSIONLESS
    assert is_scalar_type(5)
    assert is_scalar_type(5.0)
    assert is_scalar_type(np.array(5, dtype=np.int))
    assert is_scalar_type(np.array(5.0))
    assert is_scalar_type(np.float32(5.0))
    assert is_scalar_type(np.float64(5.0))
    assert_raises(TypeError, lambda: get_dimensions('a string'))
    # wrong number of indices
    assert_raises(TypeError, lambda: get_or_create_dimension([1, 2, 3, 4, 5, 6]))
    # not a sequence
    assert_raises(TypeError, lambda: get_or_create_dimension(42))
Пример #8
0
def test_parse_equations():
    ''' Test the parsing of equation strings '''
    # A simple equation
    eqs = parse_string_equations('dv/dt = -v / tau : 1')
    assert len(eqs.keys()) == 1 and 'v' in eqs and eqs['v'].type == DIFFERENTIAL_EQUATION
    assert get_dimensions(eqs['v'].unit) == DIMENSIONLESS

    # A complex one
    eqs = parse_string_equations('''dv/dt = -(v +
                                             ge + # excitatory conductance
                                             I # external current
                                             )/ tau : volt
                                    dge/dt = -ge / tau_ge : volt
                                    I = sin(2 * pi * f * t) : volt
                                    f : Hz (constant)
                                    b : boolean
                                    n : integer
                                 ''')
    assert len(eqs.keys()) == 6
    assert 'v' in eqs and eqs['v'].type == DIFFERENTIAL_EQUATION
    assert 'ge' in eqs and eqs['ge'].type == DIFFERENTIAL_EQUATION
    assert 'I' in eqs and eqs['I'].type == SUBEXPRESSION
    assert 'f' in eqs and eqs['f'].type == PARAMETER
    assert 'b' in eqs and eqs['b'].type == PARAMETER
    assert 'n' in eqs and eqs['n'].type == PARAMETER
    assert eqs['f'].var_type == FLOAT
    assert eqs['b'].var_type == BOOLEAN
    assert eqs['n'].var_type == INTEGER
    assert get_dimensions(eqs['v'].unit) == volt.dim
    assert get_dimensions(eqs['ge'].unit) == volt.dim
    assert get_dimensions(eqs['I'].unit) == volt.dim
    assert get_dimensions(eqs['f'].unit) == Hz.dim
    assert eqs['v'].flags == []
    assert eqs['ge'].flags == []
    assert eqs['I'].flags == []
    assert eqs['f'].flags == ['constant']

    duplicate_eqs = '''
    dv/dt = -v / tau : 1
    v = 2 * t : 1
    '''
    assert_raises(EquationError, lambda: parse_string_equations(duplicate_eqs))
    parse_error_eqs = [
    '''dv/d = -v / tau : 1
        x = 2 * t : 1''',
    '''dv/dt = -v / tau : 1 : volt
    x = 2 * t : 1''',
    ''' dv/dt = -v / tau : 2 * volt''',
    'dv/dt = v / second : boolean']
    for error_eqs in parse_error_eqs:
        assert_raises((ValueError, EquationError, TypeError),
                      lambda: parse_string_equations(error_eqs))
Пример #9
0
def test_parse_equations():
    ''' Test the parsing of equation strings '''
    # A simple equation
    eqs = parse_string_equations('dv/dt = -v / tau : 1')
    assert len(eqs.keys()
               ) == 1 and 'v' in eqs and eqs['v'].type == DIFFERENTIAL_EQUATION
    assert get_dimensions(eqs['v'].unit) == DIMENSIONLESS

    # A complex one
    eqs = parse_string_equations('''dv/dt = -(v +
                                             ge + # excitatory conductance
                                             I # external current
                                             )/ tau : volt
                                    dge/dt = -ge / tau_ge : volt
                                    I = sin(2 * pi * f * t) : volt
                                    f : Hz (constant)
                                    b : boolean
                                    n : integer
                                 ''')
    assert len(eqs.keys()) == 6
    assert 'v' in eqs and eqs['v'].type == DIFFERENTIAL_EQUATION
    assert 'ge' in eqs and eqs['ge'].type == DIFFERENTIAL_EQUATION
    assert 'I' in eqs and eqs['I'].type == SUBEXPRESSION
    assert 'f' in eqs and eqs['f'].type == PARAMETER
    assert 'b' in eqs and eqs['b'].type == PARAMETER
    assert 'n' in eqs and eqs['n'].type == PARAMETER
    assert eqs['f'].var_type == FLOAT
    assert eqs['b'].var_type == BOOLEAN
    assert eqs['n'].var_type == INTEGER
    assert get_dimensions(eqs['v'].unit) == volt.dim
    assert get_dimensions(eqs['ge'].unit) == volt.dim
    assert get_dimensions(eqs['I'].unit) == volt.dim
    assert get_dimensions(eqs['f'].unit) == Hz.dim
    assert eqs['v'].flags == []
    assert eqs['ge'].flags == []
    assert eqs['I'].flags == []
    assert eqs['f'].flags == ['constant']

    duplicate_eqs = '''
    dv/dt = -v / tau : 1
    v = 2 * t : 1
    '''
    assert_raises(EquationError, lambda: parse_string_equations(duplicate_eqs))
    parse_error_eqs = [
        '''dv/d = -v / tau : 1
        x = 2 * t : 1''', '''dv/dt = -v / tau : 1 : volt
    x = 2 * t : 1''', ''' dv/dt = -v / tau : 2 * volt''',
        'dv/dt = v / second : boolean'
    ]
    for error_eqs in parse_error_eqs:
        assert_raises((ValueError, EquationError, TypeError),
                      lambda: parse_string_equations(error_eqs))
Пример #10
0
            def wrapper_function(*args):
                arg_units = list(self._function._arg_units)

                if self._function.auto_vectorise:
                    arg_units += [DIMENSIONLESS]
                if not len(args) == len(arg_units):
                    raise ValueError(('Function %s got %d arguments, '
                                      'expected %d') % (self._function.pyfunc.__name__, len(args),
                                                        len(arg_units)))
                new_args = []
                for arg, arg_unit in zip(args, arg_units):
                    if arg_unit == bool or arg_unit is None or isinstance(arg_unit, str):
                        new_args.append(arg)
                    else:
                        new_args.append(Quantity.with_dimensions(arg,
                                                                 get_dimensions(arg_unit)))
                result = orig_func(*new_args)
                if isinstance(self._function._return_unit, Callable):
                    return_unit = self._function._return_unit(*[get_dimensions(a)
                                                                for a in args])
                else:
                    return_unit = self._function._return_unit
                if return_unit == bool:
                    if not (isinstance(result, bool) or
                            np.asarray(result).dtype == bool):
                        raise TypeError('The function %s returned '
                                        '%s, but it was expected '
                                        'to return a boolean '
                                        'value ' % (orig_func.__name__,
                                                    result))
                elif (isinstance(return_unit, int) and return_unit == 1) or return_unit.dim is DIMENSIONLESS:
                    fail_for_dimension_mismatch(result,
                                                return_unit,
                                                'The function %s returned '
                                                '{value}, but it was expected '
                                                'to return a dimensionless '
                                                'quantity' % orig_func.__name__,
                                                value=result)
                else:
                    fail_for_dimension_mismatch(result,
                                                return_unit,
                                                ('The function %s returned '
                                                 '{value}, but it was expected '
                                                 'to return a quantity with '
                                                 'units %r') % (orig_func.__name__,
                                                                return_unit),
                                                value=result)
                return np.asarray(result)
Пример #11
0
            def wrapper_function(*args):
                arg_units = list(self._function._arg_units)

                if self._function.auto_vectorise:
                    arg_units += [DIMENSIONLESS]
                if not len(args) == len(arg_units):
                    func_name = self._function.pyfunc.__name__
                    raise ValueError(
                        f"Function {func_name} got {len(args)} arguments, "
                        f"expected {len(arg_units)}.")
                new_args = []
                for arg, arg_unit in zip(args, arg_units):
                    if arg_unit == bool or arg_unit is None or isinstance(
                            arg_unit, str):
                        new_args.append(arg)
                    else:
                        new_args.append(
                            Quantity.with_dimensions(arg,
                                                     get_dimensions(arg_unit)))
                result = orig_func(*new_args)
                if isinstance(self._function._return_unit, Callable):
                    return_unit = self._function._return_unit(
                        *[get_dimensions(a) for a in args])
                else:
                    return_unit = self._function._return_unit
                if return_unit == bool:
                    if not (isinstance(result, bool)
                            or np.asarray(result).dtype == bool):
                        raise TypeError(
                            f"The function {orig_func.__name__} returned "
                            f"'{result}', but it was expected to return a "
                            f"boolean value ")
                elif (isinstance(return_unit, int) and return_unit
                      == 1) or return_unit.dim is DIMENSIONLESS:
                    fail_for_dimension_mismatch(
                        result, return_unit,
                        f"The function '{orig_func.__name__}' "
                        f"returned {result}, but it was "
                        f"expected to return a dimensionless "
                        f"quantity.")
                else:
                    fail_for_dimension_mismatch(
                        result, return_unit,
                        f"The function '{orig_func.__name__}' "
                        f"returned {result}, but it was "
                        f"expected to return a quantity with "
                        f"units {return_unit!r}.")
                return np.asarray(result)
Пример #12
0
 def wrapper_function(*args):
     if not len(args) == len(self._function._arg_units):
         raise ValueError(('Function %s got %d arguments, '
                           'expected %d') % (self._function.pyfunc.__name__, len(args),
                                             len(self._function._arg_units)))
     new_args = [Quantity.with_dimensions(arg, get_dimensions(arg_unit))
                 for arg, arg_unit in zip(args, self._function._arg_units)]
     result = orig_func(*new_args)
     return_unit = self._function._return_unit
     if return_unit is 1 or return_unit.dim is DIMENSIONLESS:
         fail_for_dimension_mismatch(result,
                                     return_unit,
                                     'The function %s returned '
                                     '{value}, but it was expected '
                                     'to return a dimensionless '
                                     'quantity' % orig_func.__name__,
                                     value=result)
     else:
         fail_for_dimension_mismatch(result,
                                     return_unit,
                                     ('The function %s returned '
                                      '{value}, but it was expected '
                                      'to return a quantity with '
                                      'units %r') % (orig_func.__name__,
                                                     return_unit),
                                     value=result)
     return np.asarray(result)
Пример #13
0
 def check_units(self, unit, variable_units):
     '''
     Check whether the dimensions of the expression match the expected
     dimensions.
     
     Parameters
     ----------
     unit : `Unit` or 1
         The expected unit (or 1 for dimensionless).
     variable_units : dict
         A dictionary mapping internal variable names to their units.                 
     
     Notes
     -----
     The namespace has to be resolved using the
     `~brian2.equations.codestrings.CodeString.resolve` method first.
     
     Raises
     ------
     DimensionMismatchError
         If the expression uses inconsistent units or the resulting unit does
         not match the expected `unit`.
     '''
     expr_dimensions = self.get_dimensions(variable_units)
     expected_dimensions = get_dimensions(unit)
     if not expr_dimensions == expected_dimensions:
         raise DimensionMismatchError('Dimensions of expression does not '
                                      'match its definition',
                                      expr_dimensions, expected_dimensions)
Пример #14
0
    def __init__(self,
                 type,
                 varname,
                 dimensions,
                 var_type=FLOAT,
                 expr=None,
                 flags=None):
        self.type = type
        self.varname = varname
        self.dim = get_dimensions(dimensions)
        self.var_type = var_type
        if dimensions is not DIMENSIONLESS:
            if var_type == BOOLEAN:
                raise TypeError(
                    'Boolean variables are necessarily dimensionless.')
            elif var_type == INTEGER:
                raise TypeError(
                    'Integer variables are necessarily dimensionless.')

        if type == DIFFERENTIAL_EQUATION:
            if var_type != FLOAT:
                raise TypeError(
                    'Differential equations can only define floating point variables'
                )
        self.expr = expr
        if flags is None:
            self.flags = []
        else:
            self.flags = flags

        # will be set later in the sort_subexpressions method of Equations
        self.update_order = -1
Пример #15
0
def test_parse_equations():
    ''' Test the parsing of equation strings '''
    # A simple equation
    eqs = parse_string_equations('dv/dt = -v / tau : 1', {}, False, 0)    
    assert len(eqs.keys()) == 1 and 'v' in eqs and eqs['v'].eq_type == DIFFERENTIAL_EQUATION
    assert get_dimensions(eqs['v'].unit) == DIMENSIONLESS
    
    # A complex one
    eqs = parse_string_equations('''dv/dt = -(v +
                                             ge + # excitatory conductance
                                             I # external current
                                             )/ tau : volt
                                    dge/dt = -ge / tau_ge : volt
                                    I = sin(2 * pi * f * t) : volt
                                    f : Hz (constant)
                                 ''', 
                                 {}, False, 0)
    assert len(eqs.keys()) == 4
    assert 'v' in eqs and eqs['v'].eq_type == DIFFERENTIAL_EQUATION
    assert 'ge' in eqs and eqs['ge'].eq_type == DIFFERENTIAL_EQUATION
    assert 'I' in eqs and eqs['I'].eq_type == STATIC_EQUATION
    assert 'f' in eqs and eqs['f'].eq_type == PARAMETER
    assert get_dimensions(eqs['v'].unit) == volt.dim
    assert get_dimensions(eqs['ge'].unit) == volt.dim
    assert get_dimensions(eqs['I'].unit) == volt.dim
    assert get_dimensions(eqs['f'].unit) == Hz.dim
    assert eqs['v'].flags == []
    assert eqs['ge'].flags == []
    assert eqs['I'].flags == []
    assert eqs['f'].flags == ['constant']
    
    duplicate_eqs = '''
    dv/dt = -v / tau : 1
    v = 2 * t : 1
    '''
    assert_raises(SyntaxError, lambda: parse_string_equations(duplicate_eqs,
                                                              {}, False, 0))
    parse_error_eqs = [
    '''dv/d = -v / tau : 1
        x = 2 * t : 1''',
    '''dv/dt = -v / tau : 1 : volt
    x = 2 * t : 1''',
    ''' dv/dt = -v / tau : 2 * volt''']
    for error_eqs in parse_error_eqs:
        assert_raises((ValueError, SyntaxError), lambda: parse_string_equations(error_eqs,
                                                                                {}, False, 0))
Пример #16
0
 def wrapper_function(*args):
     if not len(args) == len(self._function._arg_units):
         raise ValueError(('Function %s got %d arguments, '
                           'expected %d') % (self._function.name, len(args),
                                             len(self._function._arg_units)))
     new_args = [Quantity.with_dimensions(arg, get_dimensions(arg_unit))
                 for arg, arg_unit in zip(args, self._function._arg_units)]
     result = orig_func(*new_args)
     fail_for_dimension_mismatch(result, self._function._return_unit)
     return np.asarray(result)
Пример #17
0
    def __init__(self,
                 target,
                 target_var,
                 N,
                 rate,
                 weight,
                 when='synapses',
                 order=0):
        if target_var not in target.variables:
            raise KeyError('%s is not a variable of %s' %
                           (target_var, target.name))

        self._weight = weight
        self._target_var = target_var

        if isinstance(weight, str):
            weight = '(%s)' % weight
        else:
            weight_dims = get_dimensions(weight)
            target_dims = target.variables[target_var].dim
            # This will be checked automatically in the abstract code as well
            # but doing an explicit check here allows for a clearer error
            # message
            if not have_same_dimensions(weight_dims, target_dims):
                raise DimensionMismatchError(
                    ('The provided weight does not '
                     'have the same unit as the '
                     'target variable "%s"') % target_var, weight_dims,
                    target_dims)
            weight = repr(weight)
        self._N = N
        self._rate = rate
        binomial_sampling = BinomialFunction(N,
                                             rate * target.clock.dt,
                                             name='poissoninput_binomial*')

        code = '{targetvar} += {binomial}()*{weight}'.format(
            targetvar=target_var,
            binomial=binomial_sampling.name,
            weight=weight)
        self._stored_dt = target.dt_[:]  # make a copy
        # FIXME: we need an explicit reference here for on-the-fly subgroups
        # For example: PoissonInput(group[:N], ...)
        self._group = target
        CodeRunner.__init__(self,
                            group=target,
                            template='stateupdate',
                            code=code,
                            user_code='',
                            when=when,
                            order=order,
                            name='poissoninput*',
                            clock=target.clock)
        self.variables = Variables(self)
        self.variables._add_variable(binomial_sampling.name, binomial_sampling)
Пример #18
0
 def wrapper_function(*args):
     if not len(args) == len(self._function._arg_units):
         raise ValueError(
             ('Function %s got %d arguments, '
              'expected %d') % (self._function.name, len(args),
                                len(self._function._arg_units)))
     new_args = [
         Quantity.with_dimensions(arg, get_dimensions(arg_unit))
         for arg, arg_unit in zip(args, self._function._arg_units)
     ]
     result = orig_func(*new_args)
     fail_for_dimension_mismatch(result,
                                 self._function._return_unit)
     return np.asarray(result)
Пример #19
0
    def __init__(self,
                 model,
                 input_var,
                 input,
                 output_var,
                 output,
                 dt,
                 n_samples=30,
                 method=None,
                 reset=None,
                 refractory=False,
                 threshold=None,
                 level=0,
                 param_init=None,
                 t_start=0 * second):
        """Initialize the fitter."""
        super().__init__(dt, model, input, output, input_var, output_var,
                         n_samples, threshold, reset, refractory, method,
                         param_init)

        self.output = Quantity(output)
        self.output_ = array(output)

        if output_var not in self.model.names:
            raise NameError("%s is not a model variable" % output_var)
        if output.shape != input.shape:
            raise ValueError("Input and output must have the same size")

        # Replace input variable by TimedArray
        output_traces = TimedArray(output.transpose(), dt=dt)
        output_dim = get_dimensions(output)
        squared_output_dim = ('1' if output_dim is DIMENSIONLESS else repr(
            output_dim**2))
        error_eqs = Equations('total_error : {}'.format(squared_output_dim))
        self.model = self.model + error_eqs

        self.t_start = t_start

        if param_init:
            for param, val in param_init.items():
                if not (param in self.model.identifiers
                        or param in self.model.names):
                    raise ValueError("%s is not a model variable or an "
                                     "identifier in the model" % param)
            self.param_init = param_init

        self.simulator = None
Пример #20
0
 def __init__(self, values, dt, name=None):
     if name is None:
         name = '_timedarray*'
     Nameable.__init__(self, name)
     dimensions = get_dimensions(values)
     self.dim = dimensions
     values = np.asarray(values, dtype=np.double)
     self.values = values
     dt = float(dt)
     self.dt = dt
     if values.ndim == 1:
         self._init_1d()
     elif values.ndim == 2:
         self._init_2d()
     else:
         raise NotImplementedError(('Only 1d and 2d arrays are supported '
                                    'for TimedArray'))
Пример #21
0
 def __init__(self, values, dt, name=None):
     if name is None:
         name = '_timedarray*'
     Nameable.__init__(self, name)
     dimensions = get_dimensions(values)
     self.dim = dimensions
     values = np.asarray(values, dtype=np.double)
     self.values = values
     dt = float(dt)
     self.dt = dt
     if values.ndim == 1:
         self._init_1d()
     elif values.ndim == 2:
         self._init_2d()
     else:
         raise NotImplementedError(('Only 1d and 2d arrays are supported '
                                    'for TimedArray'))
Пример #22
0
 def __init__(self, values, dt, name=None, tOffset=0 * second):
     if name is None:
         name = "_timedarray*"
     Nameable.__init__(self, name)
     dimensions = get_dimensions(values)
     self.dim = dimensions
     values = np.asarray(values, dtype=np.double)
     self.values = values
     dt = float(dt)
     self.dt = dt
     self.tOffset = tOffset
     if values.ndim == 1:
         self._init_1d()
     elif values.ndim == 2:
         self._init_2d()
     else:
         raise NotImplementedError(("Only 1d and 2d arrays are supported "
                                    "for TimedArray"))
Пример #23
0
    def __init__(self, target, target_var, N, rate, weight, when='synapses',
                 order=0):
        if target_var not in target.variables:
            raise KeyError('%s is not a variable of %s' % (target_var, target.name))

        if isinstance(weight, basestring):
            weight = '(%s)' % weight
        else:
            weight_dims = get_dimensions(weight)
            target_dims = target.variables[target_var].dim
            # This will be checked automatically in the abstract code as well
            # but doing an explicit check here allows for a clearer error
            # message
            if not have_same_dimensions(weight_dims, target_dims):
                raise DimensionMismatchError(('The provided weight does not '
                                              'have the same unit as the '
                                              'target variable "%s"') % target_var,
                                             weight_dims,
                                             target_dims)
            weight = repr(weight)
        self._N = N
        self._rate = rate
        binomial_sampling = BinomialFunction(N, rate*target.clock.dt,
                                             name='poissoninput_binomial*')

        code = '{targetvar} += {binomial}()*{weight}'.format(targetvar=target_var,
                                                             binomial=binomial_sampling.name,
                                                             weight=weight)
        self._stored_dt = target.dt_[:]  # make a copy
        # FIXME: we need an explicit reference here for on-the-fly subgroups
        # For example: PoissonInput(group[:N], ...)
        self._group = target
        CodeRunner.__init__(self,
                            group=target,
                            template='stateupdate',
                            code=code,
                            user_code='',
                            when=when,
                            order=order,
                            name='poissoninput*',
                            clock=target.clock
                            )
        self.variables = Variables(self)
        self.variables._add_variable(binomial_sampling.name, binomial_sampling)
Пример #24
0
 def get_dimensions(self, variable_units):
     '''
     Return the dimensions of the expression by evaluating it in its
     namespace, replacing all internal variables with their units.
     
     Parameters
     ----------
     variable_units : dict
         A dictionary mapping variable names to their units.
     
     Notes
     -----
     The namespace has to be resolved using the
     `~brian2.equations.codestrings.CodeString.resolve` method first.
     
     Raises
     ------
     DimensionMismatchError
         If the expression uses inconsistent units.
     '''
     return get_dimensions(self.eval(variable_units))
Пример #25
0
        def _callback_wrapper(params, iter, resid, *args, **kwds):
            error = mean(resid**2)
            errors.append(error)
            if self.use_units:
                error_dim = self.output_dim**2 * get_dimensions(
                    normalization)**2
                all_errors = Quantity(errors, dim=error_dim)
                params = {
                    p: Quantity(val, dim=self.model[p].dim)
                    for p, val in params.items()
                }
            else:
                all_errors = array(errors)
                params = {p: float(val) for p, val in params.items()}
            tested_parameters.append(params)

            best_idx = argmin(errors)
            best_error = all_errors[best_idx]
            best_params = tested_parameters[best_idx]

            return callback_func(params, all_errors, best_params, best_error,
                                 iter)
Пример #26
0
    def __init__(self, type, varname, dimensions, var_type=FLOAT, expr=None,
                 flags=None):
        self.type = type
        self.varname = varname
        self.dim = get_dimensions(dimensions)
        self.var_type = var_type
        if dimensions is not DIMENSIONLESS:
            if var_type == BOOLEAN:
                raise TypeError('Boolean variables are necessarily dimensionless.')
            elif var_type == INTEGER:
                raise TypeError('Integer variables are necessarily dimensionless.')

        if type == DIFFERENTIAL_EQUATION:
            if var_type != FLOAT:
                raise TypeError('Differential equations can only define floating point variables')
        self.expr = expr
        if flags is None:
            self.flags = []
        else:
            self.flags = flags

        # will be set later in the sort_subexpressions method of Equations
        self.update_order = -1
Пример #27
0
def parse_expression_unit(expr, variables):
    '''
    Returns the unit value of an expression, and checks its validity
    
    Parameters
    ----------
    expr : str
        The expression to check.
    variables : dict
        Dictionary of all variables used in the `expr` (including `Constant`
        objects for external variables)
    
    Returns
    -------
    unit : Quantity
        The output unit of the expression
    
    Raises
    ------
    SyntaxError
        If the expression cannot be parsed, or if it uses ``a**b`` for ``b``
        anything other than a constant number.
    DimensionMismatchError
        If any part of the expression is dimensionally inconsistent.
    '''

    # If we are working on a string, convert to the top level node    
    if isinstance(expr, basestring):
        mod = ast.parse(expr, mode='eval')
        expr = mod.body
    if expr.__class__ is getattr(ast, 'NameConstant', None):
        # new class for True, False, None in Python 3.4
        value = expr.value
        if value is True or value is False:
            return Unit(1)
        else:
            raise ValueError('Do not know how to handle value %s' % value)
    if expr.__class__ is ast.Name:
        name = expr.id
        # Raise an error if a function is called as if it were a variable
        # (most of the time this happens for a TimedArray)
        if name in variables and isinstance(variables[name], Function):
            raise SyntaxError('%s was used like a variable/constant, but it is '
                              'a function.' % name)
        if name in variables:
            return variables[name].unit
        elif name in ['True', 'False']:
            return Unit(1)
        else:
            raise KeyError('Unknown identifier %s' % name)
    elif expr.__class__ is ast.Num:
        return get_unit_fast(1)
    elif expr.__class__ is ast.BoolOp:
        # check that the units are valid in each subexpression
        for node in expr.values:
            parse_expression_unit(node, variables)
        # but the result is a bool, so we just return 1 as the unit
        return get_unit_fast(1)
    elif expr.__class__ is ast.Compare:
        # check that the units are consistent in each subexpression
        subexprs = [expr.left]+expr.comparators
        subunits = []
        for node in subexprs:
            subunits.append(parse_expression_unit(node, variables))
        for left, right in zip(subunits[:-1], subunits[1:]):
            if not have_same_dimensions(left, right):
                msg = ('Comparison of expressions with different units. Expression '
                       '"{}" has unit ({}), while expression "{}" has units ({})').format(
                            NodeRenderer().render_node(expr.left), get_dimensions(left),
                            NodeRenderer().render_node(expr.comparators[0]), get_dimensions(right))
                raise DimensionMismatchError(msg)
        # but the result is a bool, so we just return 1 as the unit
        return get_unit_fast(1)
    elif expr.__class__ is ast.Call:
        if len(expr.keywords):
            raise ValueError("Keyword arguments not supported.")
        elif getattr(expr, 'starargs', None) is not None:
            raise ValueError("Variable number of arguments not supported")
        elif getattr(expr, 'kwargs', None) is not None:
            raise ValueError("Keyword arguments not supported")

        func = variables.get(expr.func.id, None)
        if func is None:
            raise SyntaxError('Unknown function %s' % expr.func.id)
        if not hasattr(func, '_arg_units') or not hasattr(func, '_return_unit'):
            raise ValueError(('Function %s does not specify how it '
                              'deals with units.') % expr.func.id)

        if len(func._arg_units) != len(expr.args):
            raise SyntaxError('Function %s was called with %d parameters, '
                              'needs %d.' % (expr.func.id,
                                             len(expr.args),
                                             len(func._arg_units)))

        for idx, (arg, expected_unit) in enumerate(zip(expr.args,
                                                       func._arg_units)):
            # A "None" in func._arg_units means: No matter what unit
            if expected_unit is None:
                continue
            elif expected_unit == bool:
                if not is_boolean_expression(arg, variables):
                    raise TypeError(('Argument number %d for function %s was '
                                     'expected to be a boolean value, but is '
                                     '"%s".') % (idx + 1, expr.func.id,
                                                 NodeRenderer().render_node(arg)))
            else:
                arg_unit = parse_expression_unit(arg, variables)
                if not have_same_dimensions(arg_unit, expected_unit):
                    msg = ('Argument number {} for function {} does not have the '
                           'correct units. Expression "{}" has units ({}), but '
                           'should be ({}).').format(
                        idx+1, expr.func.id,
                        NodeRenderer().render_node(arg),
                        get_dimensions(arg_unit), get_dimensions(expected_unit))
                    raise DimensionMismatchError(msg)

        if func._return_unit == bool:
            return Unit(1)
        elif isinstance(func._return_unit, (Unit, int)):
            # Function always returns the same unit
            return get_unit_fast(func._return_unit)
        else:
            # Function returns a unit that depends on the arguments
            arg_units = [parse_expression_unit(arg, variables)
                         for arg in expr.args]
            return func._return_unit(*arg_units)

    elif expr.__class__ is ast.BinOp:
        op = expr.op.__class__.__name__
        left = parse_expression_unit(expr.left, variables)
        right = parse_expression_unit(expr.right, variables)
        if op=='Add' or op=='Sub':
            u = left+right
        elif op=='Mult':
            u = left*right
        elif op=='Div':
            u = left/right
        elif op=='Pow':
            if have_same_dimensions(left, 1) and have_same_dimensions(right, 1):
                return get_unit_fast(1)
            n = _get_value_from_expression(expr.right, variables)
            u = left**n
        elif op=='Mod':
            u = left % right
        else:
            raise SyntaxError("Unsupported operation "+op)
        return u
    elif expr.__class__ is ast.UnaryOp:
        op = expr.op.__class__.__name__
        # check validity of operand and get its unit
        u = parse_expression_unit(expr.operand, variables)
        if op=='Not':
            return get_unit_fast(1)
        else:
            return u
    else:
        raise SyntaxError('Unsupported operation ' + str(expr.__class__))
Пример #28
0
def get_unit_fast(x):
    """ Return a `Quantity` with value 1 and the same dimensions. """
    return Quantity.with_dimensions(1, get_dimensions(x))
Пример #29
0
def check_units_statements(code, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    code : str
        The statements as a (multi-line) string
    variables : dict of `Variable` objects
        The information about all variables used in `code` (including
        `Constant` objects for external variables)
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    variables = dict(variables)
    # Avoid a circular import
    from brian2.codegen.translation import analyse_identifiers
    newly_defined, _, unknown = analyse_identifiers(code, variables)
    
    if len(unknown):
        raise AssertionError(('Encountered unknown identifiers, this should '
                             'not happen at this stage. Unknown identifiers: %s'
                             % unknown))

    
    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines
        
        varname, op, expr, comment = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op) 

        expr_unit = parse_expression_dimensions(expr, variables)

        if varname in variables:
            expected_unit = variables[varname].dim
            fail_for_dimension_mismatch(expr_unit, expected_unit,
                                        ('The right-hand-side of code '
                                         'statement ""%s" does not have the '
                                         'expected unit %r') % (line,
                                                               expected_unit))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(name=varname,
                                          dimensions=get_dimensions(expr_unit),
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))
Пример #30
0
def check_units_statements(code, variables):
    '''
    Check the units for a series of statements. Setting a model variable has to
    use the correct unit. For newly introduced temporary variables, the unit
    is determined and used to check the following statements to ensure
    consistency.
    
    Parameters
    ----------
    code : str
        The statements as a (multi-line) string
    variables : dict of `Variable` objects
        The information about all variables used in `code` (including
        `Constant` objects for external variables)
    
    Raises
    ------
    KeyError
        In case on of the identifiers cannot be resolved.
    DimensionMismatchError
        If an unit mismatch occurs during the evaluation.
    '''
    variables = dict(variables)
    # Avoid a circular import
    from brian2.codegen.translation import analyse_identifiers
    known = set(variables.keys())
    newly_defined, _, unknown = analyse_identifiers(code, known)
    
    if len(unknown):
        raise AssertionError(('Encountered unknown identifiers, this should '
                             'not happen at this stage. Unkown identifiers: %s'
                             % unknown))

    
    code = re.split(r'[;\n]', code)
    for line in code:
        line = line.strip()
        if not len(line):
            continue  # skip empty lines
        
        varname, op, expr, comment = parse_statement(line)
        if op in ('+=', '-=', '*=', '/=', '%='):
            # Replace statements such as "w *=2" by "w = w * 2"
            expr = '{var} {op_first} {expr}'.format(var=varname,
                                                    op_first=op[0],
                                                    expr=expr)
            op = '='
        elif op == '=':
            pass
        else:
            raise AssertionError('Unknown operator "%s"' % op) 

        expr_unit = parse_expression_dimensions(expr, variables)

        if varname in variables:
            expected_unit = variables[varname].dim
            fail_for_dimension_mismatch(expr_unit, expected_unit,
                                        ('The right-hand-side of code '
                                         'statement ""%s" does not have the '
                                         'expected unit %r') % (line,
                                                               expected_unit))
        elif varname in newly_defined:
            # note the unit for later
            variables[varname] = Variable(name=varname,
                                          dimensions=get_dimensions(expr_unit),
                                          scalar=False)
        else:
            raise AssertionError(('Variable "%s" is neither in the variables '
                                  'dictionary nor in the list of undefined '
                                  'variables.' % varname))
Пример #31
0
 def __init__(self, dimensions):
     self.dim = get_dimensions(dimensions)
Пример #32
0
 def __init__(self, dimensions):
     self.dim = get_dimensions(dimensions)
Пример #33
0
def parse_expression_dimensions(expr, variables):
    '''
    Returns the unit value of an expression, and checks its validity
    
    Parameters
    ----------
    expr : str
        The expression to check.
    variables : dict
        Dictionary of all variables used in the `expr` (including `Constant`
        objects for external variables)
    
    Returns
    -------
    unit : Quantity
        The output unit of the expression
    
    Raises
    ------
    SyntaxError
        If the expression cannot be parsed, or if it uses ``a**b`` for ``b``
        anything other than a constant number.
    DimensionMismatchError
        If any part of the expression is dimensionally inconsistent.
    '''

    # If we are working on a string, convert to the top level node
    if isinstance(expr, basestring):
        mod = ast.parse(expr, mode='eval')
        expr = mod.body
    if expr.__class__ is getattr(ast, 'NameConstant', None):
        # new class for True, False, None in Python 3.4
        value = expr.value
        if value is True or value is False:
            return DIMENSIONLESS
        else:
            raise ValueError('Do not know how to handle value %s' % value)
    if expr.__class__ is ast.Name:
        name = expr.id
        # Raise an error if a function is called as if it were a variable
        # (most of the time this happens for a TimedArray)
        if name in variables and isinstance(variables[name], Function):
            raise SyntaxError(
                '%s was used like a variable/constant, but it is '
                'a function.' % name)
        if name in variables:
            return variables[name].dim
        elif name in ['True', 'False']:
            return DIMENSIONLESS
        else:
            raise KeyError('Unknown identifier %s' % name)
    elif (expr.__class__ is ast.Num
          or expr.__class__ is getattr(ast, 'Constant', None)):  # Python 3.8
        return DIMENSIONLESS
    elif expr.__class__ is ast.BoolOp:
        # check that the units are valid in each subexpression
        for node in expr.values:
            parse_expression_dimensions(node, variables)
        # but the result is a bool, so we just return 1 as the unit
        return DIMENSIONLESS
    elif expr.__class__ is ast.Compare:
        # check that the units are consistent in each subexpression
        subexprs = [expr.left] + expr.comparators
        subunits = []
        for node in subexprs:
            subunits.append(parse_expression_dimensions(node, variables))
        for left_dim, right_dim in zip(subunits[:-1], subunits[1:]):
            if not have_same_dimensions(left_dim, right_dim):
                msg = (
                    'Comparison of expressions with different units. Expression '
                    '"{}" has unit ({}), while expression "{}" has units ({})'
                ).format(NodeRenderer().render_node(expr.left),
                         get_dimensions(left_dim),
                         NodeRenderer().render_node(expr.comparators[0]),
                         get_dimensions(right_dim))
                raise DimensionMismatchError(msg)
        # but the result is a bool, so we just return 1 as the unit
        return DIMENSIONLESS
    elif expr.__class__ is ast.Call:
        if len(expr.keywords):
            raise ValueError("Keyword arguments not supported.")
        elif getattr(expr, 'starargs', None) is not None:
            raise ValueError("Variable number of arguments not supported")
        elif getattr(expr, 'kwargs', None) is not None:
            raise ValueError("Keyword arguments not supported")

        func = variables.get(expr.func.id, None)
        if func is None:
            raise SyntaxError('Unknown function %s' % expr.func.id)
        if not hasattr(func, '_arg_units') or not hasattr(
                func, '_return_unit'):
            raise ValueError(('Function %s does not specify how it '
                              'deals with units.') % expr.func.id)

        if len(func._arg_units) != len(expr.args):
            raise SyntaxError(
                'Function %s was called with %d parameters, '
                'needs %d.' %
                (expr.func.id, len(expr.args), len(func._arg_units)))

        for idx, (arg,
                  expected_unit) in enumerate(zip(expr.args, func._arg_units)):
            # A "None" in func._arg_units means: No matter what unit
            if expected_unit is None:
                continue
            elif expected_unit == bool:
                if not is_boolean_expression(arg, variables):
                    raise TypeError(
                        ('Argument number %d for function %s was '
                         'expected to be a boolean value, but is '
                         '"%s".') % (idx + 1, expr.func.id,
                                     NodeRenderer().render_node(arg)))
            else:
                arg_unit = parse_expression_dimensions(arg, variables)
                if not have_same_dimensions(arg_unit, expected_unit):
                    msg = (
                        'Argument number {} for function {} does not have the '
                        'correct units. Expression "{}" has units ({}), but '
                        'should be ({}).').format(
                            idx + 1, expr.func.id,
                            NodeRenderer().render_node(arg),
                            get_dimensions(arg_unit),
                            get_dimensions(expected_unit))
                    raise DimensionMismatchError(msg)

        if func._return_unit == bool:
            return DIMENSIONLESS
        elif isinstance(func._return_unit, (Unit, int)):
            # Function always returns the same unit
            return getattr(func._return_unit, 'dim', DIMENSIONLESS)
        else:
            # Function returns a unit that depends on the arguments
            arg_units = [
                parse_expression_dimensions(arg, variables)
                for arg in expr.args
            ]
            return func._return_unit(*arg_units).dim

    elif expr.__class__ is ast.BinOp:
        op = expr.op.__class__.__name__
        left_dim = parse_expression_dimensions(expr.left, variables)
        right_dim = parse_expression_dimensions(expr.right, variables)
        if op in ['Add', 'Sub', 'Mod']:
            # dimensions should be the same
            if left_dim is not right_dim:
                op_symbol = {'Add': '+', 'Sub': '-', 'Mod': '%'}.get(op)
                left_str = NodeRenderer().render_node(expr.left)
                right_str = NodeRenderer().render_node(expr.right)
                left_unit = get_unit_for_display(left_dim)
                right_unit = get_unit_for_display(right_dim)
                error_msg = ('Expression "{left} {op} {right}" uses '
                             'inconsistent units ("{left}" has unit '
                             '{left_unit}; "{right}" '
                             'has unit {right_unit})').format(
                                 left=left_str,
                                 right=right_str,
                                 op=op_symbol,
                                 left_unit=left_unit,
                                 right_unit=right_unit)
                raise DimensionMismatchError(error_msg)
            u = left_dim
        elif op == 'Mult':
            u = left_dim * right_dim
        elif op == 'Div':
            u = left_dim / right_dim
        elif op == 'FloorDiv':
            if not (left_dim is DIMENSIONLESS and right_dim is DIMENSIONLESS):
                raise SyntaxError('Floor division can only be used on '
                                  'dimensionless values.')
            u = DIMENSIONLESS
        elif op == 'Pow':
            if left_dim is DIMENSIONLESS and right_dim is DIMENSIONLESS:
                return DIMENSIONLESS
            n = _get_value_from_expression(expr.right, variables)
            u = left_dim**n
        else:
            raise SyntaxError("Unsupported operation " + op)
        return u
    elif expr.__class__ is ast.UnaryOp:
        op = expr.op.__class__.__name__
        # check validity of operand and get its unit
        u = parse_expression_dimensions(expr.operand, variables)
        if op == 'Not':
            return DIMENSIONLESS
        else:
            return u
    else:
        raise SyntaxError('Unsupported operation ' + str(expr.__class__))
Пример #34
0
def get_unit_fast(x):
    """ Return a `Quantity` with value 1 and the same dimensions. """
    return Quantity.with_dimensions(1, get_dimensions(x))
Пример #35
0
    def __init__(self,
                 dt,
                 model,
                 input,
                 output,
                 input_var,
                 output_var,
                 n_samples,
                 threshold,
                 reset,
                 refractory,
                 method,
                 param_init,
                 use_units=True):
        """Initialize the fitter."""

        if dt is None:
            raise ValueError("dt-sampling frequency of the input must be set")

        if isinstance(model, str):
            model = Equations(model)
        if input_var not in model.identifiers:
            raise NameError("%s is not an identifier in the model" % input_var)

        self.dt = dt

        self.simulator = None

        self.parameter_names = model.parameter_names
        self.n_traces, n_steps = input.shape
        self.duration = n_steps * dt
        self.n_neurons = self.n_traces * n_samples

        self.n_samples = n_samples
        self.method = method
        self.threshold = threshold
        self.reset = reset
        self.refractory = refractory

        self.input = input
        self.output_var = output_var
        if output_var == 'spikes':
            self.output_dim = DIMENSIONLESS
        else:
            self.output_dim = model[output_var].dim
        self.model = model

        self.use_units = use_units

        input_dim = get_dimensions(input)
        input_dim = '1' if input_dim is DIMENSIONLESS else repr(input_dim)
        input_eqs = "{} = input_var(t, i % n_traces) : {}".format(
            input_var, input_dim)
        self.model += input_eqs

        input_traces = TimedArray(input.transpose(), dt=dt)
        self.input_traces = input_traces

        # initialization of attributes used later
        self._best_params = None
        self._best_error = None
        self.optimizer = None
        self.metric = None
        if not param_init:
            param_init = {}
        for param, val in param_init.items():
            if not (param in self.model.diff_eq_names
                    or param in self.model.parameter_names):
                raise ValueError("%s is not a model variable or a "
                                 "parameter in the model" % param)
        self.param_init = param_init
Пример #36
0
def assert_quantity(q, values, unit):
    assert isinstance(q, Quantity)
    assert np.all(np.asarray(q) == values)
    assert have_same_dimensions(q, unit), ('Dimension mismatch: (%s) (%s)' %
                                           (get_dimensions(q),
                                            get_dimensions(unit)))
Пример #37
0
def assert_quantity(q, values, unit):
    assert isinstance(q, Quantity) or (isinstance(q, np.ndarray) and have_same_dimensions(unit, 1))
    assert_equal(np.asarray(q), values)
    assert have_same_dimensions(q, unit), ('Dimension mismatch: (%s) (%s)' %
                                           (get_dimensions(q),
                                            get_dimensions(unit)))
Пример #38
0
def parse_expression_dimensions(expr, variables, orig_expr=None):
    """
    Returns the unit value of an expression, and checks its validity
    
    Parameters
    ----------
    expr : str
        The expression to check.
    variables : dict
        Dictionary of all variables used in the `expr` (including `Constant`
        objects for external variables)
    
    Returns
    -------
    unit : Quantity
        The output unit of the expression
    
    Raises
    ------
    SyntaxError
        If the expression cannot be parsed, or if it uses ``a**b`` for ``b``
        anything other than a constant number.
    DimensionMismatchError
        If any part of the expression is dimensionally inconsistent.
    """

    # If we are working on a string, convert to the top level node
    if isinstance(expr, str):
        orig_expr = expr
        mod = ast.parse(expr, mode='eval')
        expr = mod.body
    if expr.__class__ is getattr(ast, 'NameConstant', None):
        # new class for True, False, None in Python 3.4
        value = expr.value
        if value is True or value is False:
            return DIMENSIONLESS
        else:
            raise ValueError(f'Do not know how to handle value {value}')
    if expr.__class__ is ast.Name:
        name = expr.id
        # Raise an error if a function is called as if it were a variable
        # (most of the time this happens for a TimedArray)
        if name in variables and isinstance(variables[name], Function):
            raise SyntaxError(
                f'{name} was used like a variable/constant, but it is a function.',
                ("<string>", expr.lineno, expr.col_offset + 1, orig_expr))
        if name in variables:
            return get_dimensions(variables[name])
        elif name in ['True', 'False']:
            return DIMENSIONLESS
        else:
            raise KeyError(f'Unknown identifier {name}')
    elif (expr.__class__ is ast.Num
          or expr.__class__ is getattr(ast, 'Constant', None)):  # Python 3.8
        return DIMENSIONLESS
    elif expr.__class__ is ast.BoolOp:
        # check that the units are valid in each subexpression
        for node in expr.values:
            parse_expression_dimensions(node, variables, orig_expr=orig_expr)
        # but the result is a bool, so we just return 1 as the unit
        return DIMENSIONLESS
    elif expr.__class__ is ast.Compare:
        # check that the units are consistent in each subexpression
        subexprs = [expr.left] + expr.comparators
        subunits = []
        for node in subexprs:
            subunits.append(
                parse_expression_dimensions(node,
                                            variables,
                                            orig_expr=orig_expr))
        for left_dim, right_dim in zip(subunits[:-1], subunits[1:]):
            if not have_same_dimensions(left_dim, right_dim):
                left_expr = NodeRenderer().render_node(expr.left)
                right_expr = NodeRenderer().render_node(expr.comparators[0])
                dim_left = get_dimensions(left_dim)
                dim_right = get_dimensions(right_dim)
                msg = (
                    f"Comparison of expressions with different units. Expression "
                    f"'{left_expr}' has unit ({dim_left}), while expression "
                    f"'{right_expr}' has units ({dim_right}).")
                raise DimensionMismatchError(msg)
        # but the result is a bool, so we just return 1 as the unit
        return DIMENSIONLESS
    elif expr.__class__ is ast.Call:
        if len(expr.keywords):
            raise ValueError("Keyword arguments not supported.")
        elif getattr(expr, 'starargs', None) is not None:
            raise ValueError("Variable number of arguments not supported")
        elif getattr(expr, 'kwargs', None) is not None:
            raise ValueError("Keyword arguments not supported")

        func = variables.get(expr.func.id, None)
        if func is None:
            raise SyntaxError(
                f'Unknown function {expr.func.id}',
                ("<string>", expr.lineno, expr.col_offset + 1, orig_expr))
        if not hasattr(func, '_arg_units') or not hasattr(
                func, '_return_unit'):
            raise ValueError(
                f"Function {expr.func_id} does not specify how it "
                f"deals with units.")

        if len(func._arg_units) != len(expr.args):
            raise SyntaxError(
                f"Function '{expr.func.id}' was called with "
                f"{len(expr.args)} parameters, needs "
                f"{len(func._arg_units)}.",
                ("<string>", expr.lineno,
                 expr.col_offset + len(expr.func.id) + 1, orig_expr))

        for idx, (arg,
                  expected_unit) in enumerate(zip(expr.args, func._arg_units)):
            arg_unit = parse_expression_dimensions(arg,
                                                   variables,
                                                   orig_expr=orig_expr)
            # A "None" in func._arg_units means: No matter what unit
            if expected_unit is None:
                continue
            # A string means: same unit as other argument
            elif isinstance(expected_unit, str):
                arg_idx = func._arg_names.index(expected_unit)
                expected_unit = parse_expression_dimensions(
                    expr.args[arg_idx], variables, orig_expr=orig_expr)
                if not have_same_dimensions(arg_unit, expected_unit):
                    msg = (
                        f'Argument number {idx + 1} for function '
                        f'{expr.func.id} was supposed to have the '
                        f'same units as argument number {arg_idx + 1}, but '
                        f"'{NodeRenderer().render_node(arg)}' has unit "
                        f'{get_unit_for_display(arg_unit)}, while '
                        f"'{NodeRenderer().render_node(expr.args[arg_idx])}' "
                        f'has unit {get_unit_for_display(expected_unit)}')
                    raise DimensionMismatchError(msg)
            elif expected_unit == bool:
                if not is_boolean_expression(arg, variables):
                    rendered_arg = NodeRenderer().render_node(arg)
                    raise TypeError(
                        f"Argument number {idx + 1} for function "
                        f"'{expr.func.id}' was expected to be a boolean "
                        f"value, but is '{rendered_arg}'.")
            else:
                if not have_same_dimensions(arg_unit, expected_unit):
                    rendered_arg = NodeRenderer().render_node(arg)
                    arg_unit_dim = get_dimensions(arg_unit)
                    expected_unit_dim = get_dimensions(expected_unit)
                    msg = (
                        f"Argument number {idx+1} for function {expr.func.id} does "
                        f"not have the correct units. Expression '{rendered_arg}' "
                        f"has units ({arg_unit_dim}), but "
                        f"should be "
                        f"({expected_unit_dim}).")
                    raise DimensionMismatchError(msg)

        if func._return_unit == bool:
            return DIMENSIONLESS
        elif isinstance(func._return_unit, (Unit, int)):
            # Function always returns the same unit
            return getattr(func._return_unit, 'dim', DIMENSIONLESS)
        else:
            # Function returns a unit that depends on the arguments
            arg_units = [
                parse_expression_dimensions(arg,
                                            variables,
                                            orig_expr=orig_expr)
                for arg in expr.args
            ]
            return func._return_unit(*arg_units).dim

    elif expr.__class__ is ast.BinOp:
        op = expr.op.__class__.__name__
        left_dim = parse_expression_dimensions(expr.left,
                                               variables,
                                               orig_expr=orig_expr)
        right_dim = parse_expression_dimensions(expr.right,
                                                variables,
                                                orig_expr=orig_expr)
        if op in ['Add', 'Sub', 'Mod']:
            # dimensions should be the same
            if left_dim is not right_dim:
                op_symbol = {'Add': '+', 'Sub': '-', 'Mod': '%'}.get(op)
                left_str = NodeRenderer().render_node(expr.left)
                right_str = NodeRenderer().render_node(expr.right)
                left_unit = get_unit_for_display(left_dim)
                right_unit = get_unit_for_display(right_dim)
                error_msg = (
                    f"Expression '{left_str} {op_symbol} {right_str}' uses "
                    f"inconsistent units ('{left_str}' has unit "
                    f"{left_unit}; '{right_str}' "
                    f"has unit {right_unit}).")
                raise DimensionMismatchError(error_msg)
            u = left_dim
        elif op == 'Mult':
            u = left_dim * right_dim
        elif op == 'Div':
            u = left_dim / right_dim
        elif op == 'FloorDiv':
            if not (left_dim is DIMENSIONLESS and right_dim is DIMENSIONLESS):
                if left_dim is DIMENSIONLESS:
                    col_offset = expr.right.col_offset + 1
                else:
                    col_offset = expr.left.col_offset + 1
                raise SyntaxError(
                    "Floor division can only be used on "
                    "dimensionless values.",
                    ("<string>", expr.lineno, col_offset, orig_expr))
            u = DIMENSIONLESS
        elif op == 'Pow':
            if left_dim is DIMENSIONLESS and right_dim is DIMENSIONLESS:
                return DIMENSIONLESS
            n = _get_value_from_expression(expr.right, variables)
            u = left_dim**n
        else:
            raise SyntaxError(
                f"Unsupported operation {op}",
                ("<string>", expr.lineno,
                 getattr(expr.left, 'end_col_offset',
                         len(NodeRenderer().render_node(expr.left))) + 1,
                 orig_expr))
        return u
    elif expr.__class__ is ast.UnaryOp:
        op = expr.op.__class__.__name__
        # check validity of operand and get its unit
        u = parse_expression_dimensions(expr.operand,
                                        variables,
                                        orig_expr=orig_expr)
        if op == 'Not':
            return DIMENSIONLESS
        else:
            return u
    else:
        raise SyntaxError(
            f"Unsupported operation {str(expr.__class__.__name__)}",
            ("<string>", expr.lineno, expr.col_offset + 1, orig_expr))