def test_infohook_chaining(transmon_ham_and_states):
    """Test that transmon_ham_and_states and info_hooks get chained together
    correctly. This tests a whole bunch of implementation details:

    - return values from multiple info_hooks combine in tuple
    - return value None (from modify_params_after_iter) gets ignored
    - shared_data gets passed along through multiple hooks
    - shared_data is cleared in each iteration
    """
    H, psi0, psi1 = transmon_ham_and_states
    obj = krotov.Objective(initial_state=psi0, target=psi1, H=H)
    tlist = np.array([0, 0.01, 0.02])
    pulse_options = {H[1][1]: krotov.PulseOptions(1.0)}
    stdout = io.StringIO()

    def adjust_lambda_a(**args):
        λₐ = args['lambda_vals'][0]
        args['lambda_vals'][0] *= 0.5
        if 'messages' not in args['shared_data']:
            args['shared_data']['messages'] = []
        args['shared_data']['messages'].append('λₐ: %s → %s' %
                                               (λₐ, args['lambda_vals'][0]))

    def print_fidelity(**args):
        F_re = np.average(np.array(args['tau_vals']).real)
        print("Iteration %d: \tF = %f" % (args['iteration'], F_re))
        return F_re

    def print_messages(**args):
        if 'messages' in args['shared_data']:
            message = "; ".join(
                [msg for msg in args['shared_data']['messages']])
            print("\tmsg: " + message)
            return message

    with contextlib.redirect_stdout(stdout):
        oct_result = krotov.optimize_pulses(
            [obj],
            pulse_options=pulse_options,
            tlist=tlist,
            propagator=krotov.propagators.expm,
            chi_constructor=krotov.functionals.chis_re,
            info_hook=krotov.info_hooks.chain(print_fidelity, print_messages),
            modify_params_after_iter=adjust_lambda_a,
            iter_stop=2,
        )

    assert len(oct_result.info_vals) == 3
    assert isinstance(oct_result.info_vals[1], tuple)
    assert len(oct_result.info_vals[1]) == 2
    assert abs(oct_result.info_vals[1][0] - 0.001978333994757067) < 1e-8
    assert oct_result.info_vals[1][1] == 'λₐ: 0.5 → 0.25'
    assert 'Iteration 0: \tF = 0.000000' in stdout.getvalue()
    assert 'msg: λₐ: 1.0 → 0.5' in stdout.getvalue()
    assert 'Iteration 1: \tF = 0.001978' in stdout.getvalue()
    assert 'msg: λₐ: 0.5 → 0.25' in stdout.getvalue()
def test_shape_validation():
    """Test that OCT pulse shapes are converted and verified correctly"""
    opt = krotov.PulseOptions(lambda_a=1)
    assert callable(opt.shape)
    assert opt.shape(0) == 1

    opt = krotov.PulseOptions(lambda_a=1, shape=0)
    assert callable(opt.shape)
    assert opt.shape(0) == 0

    opt = krotov.PulseOptions(lambda_a=1, shape=1)
    assert callable(opt.shape)
    assert opt.shape(0) == 1

    with pytest.raises(ValueError):
        krotov.PulseOptions(lambda_a=1, shape=2)

    with pytest.raises(ValueError):
        krotov.PulseOptions(lambda_a=1, shape=np.array([0, 0.5, 1, 0.5, 0]))
def test_initialize_krotov_controls():
    """Check that pulses and controls are initialized while preserving the
    correct boundary conditions.

    This is the point that the section "Time Discretization Schemes" in the
    documentation is making.

    Tests the resolution of #20.
    """

    T = 10
    blackman = qutip_callback(krotov.shapes.blackman, t_start=0, t_stop=T)
    H = ['H0', ['H1', blackman]]
    tlist = np.linspace(0, T, 10)
    pulse_options = {blackman: krotov.PulseOptions(lambda_a=1.0)}

    objectives = [
        krotov.Objective(
            initial_state=qutip.Qobj(),
            target=None,
            H=H,
        ),
    ]

    assert abs(blackman(0, None)) < 1e-15
    assert abs(blackman(T, None)) < 1e-15

    (guess_controls, guess_pulses, pulses_mapping, lambda_vals,
     shape_arrays) = (krotov.optimize._initialize_krotov_controls(
         objectives, pulse_options, tlist))

    assert isinstance(guess_controls[0], np.ndarray)
    assert len(guess_controls[0]) == len(tlist)
    assert abs(guess_controls[0][0]) < 1e-15
    assert abs(guess_controls[0][-1]) < 1e-15

    assert isinstance(guess_pulses[0], np.ndarray)
    assert len(guess_pulses[0]) == len(tlist) - 1
    assert abs(guess_pulses[0][0]) < 1e-15
    assert abs(guess_pulses[0][-1]) < 1e-15

    assert len(pulse_options) == 1

    assert len(pulses_mapping) == 1
    assert len(pulses_mapping[0]) == 1
    assert len(pulses_mapping[0][0]) == 1
    assert len(pulses_mapping[0][0][0]) == 1
    assert pulses_mapping[0][0][0][0] == 1

    assert len(lambda_vals) == 1
    assert lambda_vals[0] == 1.0

    assert len(shape_arrays) == 1
    assert isinstance(shape_arrays[0], np.ndarray)
    assert len(shape_arrays[0]) == len(tlist) - 1
