def test_extend(self):
        lst = [1, 2, 3]

        def do_checks(node):
            self.assertEqual(len(node), len(lst))
            # Do an element wise comparison
            for x, y in zip(lst, node):
                self.assertEqual(x, y)

        node = List()
        node.extend(lst)
        do_checks(node)
        # Further extend
        node.extend(lst)
        self.assertEqual(len(node), len(lst) * 2)

        # Do an element wise comparison
        for i in range(len(lst)):
            self.assertEqual(lst[i], node[i])
            self.assertEqual(lst[i], node[i % len(lst)])

        # Now try after storing
        node = List()
        node.extend(lst)
        node.store()
        do_checks(node)
示例#2
0
def test_sqs_process(ce_sqs_code):
    prim = bulk('Au')
    structure = StructureData(ase=prim)
    chemical_symbols = List(list=[['Au', 'Pd']])

    # set up calculation
    inputs = {
        'code': ce_sqs_code,
        'structure': structure,
        'chemical_symbols': chemical_symbols,
        'pbc': List(list=[True, True, True]),
        'cutoffs': List(list=[5.0]),
        'max_size': Int(8),
        'include_smaller_cells': Bool(True),
        'n_steps': Int(2000),
        'target_concentrations': Dict(dict={
            'Au': 0.5,
            'Pd': 0.5
        }),
        'metadata': {
            'options': {
                'max_wallclock_seconds': 30,
            },
        }
    }

    result = run(CalculationFactory('ce.gensqs'), **inputs)
    assert 'sqs' in result
    assert 'cluster_space' in result

    sqs = result['sqs'].get_ase()
    assert sqs.get_number_of_atoms() == 8
    def test_extend(self):
        """Test extend() member function."""
        lst = [1, 2, 3]

        def do_checks(node):
            self.assertEqual(len(node), len(lst))
            # Do an element wise comparison
            for lst_, node_ in zip(lst, node):
                self.assertEqual(lst_, node_)

        node = List()
        node.extend(lst)
        do_checks(node)
        # Further extend
        node.extend(lst)
        self.assertEqual(len(node), len(lst) * 2)

        # Do an element wise comparison
        for i, _ in enumerate(lst):
            self.assertEqual(lst[i], node[i])
            self.assertEqual(lst[i], node[i % len(lst)])

        # Now try after storing
        node = List()
        node.extend(lst)
        node.store()
        do_checks(node)
示例#4
0
    def parse(self, **kwargs):  # pylint: disable=too-many-locals, inconsistent-return-statements
        """
        Parse outputs, store results in database.
        :returns: an exit code, if parsing fails (or nothing if parsing succeeds)
        """
        from aiida.orm import SinglefileData, List

        output_filename = self.node.get_option('output_filename')
        pickle_filename = self.node.inputs.data_file.value

        # Check that folder content is as expected
        files_retrieved = self.retrieved.list_object_names()
        files_expected = [output_filename, pickle_filename]
        # Note: set(A) <= set(B) checks whether A is a subset of B
        if not set(files_expected) <= set(files_retrieved):
            self.logger.error(
                f"Found files '{files_retrieved}', expected to find '{files_expected}'"
            )
            return self.exit_codes.ERROR_MISSING_OUTPUT_FILES

        # add output file
        self.logger.info(f"Parsing '{output_filename}'")
        with self.retrieved.open(output_filename, 'rb') as handle:
            output_node = SinglefileData(file=handle)
        self.out('log', output_node)

        # Parsing the pickle file
        self.logger.info(f"Parsing '{pickle_filename}'")
        # pickledata = pickle.load(self.retrieved.open(pickle_filename, 'rb'))
        with self.retrieved.open(pickle_filename, 'rb') as handle:
            pickledata = pickle.load(handle)
        try:
            coverage_data = [[a[0], list(map(float, a[1]))]
                             for a in pickledata['coverage_map']]
        except KeyError:
            return self.exit_codes.ERROR_NO_PICKLE_FILE

        ## Choose not to change the mpmath format
        ## the downside is that mpmath must then be present
        ## wherever this is being parsed
        rate_data = [[a[0], list(map(float, a[1]))]
                     for a in pickledata['rate_map']]
        production_data = [[a[0], list(map(float, a[1]))]
                           for a in pickledata['production_rate_map']]

        coverage_map = List(list=coverage_data)
        rate_map = List(list=rate_data)
        production_map = List(list=production_data)

        ## The three main outputs
        ## The solution to the kinetic model - coverages
        ## The rate and the production rate also provided
        self.out('coverage_map', coverage_map)
        self.out('rate_map', rate_map)
        self.out('production_rate_map', production_map)
    def test_append(self):
        def do_checks(node):
            self.assertEqual(len(node), 1)
            self.assertEqual(node[0], 4)

        node = List()
        node.append(4)
        do_checks(node)

        # Try the same after storing
        node = List()
        node.append(4)
        node.store()
        do_checks(node)
