def test_initial_magnetic_moments_invalid(fixture_code, generate_structure, initial_magnetic_moments):
    """Test ``PwBaseWorkChain.get_builder_from_protocol`` with invalid ``initial_magnetic_moments`` keyword."""
    code = fixture_code('quantumespresso.pw')
    structure = generate_structure()

    with pytest.raises(
        ValueError, match=r'`initial_magnetic_moments` is specified but spin type `.*` is incompatible.'
    ):
        PwBaseWorkChain.get_builder_from_protocol(code, structure, initial_magnetic_moments=initial_magnetic_moments)

    with pytest.raises(ValueError):
        PwBaseWorkChain.get_builder_from_protocol(
            code, structure, initial_magnetic_moments=initial_magnetic_moments, spin_type=SpinType.COLLINEAR
        )
def test_spin_type(fixture_code, generate_structure):
    """Test ``PwBaseWorkChain.get_builder_from_protocol`` with ``spin_type`` keyword."""
    code = fixture_code('quantumespresso.pw')
    structure = generate_structure()

    with pytest.raises(NotImplementedError):
        for spin_type in [SpinType.NON_COLLINEAR, SpinType.SPIN_ORBIT]:
            PwBaseWorkChain.get_builder_from_protocol(code, structure, spin_type=spin_type)

    builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, spin_type=SpinType.COLLINEAR)
    parameters = builder.pw.parameters.get_dict()

    assert parameters['SYSTEM']['nspin'] == 2
    assert parameters['SYSTEM']['starting_magnetization'] == {'Si': 0.1}
def test_electronic_type(fixture_code, generate_structure):
    """Test ``PwBaseWorkChain.get_builder_from_protocol`` with ``electronic_type`` keyword."""
    code = fixture_code('quantumespresso.pw')
    structure = generate_structure()

    with pytest.raises(NotImplementedError):
        for electronic_type in [ElectronicType.AUTOMATIC]:
            PwBaseWorkChain.get_builder_from_protocol(code, structure, electronic_type=electronic_type)

    builder = PwBaseWorkChain.get_builder_from_protocol(code, structure, electronic_type=ElectronicType.INSULATOR)
    parameters = builder.pw.parameters.get_dict()

    assert parameters['SYSTEM']['occupations'] == 'fixed'
    assert 'degauss' not in parameters['SYSTEM']
    assert 'smearing' not in parameters['SYSTEM']
def test_default(fixture_code, generate_structure, data_regression, serialize_builder):
    """Test ``PwBaseWorkChain.get_builder_from_protocol`` for the default protocol."""
    code = fixture_code('quantumespresso.pw')
    structure = generate_structure()
    builder = PwBaseWorkChain.get_builder_from_protocol(code, structure)

    assert isinstance(builder, ProcessBuilder)
    data_regression.check(serialize_builder(builder))
def test_initial_magnetic_moments(fixture_code, generate_structure):
    """Test ``PwBaseWorkChain.get_builder_from_protocol`` with ``initial_magnetic_moments`` keyword."""
    code = fixture_code('quantumespresso.pw')
    structure = generate_structure()

    initial_magnetic_moments = {'Si': 1.0}
    builder = PwBaseWorkChain.get_builder_from_protocol(
        code, structure, initial_magnetic_moments=initial_magnetic_moments, spin_type=SpinType.COLLINEAR
    )
    parameters = builder.pw.parameters.get_dict()

    assert parameters['SYSTEM']['nspin'] == 2
    assert parameters['SYSTEM']['starting_magnetization'] == {'Si': 0.25}
def test_metadata_overrides(fixture_code, generate_structure):
    """Test that pw metadata is correctly passed through overrides."""
    code = fixture_code('quantumespresso.pw')
    structure = generate_structure()

    overrides = {'pw': {'metadata': {'options': {'resources': {'num_machines': 1e90}, 'max_wallclock_seconds': 1}}}}
    builder = PwBaseWorkChain.get_builder_from_protocol(
        code,
        structure,
        overrides=overrides,
    )
    metadata = builder.pw.metadata

    assert metadata['options']['resources']['num_machines'] == 1e90
    assert metadata['options']['max_wallclock_seconds'] == 1