Beispiel #4
0
def test_pulse_options_dict_to_list(caplog):
    """Test conversion of pulse_options"""

    u1, u2, u3 = np.array([]), np.array([]), np.array([])  # dummy controls
    controls = [u1, u2]

    assert u1 is not u2
    assert u2 is not u3

    pulse_options = {
        id(u1): krotov.PulseOptions(lambda_a=1.0),
        id(u2): krotov.PulseOptions(lambda_a=2.0)
    }

    pulse_options_list = pulse_options_dict_to_list(pulse_options, controls)
    assert len(pulse_options_list) == 2
    assert pulse_options_list[0] == pulse_options[id(u1)]
    assert pulse_options_list[1] == pulse_options[id(u2)]

    # check error for missing PulseOptions
    pulse_options = {id(u1): krotov.PulseOptions(lambda_a=1.0)}
    with pytest.raises(ValueError) as exc_info:
        pulse_options_dict_to_list(pulse_options, controls)
    assert 'does not have any associated pulse options' in str(exc_info.value)

    # check warning message for extra PulseOptions
    pulse_options = {
        id(u1): krotov.PulseOptions(lambda_a=1.0),
        id(u2): krotov.PulseOptions(lambda_a=1.0),
        id(u3): krotov.PulseOptions(lambda_a=1.0),
    }
    with caplog.at_level(logging.WARNING):
        pulse_options_dict_to_list(pulse_options, controls)
    assert 'extra elements' in caplog.text
def test_complex_control_rejection():
    """Test that complex controls are rejected"""
    H0 = qutip.Qobj(0.5 * np.diag([-1, 1]))
    H1 = qutip.Qobj(np.mat([[1, 2], [3, 4]]))

    psi0 = qutip.Qobj(np.array([1, 0]))
    psi1 = qutip.Qobj(np.array([0, 1]))

    def eps0(t, args):
        return 0.2 * np.exp(1j * t)

    def S(t):
        """Shape function for the field update"""
        return krotov.shapes.flattop(t,
                                     t_start=0,
                                     t_stop=5,
                                     t_rise=0.3,
                                     t_fall=0.3,
                                     func='sinsq')

    H = [H0, [H1, eps0]]

    objectives = [krotov.Objective(initial_state=psi0, target=psi1, H=H)]

    pulse_options = {H[1][1]: krotov.PulseOptions(lambda_a=5, shape=S)}

    tlist = np.linspace(0, 5, 500)

    with pytest.raises(ValueError) as exc_info:
        krotov.optimize_pulses(objectives,
                               pulse_options,
                               tlist,
                               propagator=krotov.propagators.expm,
                               chi_constructor=krotov.functionals.chis_re,
                               iter_stop=0)
    assert 'All controls must be real-valued' in str(exc_info.value)

    def S2(t):
        """Shape function for the field update"""
        return 2.0 * krotov.shapes.flattop(
            t, t_start=0, t_stop=5, t_rise=0.3, t_fall=0.3, func='sinsq')
def test_reject_invalid_shapes():
    """Test that invalid control shapes are rejected"""
    H0 = qutip.Qobj(0.5 * np.diag([-1, 1]))
    H1 = qutip.Qobj(np.mat([[1, 2], [3, 4]]))

    psi0 = qutip.Qobj(np.array([1, 0]))
    psi1 = qutip.Qobj(np.array([0, 1]))

    def eps0(t, args):
        return 0.2

    H = [H0, [H1, eps0]]

    objectives = [krotov.Objective(initial_state=psi0, target=psi1, H=H)]

    tlist = np.linspace(0, 5, 500)

    def S_complex(t):
        """Shape function for the field update"""
        return 1j * krotov.shapes.flattop(
            t, t_start=0, t_stop=5, t_rise=0.3, t_fall=0.3, func='sinsq')

    def S_negative(t):
        """Shape function for the field update"""
        return -1 * krotov.shapes.flattop(
            t, t_start=0, t_stop=5, t_rise=0.3, t_fall=0.3, func='sinsq')

    def S_large(t):
        """Shape function for the field update"""
        return 2 * krotov.shapes.flattop(
            t, t_start=0, t_stop=5, t_rise=0.3, t_fall=0.3, func='sinsq')

    with pytest.raises(ValueError) as exc_info:
        pulse_options = {
            H[1][1]: krotov.PulseOptions(lambda_a=5, shape=S_complex)
        }
        krotov.optimize_pulses(objectives,
                               pulse_options,
                               tlist,
                               propagator=krotov.propagators.expm,
                               chi_constructor=krotov.functionals.chis_re,
                               iter_stop=0)
    assert 'must be real-valued' in str(exc_info.value)

    with pytest.raises(ValueError) as exc_info:
        pulse_options = {
            H[1][1]: krotov.PulseOptions(lambda_a=5, shape=S_negative)
        }
        krotov.optimize_pulses(objectives,
                               pulse_options,
                               tlist,
                               propagator=krotov.propagators.expm,
                               chi_constructor=krotov.functionals.chis_re,
                               iter_stop=0)
    assert 'must have values in the range [0, 1]' in str(exc_info.value)

    with pytest.raises(ValueError) as exc_info:
        pulse_options = {
            H[1][1]: krotov.PulseOptions(lambda_a=5, shape=S_large)
        }
        krotov.optimize_pulses(objectives,
                               pulse_options,
                               tlist,
                               propagator=krotov.propagators.expm,
                               chi_constructor=krotov.functionals.chis_re,
                               iter_stop=0)
    assert 'must have values in the range [0, 1]' in str(exc_info.value)