示例#6
0
    def return_results(self):
        """
        Attach the results to the output.
        """
        output_parameters = {}
        all_outputs = {}
        voro_accessible = {}
        # ZeoppCalculation Section
        all_outputs['zeopp_res'] = self.ctx.zeopp_res.outputs.output_parameters

        if all(self.ctx.should_run_visvoro):
            for key in self.ctx.components.keys():
                zeopp_label = "zeopp_{}".format(key)
                if all(self.ctx.should_run_visvoro):
                    if 'voro_accessible' in self.ctx[zeopp_label].outputs:
                        voro_accessible[key] = self.ctx[
                            zeopp_label].outputs.voro_accessible
            self.out('voro_accessible', voro_accessible)

        if all(self.ctx.should_run_comp) or all(self.ctx.should_run_pld):
            all_outputs['pm_out'] = self.ctx.pm_ev.outputs.output_parameters

            if all(self.ctx.should_run_comp):
                for key in self.inputs.components.keys():
                    ev_label = "Ev_vdw_{}_{}_{}".format(
                        self.ctx.label, key, key)
                    all_outputs['pm_ev_' + ev_label] = self.ctx.pm_ev.outputs[
                        'ev_output_file__' + ev_label]

            if all(self.ctx.should_run_pld):
                for key in self.inputs.components.keys():
                    ev_label = "Ev_vdw_{}_PLD_{}".format(self.ctx.label, key)
                    all_outputs['pm_ev_' + ev_label] = self.ctx.pm_ev.outputs[
                        'ev_output_file__' + ev_label]

        if 'ev_setting' in self.ctx.parameters.get_dict():
            all_outputs['ev_setting'] = List(
                list=self.ctx.parameters['ev_setting'])
        else:
            all_outputs['ev_setting'] = List(list=[90, 80, 50])

        output_parameters = extract_wrap_results(**all_outputs)

        # Finalizing the results and report!
        self.out("results", output_parameters)
        self.report(
            "Workchain completed successfully! | Result Dict is <{}>".format(
                self.outputs["results"].pk))
示例#7
0
    def define(cls, spec):
        super(ElasticWorkChain, cls).define(spec)

        spec.input('structure', valid_type=StructureData)
        spec.input('symmetric_strains_only', valid_type=Bool, default=lambda: Bool(True))
        spec.input('skip_input_relax', valid_type=Bool, default=lambda: Bool(False))
        spec.input('strain_magnitudes', valid_type=List,
                                        default=lambda: List(list=[-0.01,-0.005,0.005,0.01]))
        spec.input('clean_workdir', valid_type=Bool, default=lambda: Bool(True))
        spec.expose_inputs(PwRelaxWorkChain, namespace='initial_relax',
                           exclude=('structure', 'clean_workdir')) 
        spec.expose_inputs(PwRelaxWorkChain, namespace='elastic_relax',
                           exclude=('structure', 'clean_workdir'))

        spec.outline(
           cls.relax_input_structure,
           cls.get_relaxed_structure_stress,
           cls.get_deformed_structures,
           cls.compute_deformed_structures,
           cls.gather_computed_stresses,
           cls.fit_elastic_tensor, #NOTE: may wish to add a check of elastic constant quality
           cls.set_outputs
        )

        spec.output('equilibrium_structure', valid_type=StructureData)
        spec.output('elastic_outputs', valid_type=ArrayData)
        spec.output('symmetry_mapping', valid_type=Dict)

        spec.exit_code(401, 'ERROR_SUB_PROCESS_FAILED_RELAX',
                       message='one of the PwRelaxWorkChain subprocesses failed')
