Ejemplo n.º 1
0
    def test_anti_aliasing(self, circuit, inpt, degree):
        """Test that the coefficients obtained through anti-aliasing are the
        same as the ones when we don't anti-alias at the correct degree."""
        coeffs_regular = coefficients(circuit,
                                      len(inpt),
                                      degree,
                                      lowpass_filter=False)
        coeffs_anti_aliased = coefficients(circuit, len(inpt), degree)

        assert np.allclose(coeffs_regular, coeffs_anti_aliased)
Ejemplo n.º 2
0
    def test_anti_aliasing_incorrect(self, circuit, inpt, degree,
                                     expected_coeffs):
        """Test that anti-aliasing function gives correct results when we ask for
        coefficients below the maximum degree."""
        coeffs_anti_aliased = coefficients(circuit,
                                           len(inpt),
                                           degree,
                                           lowpass_filter=True,
                                           filter_threshold=degree + 2)
        assert np.allclose(coeffs_anti_aliased, expected_coeffs)

        coeffs_regular = coefficients(circuit, len(inpt), degree)
        assert not np.allclose(coeffs_regular, expected_coeffs)
Ejemplo n.º 3
0
    def test_coefficients_torch_interface(self):
        """Test that coefficients are correctly computed when using the PyTorch interface."""
        torch = pytest.importorskip("torch")
        qnode = qml.QNode(self.circuit, self.dev, interface="torch")

        weights = torch.tensor([0.5, 0.2])

        obtained_result = coefficients(partial(qnode, weights), 2, 1)

        assert np.allclose(obtained_result, self.expected_result)
Ejemplo n.º 4
0
    def test_coefficients_jax_interface(self):
        """Test that coefficients are correctly computed when using the JAX interface."""
        jax = pytest.importorskip("jax")

        # Need to enable float64 support
        from jax.config import config

        remember = config.read("jax_enable_x64")
        config.update("jax_enable_x64", True)

        qnode = qml.QNode(self.circuit,
                          self.dev,
                          interface="jax",
                          diff_method="parameter-shift")

        weights = jax.numpy.array([0.5, 0.2])

        obtained_result = coefficients(partial(qnode, weights), 2, 1)

        assert np.allclose(obtained_result, self.expected_result)

        config.update("jax_enable_x64", remember)
Ejemplo n.º 5
0
    def test_single_variable_fourier_coeffs(self, freq_dict, expected_coeffs):
        degree = max(freq_dict.keys())
        partial_func = partial(fourier_function, freq_dict)
        coeffs = coefficients(partial_func, 1, degree)

        assert np.allclose(coeffs, expected_coeffs)
Ejemplo n.º 6
0
 def test_coefficients_two_param_circuits(self, circuit, inpt, degree,
                                          expected_coeffs):
     """Test that coeffs for a single instance of a single parameter match the by-hand
     results regardless of input degree (max degree is 1)."""
     coeffs = coefficients(circuit, len(inpt), degree)
     assert np.allclose(coeffs, expected_coeffs)