コード例 #1
0
def tst_optimise_amplitudes_single_group():
    """Check that very basic optimisation reaches expected values regardless of starting values"""

    n_grp = 1
    n_atm = 10

    for n_dst in [1, 5]:
        for g_amp_start in [0.0, 1.0, None]:
            for r_amp_mult in [0.0, 1.0]:

                # Get a random set of values to test
                target_uijs, target_weights, \
                  base_uijs, base_sels, dataset_hash, \
                  atomic_base, \
                  real_group_amps, real_atomic_amps \
                    = get_optimisation_test_set(n_grp, n_dst, n_atm, atomic_amplitude=r_amp_mult)

                if r_amp_mult == 0.0:
                    assert not real_atomic_amps.any()

                # Starting values
                if g_amp_start is not None:
                    base_amplitudes_start = flex.double(
                        n_grp * n_dst, g_amp_start)
                else:
                    base_amplitudes_start = flex.double(rran(n_grp * n_dst))

                # Since not given, amplitudes start at 1.0 for atomic
                atomic_amplitudes_start = flex.double(n_atm, 1.0)

                # Optimise
                opt = OptimiseAmplitudes(
                    target_uijs=target_uijs,
                    target_weights=target_weights,
                    base_amplitudes=base_amplitudes_start,
                    base_uijs=base_uijs,
                    base_atom_indices=base_sels,
                    base_dataset_hash=dataset_hash,
                    atomic_uijs=atomic_base,
                    atomic_amplitudes=None,
                    atomic_optimisation_mask=None,
                    optimisation_weights=None,
                    convergence_tolerance=1e-08,
                ).run()

                assert approx_equal(
                    list(opt.initial),
                    list(base_amplitudes_start) +
                    list(atomic_amplitudes_start), 1e-6)
                assert approx_equal(
                    list(opt.result),
                    list(real_group_amps) + list(real_atomic_amps), 1e-6)

    print('OK')
コード例 #2
0
def tst_optimise_amplitudes_multiple_groups_permuted_dataset_order():

    n_grp = 3
    n_dst = 5
    n_atm = 10

    # Get a random set of values to test
    target_uijs, target_weights, \
      base_uijs, base_sels, dataset_hash, \
      atomic_base, \
      real_group_amps, real_atomic_amps \
        = get_optimisation_test_set(n_grp, n_dst, n_atm, random_dataset_order=True)

    sorted_group_amps = resort_amplitudes_by_dataset_hash(
        n_grp=n_grp,
        n_dst=n_dst,
        dataset_hash=dataset_hash,
        real_group_amps=real_group_amps,
    )

    # Starting values
    base_amplitudes_start = flex.double(n_grp * n_dst, 0.0)
    atomic_amplitudes_start = flex.double(n_atm, 0.0)

    opt = OptimiseAmplitudes(
        target_uijs=target_uijs,
        target_weights=target_weights,
        base_amplitudes=base_amplitudes_start,
        base_uijs=base_uijs,
        base_atom_indices=base_sels,
        base_dataset_hash=dataset_hash,
        atomic_uijs=atomic_base,
        atomic_amplitudes=atomic_amplitudes_start,
        atomic_optimisation_mask=None,
        optimisation_weights=None,
        convergence_tolerance=1e-08,
    ).run()

    assert approx_equal(
        list(opt.initial),
        list(base_amplitudes_start) + list(atomic_amplitudes_start), 1e-6)
    assert approx_equal(list(opt.result),
                        list(sorted_group_amps) + list(real_atomic_amps), 1e-6)

    print('OK')
コード例 #3
0
def tst_optimise_amplitudes_multiple_groups_with_atomic():
    """Check that multiple partial groups optimise correctly with atomic level"""

    n_grp = 3
    n_dst = 5
    n_atm = 10

    # Get a random set of values to test
    target_uijs, target_weights, \
      base_uijs, base_sels, dataset_hash, \
      atomic_base, \
      real_group_amps, real_atomic_amps \
        = get_optimisation_test_set(n_grp, n_dst, n_atm)

    # Starting values
    base_amplitudes_start = flex.double(n_grp * n_dst, 0.0)
    atomic_amplitudes_start = flex.double(n_atm, 0.0)

    opt = OptimiseAmplitudes(
        target_uijs=target_uijs,
        target_weights=target_weights,
        base_amplitudes=base_amplitudes_start,
        base_uijs=base_uijs,
        base_atom_indices=base_sels,
        base_dataset_hash=dataset_hash,
        atomic_uijs=atomic_base,
        atomic_amplitudes=atomic_amplitudes_start,
        atomic_optimisation_mask=None,
        optimisation_weights=None,
        convergence_tolerance=1e-08,
    ).run()

    assert approx_equal(
        list(opt.initial),
        list(base_amplitudes_start) + list(atomic_amplitudes_start), 1e-6)
    assert approx_equal(list(opt.result),
                        list(real_group_amps) + list(real_atomic_amps), 1e-6)

    ########################################################
    # Check that shuffling the input lists has the expected effect
    ########################################################

    n_base = n_grp * n_dst
    i_perm = iran(n_base, size=n_base, replace=False)
    # Reorder the base elements by random permutation
    base_uijs = [base_uijs[i] for i in i_perm]
    base_sels = [base_sels[i] for i in i_perm]
    dataset_hash = flex.size_t([dataset_hash[i] for i in i_perm])
    base_amplitudes_start = flex.double(
        [base_amplitudes_start[i] for i in i_perm])

    opt = OptimiseAmplitudes(
        target_uijs=target_uijs,
        target_weights=target_weights,
        base_amplitudes=base_amplitudes_start,
        base_uijs=base_uijs,
        base_atom_indices=base_sels,
        base_dataset_hash=dataset_hash,
        atomic_uijs=atomic_base,
        atomic_amplitudes=atomic_amplitudes_start,
        atomic_optimisation_mask=None,
        optimisation_weights=None,
        convergence_tolerance=1e-08,
    ).run()

    assert approx_equal(
        list(opt.initial),
        list(base_amplitudes_start) + list(atomic_amplitudes_start), 1e-6)
    assert approx_equal(
        list(opt.result),
        list([real_group_amps[i] for i in i_perm]) + list(real_atomic_amps),
        1e-6)

    print('OK')