示例#8
0
    def _values_list_serializer(list_to_parse):
        '''
        Parses a list of objects to a list of node pks.
        This is done because aiida's List does not accept items with certain data structures
        (e.g. StructureData). In this way, we normalize the input to a list of pk, so that at
        each iteration we can access the value of the node.
        Note that if you modify/overwrite this method you should take a look to the `_next_val` method.

        :param list_to_parse: a list with aiida data object
        :return: an aiida List containing the pk of each element of list_to_parse
        '''

        parsed_list = []
        # Let's iterate over all values so that we can parse them all
        for obj in list_to_parse:

            # If the object is a python type, convert it to a node
            if not isinstance(obj, Node):
                obj = to_aiida_type(obj)

            # If it has just been converted to a node, or it was an unstored node
            # store it so that it gets a pk.
            if not obj.is_stored:
                obj.store()

            # Now that we are sure the node has a pk, we append it to the list
            parsed_list.append(obj.pk)

        return List(list=parsed_list)
示例#9
0
def choose_pressure_points(inp_param, geom, raspa_widom_out):
    """If 'presure_list' is not provide, model the isotherm as single-site langmuir and return the most important
    pressure points to evaluate for an isotherm, in a List.
    """
    if inp_param["pressure_list"]:
        pressure_points = inp_param["pressure_list"]
    else:
        khenry = list(
            raspa_widom_out["framework_1"]["components"].values())[0][
                'henry_coefficient_average']  #mol/kg/Pa
        b_value = khenry / geom['Estimated_saturation_loading'] * 1e5  #(1/bar)
        pressure_points = [inp_param['pressure_min']]
        while True:
            pold = pressure_points[-1]
            delta_p = min(
                inp_param['pressure_maxstep'],
                inp_param['pressure_precision'] *
                (b_value * pold**2 + 2 * pold + 1 / b_value))
            pnew = pold + delta_p
            if pnew <= inp_param['pressure_max']:
                pressure_points.append(pnew)
            else:
                pressure_points.append(inp_param['pressure_max'])
                break
    return List(list=pressure_points)
示例#10
0
def test_enum_process(ce_enum_code):
    StructureSet = DataFactory('ce.structures')

    from ase.build import bulk
    prim = bulk('Ag')
    structure = StructureData(ase=prim)
    chemical_symbols = List(list=[['Au', 'Pd']])

    # set up calculation
    inputs = {
        'code': ce_enum_code,
        'structure': structure,
        'chemical_symbols': chemical_symbols,
        'min_volume': Int(1),
        'max_volume': Int(4),
        'metadata': {
            'options': {
                'max_wallclock_seconds': 30
            },
        },
    }

    result = run(CalculationFactory('ce.genenum'), **inputs)
    structures = result['enumerate_structures']
    structure0 = structures.get_structure(0).get_ase()

    assert numpy.allclose(structure0.cell, prim.cell)
    assert numpy.allclose(structure0.positions, prim.positions)
    assert isinstance(structures, StructureSet)

    assert result['number_of_structures'] == 10
示例#11
0
def get_pressure_list(wc_params):
    """Gets the pressure list as the AiiDA List"""
    if wc_params["pressure_list"]:
        pressure_points = wc_params["pressure_list"]
        return  List(list=pressure_points)
    else:
        raise ValueError("pressure list is not provided properly!")
    def test_append(self):
        """Test append() member function."""

        def do_checks(node):
            self.assertEqual(len(node), 1)
            self.assertEqual(node[0], 4)

        node = List()
        node.append(4)
        do_checks(node)

        # Try the same after storing
        node = List()
        node.append(4)
        node.store()
        do_checks(node)
    def test_store_load():
        """Test load_node on just stored object."""
        node = List(list=[1, 2, 3])
        node.store()

        node_loaded = load_node(node.pk)
        assert node.get_list() == node_loaded.get_list()
