def test_operator_scaling(dom_eq_ran): """Check operator scaling against NumPy reference.""" if dom_eq_ran: mat = np.random.rand(3, 3) else: mat = np.random.rand(4, 3) op = MultiplyAndSquareOp(mat) xarr, x = noise_elements(op.domain) # Test a range of scalars (scalar multiplication could implement # optimizations for (-1, 0, 1)). scalars = [-1.432, -1, 0, 1, 3.14] for scalar in scalars: lscaled = OperatorLeftScalarMult(op, scalar) rscaled = OperatorRightScalarMult(op, scalar) assert not lscaled.is_linear assert not rscaled.is_linear # Explicit check_call(lscaled, x, scalar * mult_sq_np(mat, xarr)) check_call(rscaled, x, mult_sq_np(mat, scalar * xarr)) # Using operator overloading check_call(scalar * op, x, scalar * mult_sq_np(mat, xarr)) check_call(op * scalar, x, mult_sq_np(mat, scalar * xarr)) # Fail when scaling by wrong scalar type (complex number) wrongscalars = [1j, [1, 2], (1, 2)] for wrongscalar in wrongscalars: with pytest.raises(TypeError): OperatorLeftScalarMult(op, wrongscalar) with pytest.raises(TypeError): OperatorRightScalarMult(op, wrongscalar) with pytest.raises(TypeError): op * wrongscalar with pytest.raises(TypeError): wrongscalar * op
def test_nonlinear_scale(): A = np.random.rand(4, 3) x = np.random.rand(3) Aop = MultiplyAndSquareOp(A) xvec = Aop.domain.element(x) # Test a range of scalars (scalar multiplication could implement # optimizations for (-1, 0, 1)). scalars = [-1.432, -1, 0, 1, 3.14] for scale in scalars: lscaled = OperatorLeftScalarMult(Aop, scale) rscaled = OperatorRightScalarMult(Aop, scale) assert not lscaled.is_linear assert not rscaled.is_linear # Explicit check_call(lscaled, xvec, scale * mult_sq_np(A, x)) check_call(rscaled, xvec, mult_sq_np(A, scale * x)) # Using operator overloading check_call(scale * Aop, xvec, scale * mult_sq_np(A, x)) check_call(Aop * scale, xvec, mult_sq_np(A, scale * x)) # Fail when scaling by wrong scalar type (A complex number) wrongscalars = [1j, [1, 2], (1, 2)] for wrongscalar in wrongscalars: with pytest.raises(TypeError): print(OperatorLeftScalarMult(Aop, wrongscalar)) with pytest.raises(TypeError): print(OperatorRightScalarMult(Aop, wrongscalar)) with pytest.raises(TypeError): print(Aop * wrongscalar) with pytest.raises(TypeError): print(wrongscalar * Aop)