def test_parallelization_overrides(fixture_code, generate_structure):
    """Test that pw parallelization settings are correctly passed through overrides."""
    code = fixture_code('quantumespresso.pw')
    structure = generate_structure()

    overrides = {'pw': {'parallelization': {'npool': 4, 'ndiag': 12}}}
    builder = PwBaseWorkChain.get_builder_from_protocol(
        code,
        structure,
        overrides=overrides,
    )
    parallelization = builder.pw.parallelization

    assert parallelization['npool'] == 4
    assert parallelization['ndiag'] == 12
def test_get_default_protocol():
    """Test ``PwBaseWorkChain.get_default_protocol``."""
    assert PwBaseWorkChain.get_default_protocol() == 'moderate'
def test_get_available_protocols():
    """Test ``PwBaseWorkChain.get_available_protocols``."""
    protocols = PwBaseWorkChain.get_available_protocols()
    assert sorted(protocols.keys()) == ['fast', 'moderate', 'precise']
    all('description' in protocol for protocol in protocols)
Esempio n. 10
0
def main(options):

    ###### setting the lattice structure ######

    alat = 2.4955987320  # Angstrom
    the_cell = [[1.000000 * alat, 0.000000, 0.000000],
                [-0.500000 * alat, 0.866025 * alat, 0.000000],
                [0.000000, 0.000000, 6.4436359260]]

    atoms = Atoms('BNNB', [(1.2477994910, 0.7204172280, 0.0000000000),
                           (-0.0000001250, 1.4408346720, 0.0000000000),
                           (1.2477994910, 0.7204172280, 3.2218179630),
                           (-0.0000001250, 1.4408346720, 3.2218179630)],
                  cell=[1, 1, 1])
    atoms.set_cell(the_cell, scale_atoms=False)
    atoms.set_pbc([True, True, True])

    StructureData = DataFactory('structure')
    structure = StructureData(ase=atoms)

    ###### setting the kpoints mesh ######

    KpointsData = DataFactory('array.kpoints')
    kpoints = KpointsData()
    kpoints.set_kpoints_mesh([6, 6, 2])

    ###### setting the scf parameters ######

    Dict = DataFactory('dict')
    params_scf = {
        'CONTROL': {
            'calculation': 'scf',
            'verbosity': 'high',
            'wf_collect': True
        },
        'SYSTEM': {
            'ecutwfc': 130.,
            'force_symmorphic': True,
            'nbnd': 20
        },
        'ELECTRONS': {
            'mixing_mode': 'plain',
            'mixing_beta': 0.7,
            'conv_thr': 1.e-8,
            'diago_thr_init': 5.0e-6,
            'diago_full_acc': True
        },
    }

    parameter_scf = Dict(dict=params_scf)

    ###### creation of the workchain ######

    builder = PwBaseWorkChain.get_builder()
    builder.pw.structure = structure
    builder.pw.parameters = parameter_scf
    builder.kpoints = kpoints
    builder.pw.metadata.options.max_wallclock_seconds = \
            options['max_wallclock_seconds']
    builder.pw.metadata.options.resources = \
            dict = options['resources']

    if 'queue_name' in options:
        builder.pw.metadata.options.queue_name = options['queue_name']

    if 'qos' in options:
        builder.pw.metadata.options.qos = options['qos']

    if 'account' in options:
        builder.metadata.options.account = options['account']

    builder.pw.metadata.options.prepend_text = options['prepend_text']

    builder.pw.code = load_code(options['code_id'])
    builder.pw.pseudos = validate_and_prepare_pseudos_inputs(
        builder.pw.structure, pseudo_family=Str(options['pseudo_family']))

    return builder