def test_strained_fp_tb(
        configure_with_daemon,  # pylint: disable=unused-argument
        get_optimize_fp_tb_input,  # pylint: disable=redefined-outer-name
):
    """
    Run the DFT tight-binding optimization workflow with strain on an InSb sample for three strain values.
    """
    from aiida.engine import run
    from aiida.orm import Code
    from aiida.orm import Str, List
    from aiida_tbextraction.optimize_strained_fp_tb import OptimizeStrainedFirstPrinciplesTightBinding
    inputs = get_optimize_fp_tb_input

    inputs['strain_kind'] = Str('three_five.Biaxial001')
    inputs['strain_parameters'] = Str('InSb')

    strain_strengths = List()
    strain_list = [-0.1, 0., 0.1]
    strain_strengths.extend(strain_list)
    inputs['strain_strengths'] = strain_strengths

    inputs['symmetry_repr_code'] = Code.get_from_string('symmetry_repr')

    result = run(OptimizeStrainedFirstPrinciplesTightBinding, **inputs)
    print(result)
    for value in strain_list:
        suffix = '_{}'.format(value).replace('.', '_dot_')
        assert all(key + suffix in result
                   for key in ['cost_value', 'tb_model', 'window'])
示例#15
0
def get_optimize_fp_tb_input(get_fp_tb_input):  # pylint: disable=redefined-outer-name
    """
    Get the input for the first-principles tight-binding workflow with optimization.
    """
    from aiida.orm import List

    inputs = get_fp_tb_input
    inputs['initial_window'] = List(list=[-4.5, -4, 6.5, 16])

    return inputs
示例#16
0
    def report_wf(self):

        self.report('Final step. The workflow now will collect some info about the calculations in the "path" output node, and the relaxed scf calc')

        self.report('Relaxation scheme performed: {}'.format(self.ctx.conv_options['relaxation_scheme']))

        path = List(list=self.ctx.path).store()
        rel_scf = Int(self.ctx.scf.pk).store()
        self.out('path', path)

        self.out('relaxed_scf', rel_scf)
示例#17
0
    def return_results(self):
        '''
        Takes care of returning the results of the convergence to the user
        '''

        converged = Bool(self.converged)
        converged_index = Int(getattr(self.ctx, 'converged_index', -1))
        iteration_keys = List(list=list(self.ctx.iteration_keys))
        used_values = List(list=self.ctx.used_values)
        target_values = List(list=self.ctx.target_values)

        outputs = generate_convergence_results(iteration_keys, used_values, target_values, converged, converged_index)

        if converged:
            self.report(
                '\n\nConvergence has been reached! Converged parameters:'
                f'{outputs["converged_parameters"].get_dict()}\n'
            )
        else:
            self.report('\n\nWARNING: Workchain ended without finding convergence\n ')

        self.out_many(outputs)

        super().return_results()
示例#18
0
def cod_query(cod_values):
    """
    performs a search of any CIF structure that is provided
    according to the data coming from the input JSON request
    returns a list of structures and an error code if there is
    something wrong with things
    """
    qlist = cod_values.get_dict()
    importer = CodDbImporter()
    found = importer.query(**qlist)
    found_list = found.fetch_all()
    x = [i.source['id'] for i in found_list]
    # returned list of retrieved structures

    return List(list=x)
示例#19
0
def get_temperature_points(vlcparams):
    """Chooses the pressure points for VLCCWorkChain
    Current version: Only gets inital and final T with spacing.
    TODO: also read the reference data and get the info from there.
    """
    if vlcparams["temperature_list"]:
        T = vlcparams["temperature_list"]
    else:
        import numpy as np

        T_min = vlcparams['T_min']
        T_max = vlcparams['T_max']
        dT = vlcparams['dT']
        T = list(np.arange(T_min, T_max + 1, dT))

    return List(list=T)
示例#20
0
文件: gensqs.py 项目: unkcpz/aiida-ce
    def parse(self, **kwargs):
        """
        Parse outputs, store results in database.
        """

        # Check that folder content is as expected
        output_filename = self.node.get_option('output_filename')

        files_retrieved = self.retrieved.list_object_names()
        files_expected = [output_filename, 'sqs.out']
        if not set(files_expected) <= set(files_retrieved):
            self.logger.error("Found files '{}', expected to fine '{}'".format(
                files_retrieved, files_expected))
            return self.exit_codes.ERROR_MISSING_OUTPUT_FILES

        # add output node
        self.logger.info("Parsing sqs.out")
        with self.retrieved.open('sqs.out', 'rb') as handle:
            data = json.load(handle)

        cell = data['structure']['cell']
        positions = data['structure']['positions']
        pbc = data['structure']['pbc']
        atomic_numbers = data['structure']['atomic_numbers']
        structure = Atoms(cell=cell, positions=positions, numbers=atomic_numbers, pbc=pbc)
        sqs = StructureData(ase=structure)
        cluster_vector = data['cluster_vector']


        self.out('sqs', sqs)
        self.out('cluster_vector', List(list=cluster_vector))

        calc = self.node
        prim = calc.inputs.structure.get_ase()
        cs = {
            'cell': prim.cell.tolist(),
            'positions': prim.positions.tolist(),
            'pbc': prim.pbc.tolist(),
            'cutoffs': calc.inputs.cutoffs.get_list(),
            'chemical_symbols': calc.inputs.chemical_symbols.get_list(),
        }
        self.out('cluster_space', ClusterSpaceData(cs))

        return ExitCode(0)
