def coerce(self, conn, function): function = super().coerce(conn, function) if function is None: function_info = FunctionInfo(function=None, size=None) elif isinstance(function, FunctionInfo): function_info = function elif is_array_like(function): array = np.array(function, copy=False, dtype=np.float64) self.check_array(conn, array) function_info = FunctionInfo(function=array, size=array.shape[1]) elif callable(function): function_info = FunctionInfo(function=function, size=self.determine_size( conn, function)) # TODO: necessary? super().coerce(conn, function_info) else: raise ValidationError("Invalid connection function type %r " "(must be callable or array-like)" % type(function).__name__, attr=self.name, obj=conn) self.check_function_can_be_applied(conn, function_info) return function_info
def __set__(self, conn, function): if function is None: function_info = FunctionInfo(function=None, size=None) elif isinstance(function, FunctionInfo): function_info = function elif is_array_like(function): array = np.array(function, copy=False, dtype=np.float64) self.validate_array(conn, array) function_info = FunctionInfo(function=array, size=array.shape[1]) elif callable(function): function_info = FunctionInfo(function=function, size=self.determine_size( conn, function)) self.validate_callable(conn, function_info) else: raise ValidationError("Invalid connection function type %r " "(must be callable or array-like)" % type(function).__name__, attr=self.name, obj=conn) self.validate(conn, function_info) self.data[conn] = function_info
def test_functionparam(): """FunctionParam must be a function, and accept one scalar argument.""" class Test: fp = params.FunctionParam("fp", default=None) inst = Test() assert inst.fp is None inst.fp = np.sin assert inst.fp.function is np.sin assert inst.fp.size == 1 inst.fp = FunctionInfo(np.cos, 1) assert inst.fp.function is np.cos assert inst.fp.size == 1 # Not OK: requires two args with pytest.raises(ValidationError, match="function.*must accept a single.*argu"): inst.fp = lambda x, y: x + y # Not OK: not a function with pytest.raises(ValidationError, match="function.*must be callable"): inst.fp = 0