def test_check_for_export_format_version(self): """Test the check for the export format version.""" # Creating a folder for the import/export files export_file_tmp_folder = tempfile.mkdtemp() unpack_tmp_folder = tempfile.mkdtemp() try: struct = orm.StructureData() struct.store() filename = os.path.join(export_file_tmp_folder, 'export.tar.gz') export([struct], outfile=filename, silent=True) with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar: tar.extractall(unpack_tmp_folder) with open(os.path.join(unpack_tmp_folder, 'metadata.json'), 'r', encoding='utf8') as fhandle: metadata = json.load(fhandle) metadata['export_version'] = 0.0 with open(os.path.join(unpack_tmp_folder, 'metadata.json'), 'wb') as fhandle: json.dump(metadata, fhandle) with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar: tar.add(unpack_tmp_folder, arcname='') self.tearDownClass() self.setUpClass() with self.assertRaises(exceptions.IncompatibleArchiveVersionError): import_data(filename, silent=True) finally: # Deleting the created temporary folders shutil.rmtree(export_file_tmp_folder, ignore_errors=True) shutil.rmtree(unpack_tmp_folder, ignore_errors=True)
def test_group_import_existing(self, temp_dir): """ Testing what happens when I try to import a group that already exists in the database. This should raise an appropriate exception """ grouplabel = 'node_group_existing' # Create another user new_email = '[email protected]' user = orm.User(email=new_email) user.store() # Create a structure data node sd1 = orm.StructureData() sd1.user = user sd1.label = 'sd' sd1.store() # Create a group and add the data inside group = orm.Group(label=grouplabel) group.store() group.add_nodes([sd1]) # At this point we export the generated data filename = os.path.join(temp_dir, 'export1.tar.gz') export([group], outfile=filename, silent=True) self.clean_db() self.insert_data() # Creating a group of the same name group = orm.Group(label='node_group_existing') group.store() import_data(filename, silent=True) # The import should have created a new group with a suffix # I check for this: builder = orm.QueryBuilder().append( orm.Group, filters={'label': { 'like': grouplabel + '%' }}) self.assertEqual(builder.count(), 2) # Now I check for the group having one member, and whether the name is different: builder = orm.QueryBuilder() builder.append(orm.Group, filters={'label': { 'like': grouplabel + '%' }}, tag='g', project='label') builder.append(orm.StructureData, with_group='g') self.assertEqual(builder.count(), 1) # I check that the group name was changed: self.assertTrue(builder.all()[0][0] != grouplabel) # I import another name, the group should not be imported again import_data(filename, silent=True) builder = orm.QueryBuilder() builder.append(orm.Group, filters={'label': { 'like': grouplabel + '%' }}) self.assertEqual(builder.count(), 2)
def test_high_level_workflow_links(self, temp_dir): """ This test checks that all the needed links are correctly exported and imported. INPUT_CALC, INPUT_WORK, CALL_CALC, CALL_WORK, CREATE, and RETURN links connecting Data nodes and high-level Calculation and Workflow nodes: CalcJobNode, CalcFunctionNode, WorkChainNode, WorkFunctionNode """ high_level_calc_nodes = [['CalcJobNode', 'CalcJobNode'], ['CalcJobNode', 'CalcFunctionNode'], ['CalcFunctionNode', 'CalcJobNode'], ['CalcFunctionNode', 'CalcFunctionNode']] high_level_work_nodes = [['WorkChainNode', 'WorkChainNode'], ['WorkChainNode', 'WorkFunctionNode'], ['WorkFunctionNode', 'WorkChainNode'], ['WorkFunctionNode', 'WorkFunctionNode']] for calcs in high_level_calc_nodes: for works in high_level_work_nodes: self.reset_database() graph_nodes, _ = self.construct_complex_graph(calc_nodes=calcs, work_nodes=works) # Getting the input, create, return and call links builder = orm.QueryBuilder() builder.append(orm.Node, project='uuid') builder.append( orm.Node, project='uuid', edge_project=['label', 'type'], edge_filters={ 'type': { 'in': ( LinkType.INPUT_CALC.value, LinkType.INPUT_WORK.value, LinkType.CREATE.value, LinkType.RETURN.value, LinkType.CALL_CALC.value, LinkType.CALL_WORK.value ) } } ) self.assertEqual( builder.count(), 13, msg='Failed with c1={}, c2={}, w1={}, w2={}'.format(calcs[0], calcs[1], works[0], works[1]) ) export_links = builder.all() export_file = os.path.join(temp_dir, 'export.tar.gz') export(graph_nodes, outfile=export_file, silent=True, overwrite=True) self.reset_database() import_data(export_file, silent=True) import_links = get_all_node_links() export_set = [tuple(_) for _ in export_links] import_set = [tuple(_) for _ in import_links] self.assertSetEqual( set(export_set), set(import_set), msg='Failed with c1={}, c2={}, w1={}, w2={}'.format(calcs[0], calcs[1], works[0], works[1]) )
def test_complex_workflow_graph_export_sets(self, temp_dir): """Test ex-/import of individual nodes in complex graph""" for export_conf in range(0, 9): _, (export_node, export_target) = self.construct_complex_graph(export_conf) export_target_uuids = set(_.uuid for _ in export_target) export_file = os.path.join(temp_dir, 'export.aiida') export([export_node], filename=export_file, overwrite=True) export_node_str = str(export_node) self.clean_db() import_data(export_file) # Get all the nodes of the database builder = orm.QueryBuilder() builder.append(orm.Node, project='uuid') imported_node_uuids = set(_[0] for _ in builder.all()) self.assertSetEqual( export_target_uuids, imported_node_uuids, 'Problem in comparison of export node: ' + export_node_str + '\n' + 'Expected set: ' + str(export_target_uuids) + '\n' + 'Imported set: ' + str(imported_node_uuids) + '\n' + 'Difference: ' + str( export_target_uuids.symmetric_difference( imported_node_uuids)))
def test_check_for_export_format_version(aiida_profile, tmp_path): """Test the check for the export format version.""" # Creating a folder for the archive files export_file_tmp_folder = tmp_path / 'export_tmp' export_file_tmp_folder.mkdir() unpack_tmp_folder = tmp_path / 'unpack_tmp' unpack_tmp_folder.mkdir() aiida_profile.reset_db() struct = orm.StructureData() struct.store() filename = str(export_file_tmp_folder / 'export.aiida') export([struct], filename=filename, file_format='tar.gz') with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar: tar.extractall(unpack_tmp_folder) with (unpack_tmp_folder / 'metadata.json').open( 'r', encoding='utf8') as fhandle: metadata = json.load(fhandle) metadata['export_version'] = 0.0 with (unpack_tmp_folder / 'metadata.json').open('wb') as fhandle: json.dump(metadata, fhandle) with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar: tar.add(unpack_tmp_folder, arcname='') aiida_profile.reset_db() with pytest.raises(exceptions.IncompatibleArchiveVersionError): import_data(filename)
def test_workcalculation(self, temp_dir): """Test simple master/slave WorkChainNodes""" from aiida.common.links import LinkType master = orm.WorkChainNode() slave = orm.WorkChainNode() input_1 = orm.Int(3).store() input_2 = orm.Int(5).store() output_1 = orm.Int(2).store() master.add_incoming(input_1, LinkType.INPUT_WORK, 'input_1') slave.add_incoming(master, LinkType.CALL_WORK, 'CALL') slave.add_incoming(input_2, LinkType.INPUT_WORK, 'input_2') master.store() slave.store() output_1.add_incoming(master, LinkType.RETURN, 'RETURN') master.seal() slave.seal() uuids_values = [(v.uuid, v.value) for v in (output_1, )] filename1 = os.path.join(temp_dir, 'export1.tar.gz') export([output_1], outfile=filename1, silent=True) self.clean_db() self.insert_data() import_data(filename1, silent=True) for uuid, value in uuids_values: self.assertEqual(orm.load_node(uuid).value, value)
def test_input_and_create_links(self, temp_dir): """ Simple test that will verify that INPUT and CREATE links are properly exported and correctly recreated upon import. """ node_work = orm.CalculationNode() node_input = orm.Int(1).store() node_output = orm.Int(2).store() node_work.add_incoming(node_input, LinkType.INPUT_CALC, 'input') node_work.store() node_output.add_incoming(node_work, LinkType.CREATE, 'output') node_work.seal() export_links = get_all_node_links() export_file = os.path.join(temp_dir, 'export.aiida') export([node_output], filename=export_file) self.clean_db() import_data(export_file) import_links = get_all_node_links() export_set = [tuple(_) for _ in export_links] import_set = [tuple(_) for _ in import_links] self.assertSetEqual(set(export_set), set(import_set))
def test_critical_log_msg_and_metadata(self, temp_dir): """ Testing logging of critical message """ message = 'Testing logging of critical failure' calc = orm.CalculationNode() # Firing a log for an unstored node should not end up in the database calc.logger.critical(message) # There should be no log messages for the unstored object self.assertEqual(len(orm.Log.objects.all()), 0) # After storing the node, logs above log level should be stored calc.store() calc.seal() calc.logger.critical(message) # Store Log metadata log_metadata = orm.Log.objects.get(dbnode_id=calc.id).metadata export_file = os.path.join(temp_dir, 'export.tar.gz') export([calc], outfile=export_file, silent=True) self.reset_database() import_data(export_file, silent=True) # Finding all the log messages logs = orm.Log.objects.all() self.assertEqual(len(logs), 1) self.assertEqual(logs[0].message, message) self.assertEqual(logs[0].metadata, log_metadata)
def test_import(aiida_profile, benchmark, tmp_path, depth, breadth, num_objects): """Benchmark importing a provenance graph.""" aiida_profile.reset_db() root_node = Dict() recursive_provenance(root_node, depth=depth, breadth=breadth, num_objects=num_objects) root_uuid = root_node.uuid out_path = tmp_path / 'test.aiida' kwargs = get_export_kwargs(filename=str(out_path)) export([root_node], **kwargs) def _setup(): aiida_profile.reset_db() def _run(): import_data(str(out_path), silent=True) benchmark.pedantic(_run, setup=_setup, iterations=1, rounds=12, warmup_rounds=1) load_node(root_uuid)
def test_missing_node_repo_folder_export(self, temp_dir): """ Make sure `~aiida.tools.importexport.common.exceptions.ArchiveExportError` is raised during export when missing Node repository folder. Create and store a new Node and manually remove its repository folder. Attempt to export it and make sure `~aiida.tools.importexport.common.exceptions.ArchiveExportError` is raised, due to the missing folder. """ node = orm.CalculationNode().store() node.seal() node_uuid = node.uuid node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid) # pylint: disable=protected-access self.assertTrue( node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder' ) # Removing the Node's local repository folder shutil.rmtree(node_repo.abspath, ignore_errors=True) self.assertFalse( node_repo.exists(), msg='Newly created and stored Node should have had its repository folder removed' ) # Try to export, check it raises and check the raise message filename = os.path.join(temp_dir, 'export.aiida') with self.assertRaises(exceptions.ArchiveExportError) as exc: export([node], filename=filename) self.assertIn(f'Unable to find the repository folder for Node with UUID={node_uuid}', str(exc.exception)) self.assertFalse(os.path.exists(filename), msg='The archive file should not exist')
def test_exclude_logs_flag(self, temp_dir): """Test that the `include_logs` argument for `export` works.""" log_msg = 'Testing logging of critical failure' # Create node calc = orm.CalculationNode() calc.store() calc.seal() # Create log message calc.logger.critical(log_msg) # Save uuids prior to export calc_uuid = calc.uuid # Export, excluding logs export_file = os.path.join(temp_dir, 'export.tar.gz') export([calc], outfile=export_file, silent=True, include_logs=False) # Clean database and reimport exported data self.reset_database() import_data(export_file, silent=True) # Finding all the log messages import_calcs = orm.QueryBuilder().append(orm.CalculationNode, project=['uuid']).all() import_logs = orm.QueryBuilder().append(orm.Log, project=['uuid']).all() # There should be exactly: 1 orm.CalculationNode, 0 Logs self.assertEqual(len(import_calcs), 1) self.assertEqual(len(import_logs), 0) # Check it's the correct node self.assertEqual(str(import_calcs[0][0]), calc_uuid)
def test_calcfunction(self, temp_dir): """Test @calcfunction""" from aiida.engine import calcfunction from aiida.common.exceptions import NotExistent @calcfunction def add(a, b): """Add 2 numbers""" return {'res': orm.Float(a + b)} def max_(**kwargs): """select the max value""" max_val = max([(v.value, v) for v in kwargs.values()]) return {'res': max_val[1]} # I'm creating a bunch of numbers a, b, c, d, e = (orm.Float(i).store() for i in range(5)) # this adds the maximum number between bcde to a. res = add(a=a, b=max_(b=b, c=c, d=d, e=e)['res'])['res'] # These are the uuids that would be exported as well (as parents) if I wanted the final result uuids_values = [(a.uuid, a.value), (e.uuid, e.value), (res.uuid, res.value)] # These are the uuids that shouldn't be exported since it's a selection. not_wanted_uuids = [v.uuid for v in (b, c, d)] # At this point we export the generated data filename1 = os.path.join(temp_dir, 'export1.tar.gz') export([res], outfile=filename1, silent=True, return_backward=True) self.clean_db() self.insert_data() import_data(filename1, silent=True) # Check that the imported nodes are correctly imported and that the value is preserved for uuid, value in uuids_values: self.assertEqual(orm.load_node(uuid).value, value) for uuid in not_wanted_uuids: with self.assertRaises(NotExistent): orm.load_node(uuid)
def test_dangling_link_to_existing_db_node(self, temp_dir): """A dangling link that references a Node that is not included in the archive should `not` be importable""" struct = orm.StructureData() struct.store() struct_uuid = struct.uuid calc = orm.CalculationNode() calc.add_incoming(struct, LinkType.INPUT_CALC, 'input') calc.store() calc.seal() calc_uuid = calc.uuid filename = os.path.join(temp_dir, 'export.aiida') export([struct], filename=filename, file_format='tar.gz') unpack = SandboxFolder() with tarfile.open(filename, 'r:gz', format=tarfile.PAX_FORMAT) as tar: tar.extractall(unpack.abspath) with open(unpack.get_abs_path('data.json'), 'r', encoding='utf8') as fhandle: data = json.load(fhandle) data['links_uuid'].append({ 'output': calc.uuid, 'input': struct.uuid, 'label': 'input', 'type': LinkType.INPUT_CALC.value }) with open(unpack.get_abs_path('data.json'), 'wb') as fhandle: json.dump(data, fhandle) with tarfile.open(filename, 'w:gz', format=tarfile.PAX_FORMAT) as tar: tar.add(unpack.abspath, arcname='') # Make sure the CalculationNode is still in the database builder = orm.QueryBuilder().append(orm.CalculationNode, project='uuid') self.assertEqual( builder.count(), 1, msg= f'There should be a single CalculationNode, instead {builder.count()} has been found' ) self.assertEqual(builder.all()[0][0], calc_uuid) with self.assertRaises(DanglingLinkError): import_data(filename) # Using the flag `ignore_unknown_nodes` should import it without problems import_data(filename, ignore_unknown_nodes=True) builder = orm.QueryBuilder().append(orm.StructureData, project='uuid') self.assertEqual( builder.count(), 1, msg= f'There should be a single StructureData, instead {builder.count()} has been found' ) self.assertEqual(builder.all()[0][0], struct_uuid)
def test_missing_node_repo_folder_import(self, temp_dir): """ Make sure `~aiida.tools.importexport.common.exceptions.CorruptArchive` is raised during import when missing Node repository folder. Create and export a Node and manually remove its repository folder in the export file. Attempt to import it and make sure `~aiida.tools.importexport.common.exceptions.CorruptArchive` is raised, due to the missing folder. """ import tarfile from aiida.common.folders import SandboxFolder from aiida.tools.importexport.common.archive import extract_tar from aiida.tools.importexport.common.config import NODES_EXPORT_SUBFOLDER from aiida.tools.importexport.common.utils import export_shard_uuid node = orm.CalculationNode().store() node.seal() node_uuid = node.uuid node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid) # pylint: disable=protected-access self.assertTrue( node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder' ) # Export and reset db filename = os.path.join(temp_dir, 'export.aiida') export([node], filename=filename, file_format='tar.gz', silent=True) self.reset_database() # Untar export file, remove repository folder, re-tar node_shard_uuid = export_shard_uuid(node_uuid) node_top_folder = node_shard_uuid.split('/')[0] with SandboxFolder() as folder: extract_tar(filename, folder, silent=True, nodes_export_subfolder=NODES_EXPORT_SUBFOLDER) node_folder = folder.get_subfolder(os.path.join(NODES_EXPORT_SUBFOLDER, node_shard_uuid)) self.assertTrue( node_folder.exists(), msg="The Node's repository folder should still exist in the export file" ) # Removing the Node's repository folder from the export file shutil.rmtree( folder.get_subfolder(os.path.join(NODES_EXPORT_SUBFOLDER, node_top_folder)).abspath, ignore_errors=True ) self.assertFalse( node_folder.exists(), msg="The Node's repository folder should now have been removed in the export file" ) filename_corrupt = os.path.join(temp_dir, 'export_corrupt.aiida') with tarfile.open(filename_corrupt, 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as tar: tar.add(folder.abspath, arcname='') # Try to import, check it raises and check the raise message with self.assertRaises(exceptions.CorruptArchive) as exc: import_data(filename_corrupt, silent=True) self.assertIn( 'Unable to find the repository folder for Node with UUID={}'.format(node_uuid), str(exc.exception) )
def create( output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward, create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs ): """ Export subsets of the provenance graph to file for sharing. Besides Nodes of the provenance graph, you can export Groups, Codes, Computers, Comments and Logs. By default, the export file will include not only the entities explicitly provided via the command line but also their provenance, according to the rules outlined in the documentation. You can modify some of those rules using options of this command. """ from aiida.tools.importexport import export, ExportFileFormat from aiida.tools.importexport.common.exceptions import ArchiveExportError entities = [] if codes: entities.extend(codes) if computers: entities.extend(computers) if groups: entities.extend(groups) if nodes: entities.extend(nodes) kwargs = { 'input_calc_forward': input_calc_forward, 'input_work_forward': input_work_forward, 'create_backward': create_backward, 'return_backward': return_backward, 'call_calc_backward': call_calc_backward, 'call_work_backward': call_work_backward, 'include_comments': include_comments, 'include_logs': include_logs, 'overwrite': force } if archive_format == 'zip': export_format = ExportFileFormat.ZIP kwargs.update({'use_compression': True}) elif archive_format == 'zip-uncompressed': export_format = ExportFileFormat.ZIP kwargs.update({'use_compression': False}) elif archive_format == 'tar.gz': export_format = ExportFileFormat.TAR_GZIPPED try: export(entities, filename=output_file, file_format=export_format, **kwargs) except ArchiveExportError as exception: echo.echo_critical('failed to write the archive file. Exception: {}'.format(exception)) else: echo.echo_success('wrote the export archive file to {}'.format(output_file))
def test_empty_repo_folder_export(self, temp_dir): """Check a Node's empty repository folder is exported properly""" from aiida.common.folders import Folder from aiida.tools.importexport.dbexport import export_tree node = orm.Dict().store() node_uuid = node.uuid node_repo = RepositoryFolder(section=Repository._section_name, uuid=node_uuid) # pylint: disable=protected-access self.assertTrue( node_repo.exists(), msg='Newly created and stored Node should have had an existing repository folder' ) for filename, is_file in node_repo.get_content_list(only_paths=False): abspath_filename = os.path.join(node_repo.abspath, filename) if is_file: os.remove(abspath_filename) else: shutil.rmtree(abspath_filename, ignore_errors=False) self.assertFalse( node_repo.get_content_list(), msg='Repository folder should be empty, instead the following was found: {}'.format( node_repo.get_content_list() ) ) archive_variants = { 'archive folder': os.path.join(temp_dir, 'export_tree'), 'tar archive': os.path.join(temp_dir, 'export.tar.gz'), 'zip archive': os.path.join(temp_dir, 'export.zip') } export_tree([node], folder=Folder(archive_variants['archive folder']), silent=True) export([node], filename=archive_variants['tar archive'], file_format='tar.gz', silent=True) export([node], filename=archive_variants['zip archive'], file_format='zip', silent=True) for variant, filename in archive_variants.items(): self.reset_database() node_count = orm.QueryBuilder().append(orm.Dict, project='uuid').count() self.assertEqual(node_count, 0, msg='After DB reset {} Dict Nodes was (wrongly) found'.format(node_count)) import_data(filename, silent=True) builder = orm.QueryBuilder().append(orm.Dict, project='uuid') imported_node_count = builder.count() self.assertEqual( imported_node_count, 1, msg='After {} import a single Dict Node should have been found, ' 'instead {} was/were found'.format(variant, imported_node_count) ) imported_node_uuid = builder.all()[0][0] self.assertEqual( imported_node_uuid, node_uuid, msg='The wrong UUID was found for the imported {}: ' '{}. It should have been: {}'.format(variant, imported_node_uuid, node_uuid) )
def test_base_data_type_change(self, temp_dir): """ Base Data types type string changed Example: Bool: “data.base.Bool.” → “data.bool.Bool.” """ # Test content test_content = ('Hello', 6, -1.2399834e12, False) test_types = () for node_type in ['str', 'int', 'float', 'bool']: add_type = ('data.{}.{}.'.format(node_type, node_type.capitalize()),) test_types = test_types.__add__(add_type) # List of nodes to be exported export_nodes = [] # Create list of base type nodes nodes = [cls(val).store() for val, cls in zip(test_content, (orm.Str, orm.Int, orm.Float, orm.Bool))] export_nodes.extend(nodes) # Collect uuids for created nodes uuids = [n.uuid for n in nodes] # Create List() and insert already created nodes into it list_node = orm.List() list_node.set_list(nodes) list_node.store() list_node_uuid = list_node.uuid export_nodes.append(list_node) # Export nodes filename = os.path.join(temp_dir, 'export.aiida') export(export_nodes, filename=filename, silent=True) # Clean the database self.reset_database() # Import nodes again import_data(filename, silent=True) # Check whether types are correctly imported nlist = orm.load_node(list_node_uuid) # List for uuid, list_value, refval, reftype in zip(uuids, nlist.get_list(), test_content, test_types): # Str, Int, Float, Bool base = orm.load_node(uuid) # Check value/content self.assertEqual(base.value, refval) # Check type msg = "type of node ('{}') is not updated according to db schema v0.4".format(base.node_type) self.assertEqual(base.node_type, reftype, msg=msg) # List # Check value self.assertEqual(list_value, refval) # Check List type msg = "type of node ('{}') is not updated according to db schema v0.4".format(nlist.node_type) self.assertEqual(nlist.node_type, 'data.list.List.', msg=msg)
def test_mtime_of_imported_comments(self, temp_dir): """ Test mtime does not change for imported comments This is related to correct usage of `comment_mode` when importing. """ # Get user user = orm.User.objects.get_default() comment_content = 'You get what you give' # Create node calc = orm.CalculationNode().store() calc.seal() # Create comment orm.Comment(calc, user, comment_content).store() calc.store() # Save UUIDs and mtime calc_uuid = calc.uuid builder = orm.QueryBuilder().append(orm.Comment, project=['uuid', 'mtime']).all() comment_uuid = str(builder[0][0]) comment_mtime = builder[0][1] builder = orm.QueryBuilder().append(orm.CalculationNode, project=['uuid', 'mtime']).all() calc_uuid = str(builder[0][0]) calc_mtime = builder[0][1] # Export, reset database and reimport export_file = os.path.join(temp_dir, 'export.aiida') export([calc], filename=export_file, silent=True) self.reset_database() import_data(export_file, silent=True) # Retrieve node and comment builder = orm.QueryBuilder().append(orm.CalculationNode, tag='calc', project=['uuid', 'mtime']) builder.append(orm.Comment, with_node='calc', project=['uuid', 'mtime']) import_entities = builder.all()[0] self.assertEqual(len(import_entities), 4) # Check we have the correct amount of returned values import_calc_uuid = str(import_entities[0]) import_calc_mtime = import_entities[1] import_comment_uuid = str(import_entities[2]) import_comment_mtime = import_entities[3] # Check we have the correct UUIDs self.assertEqual(import_calc_uuid, calc_uuid) self.assertEqual(import_comment_uuid, comment_uuid) # Make sure the mtime is the same after import as it was before export self.assertEqual(import_comment_mtime, comment_mtime) self.assertEqual(import_calc_mtime, calc_mtime)
def setUpClass(cls, *args, **kwargs): """Only run to prepare an export file""" super().setUpClass() data = orm.Data() data.label = 'my_test_data_node' data.store() data.set_extra_many({'b': 2, 'c': 3}) cls.tmp_folder = tempfile.mkdtemp() cls.export_file = os.path.join(cls.tmp_folder, 'export.aiida') export([data], outfile=cls.export_file, silent=True)
def test_nodes_in_group(self, temp_dir): """ This test checks that nodes that belong to a specific group are correctly imported and exported. """ from aiida.common.links import LinkType # Create another user new_email = '[email protected]' user = orm.User(email=new_email) user.store() # Create a structure data node that has a calculation as output sd1 = orm.StructureData() sd1.user = user sd1.label = 'sd1' sd1.store() jc1 = orm.CalcJobNode() jc1.computer = self.computer jc1.set_option('resources', { 'num_machines': 1, 'num_mpiprocs_per_machine': 1 }) jc1.user = user jc1.label = 'jc1' jc1.add_incoming(sd1, link_type=LinkType.INPUT_CALC, link_label='link') jc1.store() jc1.seal() # Create a group and add the data inside gr1 = orm.Group(label='node_group') gr1.store() gr1.add_nodes([sd1, jc1]) gr1_uuid = gr1.uuid # At this point we export the generated data filename1 = os.path.join(temp_dir, 'export1.tar.gz') export([sd1, jc1, gr1], outfile=filename1, silent=True) n_uuids = [sd1.uuid, jc1.uuid] self.clean_db() self.insert_data() import_data(filename1, silent=True) # Check that the imported nodes are correctly imported and that # the user assigned to the nodes is the right one for uuid in n_uuids: self.assertEqual(orm.load_node(uuid).user.email, new_email) # Check that the exported group is imported correctly builder = orm.QueryBuilder() builder.append(orm.Group, filters={'uuid': {'==': gr1_uuid}}) self.assertEqual(builder.count(), 1, 'The group was not found.')
def test_input_code(self, temp_dir): """ This test checks that when a calculation is exported then the corresponding code is also exported. It also checks that the links are also in place after the import. """ code_label = 'test_code1' code = orm.Code() code.set_remote_computer_exec((self.computer, '/bin/true')) code.label = code_label code.store() code_uuid = code.uuid calc = orm.CalcJobNode() calc.computer = self.computer calc.set_option('resources', { 'num_machines': 1, 'num_mpiprocs_per_machine': 1 }) calc.add_incoming(code, LinkType.INPUT_CALC, 'code') calc.store() calc.seal() links_count = 1 export_links = get_all_node_links() export_file = os.path.join(temp_dir, 'export.aiida') export([calc], filename=export_file) self.clean_db() import_data(export_file) # Check that the code node is there self.assertEqual(orm.load_node(code_uuid).label, code_label) # Check that the link is in place import_links = get_all_node_links() self.assertListEqual(sorted(export_links), sorted(import_links)) self.assertEqual( len(export_links), links_count, 'Expected to find only one link from code to ' 'the calculation node before export. {} found.'.format( len(export_links))) self.assertEqual( len(import_links), links_count, 'Expected to find only one link from code to ' 'the calculation node after import. {} found.'.format( len(import_links)))
def cmd_export( group, max_atoms, max_atomic_number, number_species, partial_occupancies, include_duplicates, no_cod_hydrogen, sssp_only, filename ): """Pass.""" from aiida import orm from aiida.common.constants import elements from aiida.tools.importexport import export filters_elements = set() filters_structures = {'and': []} if no_cod_hydrogen: filters_structures['and'].append({'id': {'!in': get_cod_hydrogen_structure_ids()}}) if max_atoms is not None: filters_structures['and'].append({'attributes.sites': {'shorter': max_atoms + 1}}) if max_atomic_number: filters_elements = filters_elements.union({e['symbol'] for z, e in elements.items() if z > max_atomic_number}) if sssp_only: # All elements with atomic number of Radon or lower, with the exception of Astatine filters_elements = filters_elements.union({e['symbol'] for z, e in elements.items() if z > 86 or z == 85}) builder = orm.QueryBuilder().append( orm.Group, filters={'id': group.pk}, tag='group').append( orm.StructureData, with_group='group', filters=filters_structures) duplicates = [] if max_atomic_number or sssp_only: structures = [] for structure, in builder.iterall(): if all([element not in filters_elements for element in structure.get_symbols_set()]): structures.append(structure) else: structures = builder.all(flat=True) if include_duplicates: for structure in structures: dupes = [] structure_duplicates = structure.get_extra('duplicates') for database, uuids in structure_duplicates.items(): dupes.extend(uuids) for duplicate in dupes: if duplicate != structure.uuid: duplicates.append(orm.load_node(duplicate)) export(structures + duplicates, outfile=filename, create_backward=False, return_backward=False)
def test_calc_and_data_nodes_with_comments(self, temp_dir): """ Test comments for CalculatioNode and Data node are correctly ex-/imported """ # Create user, nodes, and comments user = orm.User.objects.get_default() calc_node = orm.CalculationNode().store() calc_node.seal() data_node = orm.Data().store() comment_one = orm.Comment(calc_node, user, self.comments[0]).store() comment_two = orm.Comment(calc_node, user, self.comments[1]).store() comment_three = orm.Comment(data_node, user, self.comments[2]).store() comment_four = orm.Comment(data_node, user, self.comments[3]).store() # Get values prior to export calc_uuid = calc_node.uuid data_uuid = data_node.uuid calc_comments_uuid = [c.uuid for c in [comment_one, comment_two]] data_comments_uuid = [c.uuid for c in [comment_three, comment_four]] # Export nodes export_file = os.path.join(temp_dir, 'export.tar.gz') export([calc_node, data_node], outfile=export_file, silent=True) # Clean database and reimport exported file self.reset_database() import_data(export_file, silent=True) # Get nodes and comments builder = orm.QueryBuilder() builder.append(orm.Node, tag='node', project=['uuid']) builder.append(orm.Comment, with_node='node', project=['uuid']) nodes_and_comments = builder.all() self.assertEqual(len(nodes_and_comments), len(self.comments)) for entry in nodes_and_comments: self.assertEqual(len(entry), 2) # 1 Node + 1 Comment import_node_uuid = str(entry[0]) import_comment_uuid = str(entry[1]) self.assertIn(import_node_uuid, [calc_uuid, data_uuid]) if import_node_uuid == calc_uuid: # Calc node comments self.assertIn(import_comment_uuid, calc_comments_uuid) else: # Data node comments self.assertIn(import_comment_uuid, data_comments_uuid)
def test_calc_of_structuredata(aiida_profile, tmp_path, file_format): """Simple ex-/import of CalcJobNode with input StructureData""" aiida_profile.reset_db() struct = orm.StructureData() struct.store() computer = orm.Computer( label='localhost-test', description='localhost computer set up by test manager', hostname='localhost-test', workdir=str(tmp_path / 'workdir'), transport_type='local', scheduler_type='direct') computer.store() computer.configure() calc = orm.CalcJobNode() calc.computer = computer calc.set_option('resources', { 'num_machines': 1, 'num_mpiprocs_per_machine': 1 }) calc.add_incoming(struct, link_type=LinkType.INPUT_CALC, link_label='link') calc.store() calc.seal() pks = [struct.pk, calc.pk] attrs = {} for pk in pks: node = orm.load_node(pk) attrs[node.uuid] = dict() for k in node.attributes.keys(): attrs[node.uuid][k] = node.get_attribute(k) filename = str(tmp_path / 'export.aiida') export([calc], filename=filename, file_format=file_format) aiida_profile.reset_db() import_data(filename) for uuid in attrs: node = orm.load_node(uuid) for k in attrs[uuid].keys(): assert attrs[uuid][k] == node.get_attribute(k)
def prepare_link_flags_export(nodes_to_export, test_data): """Helper function""" from aiida.common.links import GraphTraversalRules export_rules = GraphTraversalRules.EXPORT.value traversal_rules = {name: rule.default for name, rule in export_rules.items() if rule.toggleable} for export_file, rule_changes, expected_nodes in test_data.values(): traversal_rules.update(rule_changes) export(nodes_to_export[0], outfile=export_file, silent=True, **traversal_rules) for node_type in nodes_to_export[1]: if node_type in expected_nodes: expected_nodes[node_type].update(nodes_to_export[1][node_type]) else: expected_nodes[node_type] = nodes_to_export[1][node_type]
def test_exclude_comments_flag(self, temp_dir): """Test comments and associated commenting users are not exported when using `include_comments=False`.""" # Create users, node, and comments user_one = orm.User.objects.get_default() user_two = orm.User(email='[email protected]').store() node = orm.Data().store() orm.Comment(node, user_one, self.comments[0]).store() orm.Comment(node, user_one, self.comments[1]).store() orm.Comment(node, user_two, self.comments[2]).store() orm.Comment(node, user_two, self.comments[3]).store() # Get values prior to export users_email = [u.email for u in [user_one, user_two]] node_uuid = node.uuid # Check that node belongs to user_one self.assertEqual(node.user.email, users_email[0]) # Export nodes, excluding comments export_file = os.path.join(temp_dir, 'export.tar.gz') export([node], outfile=export_file, silent=True, include_comments=False) # Clean database and reimport exported file self.reset_database() import_data(export_file, silent=True) # Get node, users, and comments import_nodes = orm.QueryBuilder().append(orm.Node, project=['uuid']).all() import_comments = orm.QueryBuilder().append(orm.Comment, project=['uuid']).all() import_users = orm.QueryBuilder().append(orm.User, project=['email']).all() # There should be exactly: 1 Node, 0 Comments, 1 User self.assertEqual(len(import_nodes), 1) self.assertEqual(len(import_comments), 0) self.assertEqual(len(import_users), 1) # Check it's the correct user (and node) self.assertEqual(str(import_nodes[0][0]), node_uuid) self.assertEqual(str(import_users[0][0]), users_email[0])
def test_calc_of_structuredata(self, temp_dir): """Simple ex-/import of CalcJobNode with input StructureData""" from aiida.common.links import LinkType struct = orm.StructureData() struct.store() calc = orm.CalcJobNode() calc.computer = self.computer calc.set_option('resources', { 'num_machines': 1, 'num_mpiprocs_per_machine': 1 }) calc.add_incoming(struct, link_type=LinkType.INPUT_CALC, link_label='link') calc.store() calc.seal() pks = [struct.pk, calc.pk] attrs = {} for pk in pks: node = orm.load_node(pk) attrs[node.uuid] = dict() for k in node.attributes.keys(): attrs[node.uuid][k] = node.get_attribute(k) filename = os.path.join(temp_dir, 'export.aiida') export([calc], filename=filename, silent=True) self.clean_db() self.create_user() # NOTE: it is better to load new nodes by uuid, rather than assuming # that they will have the first 3 pks. In fact, a recommended policy in # databases is that pk always increment, even if you've deleted elements import_data(filename, silent=True) for uuid in attrs: node = orm.load_node(uuid) for k in attrs[uuid].keys(): self.assertEqual(attrs[uuid][k], node.get_attribute(k))
def test_node_process_type(self, temp_dir): """ Column `process_type` added to `Node` entity DB table """ from aiida.engine import run_get_node from tests.utils.processes import AddProcess # Node types node_type = 'process.workflow.WorkflowNode.' node_process_type = 'tests.utils.processes.AddProcess' # Run workflow inputs = {'a': orm.Int(2), 'b': orm.Int(3)} _, node = run_get_node(AddProcess, **inputs) # Save node uuid node_uuid = str(node.uuid) # Assert correct type and process_type strings self.assertEqual(node.node_type, node_type) self.assertEqual(node.process_type, node_process_type) # Export nodes filename = os.path.join(temp_dir, 'export.aiida') export([node], filename=filename) # Clean the database and reimport data self.clean_db() import_data(filename) # Retrieve node and check exactly one node is imported builder = orm.QueryBuilder() builder.append(orm.ProcessNode, project=['uuid']) self.assertEqual(builder.count(), 1) # Get node uuid and check it is the same as the one exported nodes = builder.all() imported_node_uuid = str(nodes[0][0]) self.assertEqual(imported_node_uuid, node_uuid) # Check imported node type and process type node = orm.load_node(imported_node_uuid) self.assertEqual(node.node_type, node_type) self.assertEqual(node.process_type, node_process_type)
def test_double_return_links_for_workflows(self, temp_dir): """ This test checks that double return links to a node can be exported and imported without problems, """ work1 = orm.WorkflowNode() work2 = orm.WorkflowNode().store() data_in = orm.Int(1).store() data_out = orm.Int(2).store() work1.add_incoming(data_in, LinkType.INPUT_WORK, 'input_i1') work1.add_incoming(work2, LinkType.CALL_WORK, 'call') work1.store() data_out.add_incoming(work1, LinkType.RETURN, 'return1') data_out.add_incoming(work2, LinkType.RETURN, 'return2') links_count = 4 work1.seal() work2.seal() uuids_wanted = set(_.uuid for _ in (work1, data_out, data_in, work2)) links_wanted = get_all_node_links() export_file = os.path.join(temp_dir, 'export.tar.gz') export([data_out, work1, work2, data_in], outfile=export_file, silent=True) self.reset_database() import_data(export_file, silent=True) uuids_in_db = [ str(uuid) for [uuid] in orm.QueryBuilder().append( orm.Node, project='uuid').all() ] self.assertListEqual(sorted(uuids_wanted), sorted(uuids_in_db)) links_in_db = get_all_node_links() self.assertListEqual(sorted(links_wanted), sorted(links_in_db)) # Assert number of links, checking both RETURN links are included self.assertEqual(len(links_wanted), links_count) # Before export self.assertEqual(len(links_in_db), links_count) # After import
def test_simple_import(self): """ This is a very simple test which checks that an export file with nodes that are not associated to a computer is imported correctly. In Django when such nodes are exported, there is an empty set for computers in the export file. In SQLA there is such a set only when a computer is associated with the exported nodes. When an empty computer set is found at the export file (when imported to an SQLA profile), the SQLA import code used to crash. This test demonstrates this problem. """ parameters = orm.Dict( dict={ 'Pr': { 'cutoff': 50.0, 'pseudo_type': 'Wentzcovitch', 'dual': 8, 'cutoff_units': 'Ry' }, 'Ru': { 'cutoff': 40.0, 'pseudo_type': 'SG15', 'dual': 4, 'cutoff_units': 'Ry' }, }).store() with tempfile.NamedTemporaryFile() as handle: nodes = [parameters] export(nodes, outfile=handle.name, overwrite=True, silent=True) # Check that we have the expected number of nodes in the database self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes)) # Clean the database and verify there are no nodes left self.clean_db() self.create_user() self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), 0) # After importing we should have the original number of nodes again import_data(handle.name, silent=True) self.assertEqual(orm.QueryBuilder().append(orm.Node).count(), len(nodes))