示例#21
0
    def parse(self, **kwargs):
        """Receives in input a dictionary of retrieved nodes. Does all the logic here."""
        try:
            out_folder = self.retrieved
        except NotExistent:
            return self.exit_codes.ERROR_NO_RETRIEVED_FOLDER
        output_folder_name = self.node.process_class.OUTPUT_FOLDER

        if output_folder_name not in out_folder._repository.list_object_names():  # pylint: disable=protected-access
            return self.exit_codes.ERROR_NO_OUTPUT_FILE

        output_parameters = {}
        warnings = []
        ncomponents = len(self.node.inputs.parameters.get_dict()['Component'])
        for system_id, system_name in enumerate(self.node.get_extra('system_order')):
            # specify the name for the system
            system = "System_{}".format(system_id)
            fname = out_folder._repository.list_object_names(os.path.join(output_folder_name, system))[0]  # pylint: disable=protected-access

            # get absolute path of the output file
            output_abs_path = os.path.join(
                out_folder._repository._get_base_folder().abspath,  # pylint: disable=protected-access
                self.node.process_class.OUTPUT_FOLDER,
                system,
                fname)

            # Check for possible errors
            with open(output_abs_path) as fobj:
                content = fobj.read()
                if "Starting simulation" not in content:
                    return self.exit_codes.ERROR_SIMULATION_DID_NOT_START
                if "Simulation finished" not in content:
                    return self.exit_codes.TIMEOUT

            # parse output parameters and warnings
            parsed_parameters, parsed_warnings = parse_base_output(output_abs_path, system_name, ncomponents)
            output_parameters[system_name] = parsed_parameters
            warnings += parsed_warnings

        self.out("output_parameters", Dict(dict=output_parameters))
        self.out("warnings", List(list=warnings))

        return ExitCode(0)
示例#22
0
def run_slice():
    """
    Creates and runs the slice calculation.
    """
    builder = SliceCalculation.get_builder()
    builder.code = Code.get_from_string('tbmodels')

    builder.tb_model = get_singlefile_instance(
        description=u'InSb TB model', path='./reference_input/model.hdf5')

    # single-core on local machine
    builder.metadata.options = dict(resources=dict(num_machines=1,
                                                   tot_num_mpiprocs=1),
                                    withmpi=False)

    builder.slice_idx = List(list=[0, 3, 2, 1])

    result, pid = run_get_pk(builder)
    print('\nRan calculation with PID', pid)
    print('Result:\n', result)
示例#23
0
    def structures(self):
        """
        Creates structure data nodes with different Volume (lattice constants)
        """
        points = self.ctx.points
        step = self.ctx.step
        guess = self.ctx.guess
        startscale = guess - (points - 1) / 2 * step

        for point in range(points):
            self.ctx.scalelist.append(startscale + point * step)

        self.report('scaling factors which will be calculated:{}'.format(
            self.ctx.scalelist))
        self.ctx.org_volume = self.inputs.structure.get_cell_volume()

        struc_dict = eos_structures(self.inputs.structure,
                                    List(list=self.ctx.scalelist))
        # since cf this has to be a dict, we sort to assure ordering of scale
        self.ctx.structures = [struc_dict[key] for key in sorted(struc_dict)]
    def test_mutability(self):
        node = List()
        node.append(5)
        node.store()

        # Test all mutable calls are now disallowed
        with self.assertRaises(ModificationNotAllowed):
            node.append(5)
        with self.assertRaises(ModificationNotAllowed):
            node.extend([5])
        with self.assertRaises(ModificationNotAllowed):
            node.insert(0, 2)
        with self.assertRaises(ModificationNotAllowed):
            node.remove(0)
        with self.assertRaises(ModificationNotAllowed):
            node.pop()
        with self.assertRaises(ModificationNotAllowed):
            node.sort()
        with self.assertRaises(ModificationNotAllowed):
            node.reverse()
示例#25
0
def choose_pressure_points(wc_params):
    """If 'presure_list' is not provide, model the isotherm as single-site langmuir and return the most important
    pressure points to evaluate for an isotherm, in a List.
    """
    if wc_params["pressure_list"]:
        pressure_points = wc_params["pressure_list"]
    else:
        # Simply create a linear range of pressure points.
        # TODO: Make it possible to guess but needs benchmarking. #pylint: disable=fixme
        pressure_points = [wc_params['pressure_min']]
        delta_p = wc_params['pressure_precision']
        while True:
            pold = pressure_points[-1]
            pnew = pold + delta_p
            if pnew <= wc_params['pressure_max']:
                pressure_points.append(pnew)
            else:
                pressure_points.append(wc_params['pressure_max'])
                break
    return List(list=pressure_points)
示例#26
0
def test_calcfunction_band_gap(db_test_app, data_regression):
    data = get_test_data("edge_at_fermi")
    array = ArrayData()
    array.set_array("energies", np.array(data.energies))
    array.set_array("total", np.array(data.densities))
    outputs, node = calcfunction_band_gap.run_get_node(
        doss_array=array,
        doss_results=Dict(dict={
            "fermi_energy": data.fermi,
            "units": {
                "energy": "eV"
            }
        }),
        dtol=Float(1e-6),
        try_fshifts=List(list=data.try_fshifts),
        metadata={"store_provenance": True},
    )
    assert node.is_finished_ok, node.exit_status
    assert "results" in node.outputs
    data_regression.check(recursive_round(node.outputs.results.attributes, 4))
示例#27
0
    def define(cls, spec):
        #yapf: disable
        super(SqsCalculation, cls).define(spec)
        spec.input('metadata.options.resources', valid_type=dict, default={'num_machines':1, 'num_mpiprocs_per_machine':1}, non_db=True)
        spec.input('metadata.options.parser_name', valid_type=str, default='ce.gensqs', non_db=True)
        spec.input('metadata.options.input_filename', valid_type=str, default='aiida.json', non_db=True)
        spec.input('metadata.options.output_filename', valid_type=str, default='aiida.out', non_db=True)
        spec.input('structure', valid_type=StructureData, help='prototype structure to expand')
        spec.input('pbc', valid_type=List, default=List(list=[True, True, True]))
        spec.input('chemical_symbols', valid_type=List, help='An N elements list of which that each element is the possible symbol of the site.')
        spec.input('target_concentrations', valid_type=Dict, help='target concentration of elements of the sqs')
        spec.input('include_smaller_cells', valid_type=Bool, default=Bool(False), help='if false, only cell with >32 atoms will calculated')
        spec.input('cutoffs', valid_type=List, help='cutoffs of each NN distance')
        spec.input('max_size', valid_type=Int, default=Int(16), help='structures having up to max size times in the supercell')
        spec.input('n_steps', valid_type=Int, default=Int(10000), help='max annealing steps to run')

        spec.output('sqs', valid_type=StructureData, help='sqs structure')
        spec.output('cluster_vector', valid_type=List, help='cluster vector of sqs')
        spec.output('cluster_space', valid_type=ClusterSpaceData, help='cluster space used to generate sqs')

        spec.exit_code(100, 'ERROR_MISSING_OUTPUT_FILES', message='Calculation did not produce all expected output files.')
示例#28
0
def test_slice(
        configure_with_daemon,  # pylint: disable=unused-argument
        sample,
        get_tbmodels_process_builder,
        check_calc_ok):
    """
    Run the tbmodels.slice calculation and check that it outputs
    a tight-binding model.
    """
    from aiida.plugins import DataFactory
    from aiida.orm import List
    from aiida.engine import run_get_node

    builder = get_tbmodels_process_builder('tbmodels.slice')

    SinglefileData = DataFactory('singlefile')  # pylint: disable=invalid-name
    builder.tb_model = SinglefileData(file=sample('model.hdf5'))

    builder.slice_idx = List(list=[0, 3, 2, 1])

    output, calc = run_get_node(builder)
    check_calc_ok(calc)
    assert isinstance(output['tb_model'], SinglefileData)
示例#29
0
 def do_test(self):
     input_list = self.inputs.namespace.input
     assert isinstance(input_list, list)
     assert not isinstance(input_list, List)
     self.out('output', List(list=list(input_list)).store())
示例#30
0
def get_fp_tb_input(configure, get_insb_input, sample, request):  # pylint: disable=too-many-locals,unused-argument,redefined-outer-name,too-many-locals,too-many-statements
    """
    Returns the input for DFT-based tight-binding workflows (without optimization).
    """
    from aiida.plugins import DataFactory
    from aiida.orm import List, Bool
    from aiida.orm import Dict
    from aiida.orm import Code
    from aiida_tools.process_inputs import get_fullname
    from aiida_tbextraction.fp_run import VaspFirstPrinciplesRun
    from aiida_tbextraction.fp_run import SplitFirstPrinciplesRun
    from aiida_tbextraction.fp_run.reference_bands import VaspReferenceBands
    from aiida_tbextraction.fp_run.wannier_input import VaspWannierInput
    from aiida_tbextraction.model_evaluation import BandDifferenceModelEvaluation

    inputs = dict()

    vasp_inputs = get_insb_input

    vasp_subwf_inputs = {
        'code': vasp_inputs.pop('code'),
        'parameters': vasp_inputs.pop('parameters'),
        'calculation_kwargs': vasp_inputs.pop('calculation_kwargs'),
    }
    if request.param == 'split':
        inputs['fp_run_workflow'] = SplitFirstPrinciplesRun
        inputs['fp_run'] = dict()
        inputs['fp_run']['reference_bands_workflow'] = get_fullname(
            VaspReferenceBands)
        inputs['fp_run']['reference_bands'] = dict(merge_kpoints=Bool(True),
                                                   **vasp_subwf_inputs)
        inputs['fp_run']['wannier_input_workflow'] = get_fullname(
            VaspWannierInput)
        inputs['fp_run']['wannier_input'] = vasp_subwf_inputs

    else:
        assert request.param == 'combined'
        inputs['fp_run_workflow'] = VaspFirstPrinciplesRun
        inputs['fp_run'] = copy.copy(vasp_subwf_inputs)
        inputs['fp_run']['scf'] = {
            'parameters': Dict(dict=dict(isym=2)),
        }
        inputs['fp_run']['bands'] = {'merge_kpoints': Bool(True)}

    inputs.update(vasp_inputs)

    kpoints = orm.KpointsData()
    kpoints.set_kpoints_path([('G', (0, 0, 0), 'M', (0.5, 0.5, 0.5))])
    inputs['kpoints'] = kpoints
    kpoints_mesh = orm.KpointsData()
    kpoints_mesh.set_kpoints_mesh([2, 2, 2])
    inputs['kpoints_mesh'] = kpoints_mesh

    inputs['wannier_code'] = Code.get_from_string('wannier90')
    inputs['tbmodels_code'] = Code.get_from_string('tbmodels')

    inputs['model_evaluation_workflow'] = BandDifferenceModelEvaluation
    inputs['model_evaluation'] = {
        'bands_inspect_code': Code.get_from_string('bands_inspect')
    }

    wannier_parameters = orm.Dict(dict=dict(
        num_wann=14,
        num_bands=36,
        dis_num_iter=1000,
        num_iter=0,
        spinors=True,
    ))
    inputs['wannier_parameters'] = wannier_parameters
    wannier_projections = List()
    wannier_projections.extend(['In : s; px; py; pz', 'Sb : px; py; pz'])
    inputs['wannier_projections'] = wannier_projections
    inputs['wannier_calculation_kwargs'] = dict(options={
        'resources': {
            'num_machines': 1,
            'tot_num_mpiprocs': 1
        },
        'withmpi': False
    })
    inputs['symmetries'] = orm.SinglefileData(file=sample('symmetries.hdf5'))

    slice_reference_bands = List()
    slice_reference_bands.extend(list(range(12, 26)))
    inputs['slice_reference_bands'] = slice_reference_bands

    slice_tb_model = List()
    slice_tb_model.extend([0, 2, 3, 1, 5, 6, 4, 7, 9, 10, 8, 12, 13, 11])
    inputs['slice_tb_model'] = slice_tb_model

    return inputs