def test_create_when_already_exists(self):
        with self.assertRaises(Exception):
            q = Project()
            q.new(os.path.join(tempfile.gettempdir(), uuid.uuid4().hex))

        with self.assertRaises(Exception):
            q = Project()
            q.open(os.path.join(tempfile.gettempdir(), uuid.uuid4().hex))
Example #2
0
    def test_creation(self):
        test_file = temp_proj_name
        with self.assertRaises(FileNotFoundError):
            p = Project()
            p.load(test_file)

        p = Project()
        p.new(test_file)
        p.conn.close()
class TestProject(TestCase):
    def setUp(self) -> None:
        self.temp_proj_folder = os.path.join(tempfile.gettempdir(),
                                             uuid.uuid4().hex)
        self.proj = Project()
        self.proj.new(self.temp_proj_folder)

    def tearDown(self) -> None:
        self.proj.close()

    def test_opening_wrong_folder(self):
        temp_proj_folder = os.path.join(tempfile.gettempdir(),
                                        uuid.uuid4().hex)
        self.proj.close()
        with self.assertRaises(FileNotFoundError):
            proj = Project()
            proj.open(temp_proj_folder)
        self.proj.open(self.temp_proj_folder)

    def test_create_when_already_exists(self):
        with self.assertRaises(Exception):
            q = Project()
            q.new(os.path.join(tempfile.gettempdir(), uuid.uuid4().hex))

        with self.assertRaises(Exception):
            q = Project()
            q.open(os.path.join(tempfile.gettempdir(), uuid.uuid4().hex))

    def test_creation(self):

        curr = self.proj.conn.cursor()
        curr.execute("""PRAGMA table_info(links);""")
        fields = curr.fetchall()
        fields = [x[1] for x in fields]

        if 'distance' not in fields:
            self.fail("Table LINKS was not created correctly")

        curr = self.proj.conn.cursor()
        curr.execute("""PRAGMA table_info(nodes);""")
        nfields = curr.fetchall()
        nfields = [x[1] for x in nfields]

        if 'is_centroid' not in nfields:
            self.fail('Table NODES was not created correctly')

    def test_close(self):

        _ = database_connection()

        self.proj.close()
        with self.assertRaises(FileExistsError):
            _ = database_connection()
Example #4
0
    def test_connection_with_new_project(self):
        temp_proj_folder = os.path.join(tempfile.gettempdir(),
                                        uuid.uuid4().hex)
        proj = Project()
        proj.new(temp_proj_folder)
        proj.close()

        proj = Project()
        proj.open(temp_proj_folder)
        conn = database_connection()
        cursor = conn.cursor()
        cursor.execute('select count(*) from links')

        self.assertEqual(cursor.fetchone()[0], 0,
                         "Returned more links thant it should have")
        proj.close()
Example #5
0
class CreatesTranspoNetProcedure(WorkerThread):
    def __init__(
        self,
        parentThread,
        output_file,
        node_layer,
        node_fields,
        link_layer,
        link_fields,
    ):
        WorkerThread.__init__(self, parentThread)

        self.output_file = output_file
        self.node_fields = node_fields
        self.link_fields = link_fields
        self.node_layer = node_layer
        self.link_layer = link_layer
        self.report = []
        self.project: Project

    def doWork(self):
        self.emit_messages(message="Initializing project", value=0, max_val=1)
        self.project = Project()
        self.project.new(self.output_file)
        self.project.conn.close()
        self.project.conn = qgis.utils.spatialite_connect(self.output_file)
        self.project.network.conn = self.project.conn
        self.project.network.create_empty_tables()

        # Add the required extra fields to the link layer
        self.emit_messages(
            message="Adding extra network data fields to database",
            value=0,
            max_val=1)
        self.additional_fields_to_layers('links', self.link_layer,
                                         self.link_fields)
        self.additional_fields_to_layers('nodes', self.node_layer,
                                         self.node_fields)

        conn = qgis.utils.spatialite_connect(self.output_file)

        self.transfer_layer_features("links", self.link_layer,
                                     self.link_fields)
        self.transfer_layer_features("nodes", self.node_layer,
                                     self.node_fields)

        self.emit_messages(message="Creating layer triggers",
                           value=0,
                           max_val=1)
        self.project.network.add_triggers()
        self.emit_messages(message="Spatial indices", value=0, max_val=1)
        self.project.network.add_spatial_index()
        self.ProgressText.emit("DONE")

    # Adds the non-standard fields to a layer
    def additional_fields_to_layers(self, table, layer, layer_fields):
        curr = self.project.conn.cursor()
        fields = layer.dataProvider().fields()
        string_fields = []

        curr.execute(f'PRAGMA table_info({table});')
        field_names = curr.fetchall()
        existing_fields = [f[1].lower() for f in field_names]

        for f in set(layer_fields.keys()):
            if f in existing_fields:
                continue
            field = fields[layer_fields[f]]
            field_length = field.length()
            if not field.isNumeric():
                field_type = "char"
                string_fields.append(f)
            else:
                if "Int" in field.typeName():
                    field_type = "INTEGER"
                else:
                    field_type = "REAL"
            try:
                sql = "alter table " + table + " add column " + f + " " + field_type + "(" + str(
                    field_length) + ")"
                curr.execute(sql)
                self.project.conn.commit()
            except:
                logger.error(sql)
                self.report.append("field " + str(f) + " could not be added")
        curr.close()
        return string_fields

    def transfer_layer_features(self, table, layer, layer_fields):
        self.emit_messages(message=f"Transferring features from {table} layer",
                           value=0,
                           max_val=layer.featureCount())
        curr = self.project.conn.cursor()

        field_titles = ", ".join(list(layer_fields.keys()))
        all_modes = set()
        for j, f in enumerate(layer.getFeatures()):
            self.emit_messages(value=j)
            attrs = []
            for k, val in layer_fields.items():
                if val < 0:
                    attrs.append("NULL")
                else:
                    attr_val = self.convert_data(f.attributes()[val])
                    if not str(attr_val).isnumeric():
                        attrs.append(f'"{attr_val}"')
                    else:
                        attrs.append(attr_val)

            attrs = ", ".join(attrs)
            geom = f.geometry().asWkt().upper()
            crs = str(layer.crs().authid().split(":")[1])

            sql = f"INSERT INTO {table} ({field_titles} , geometry)  VALUES ({attrs} , GeomFromText('{geom}', {crs}))"

            if table == 'links':
                all_modes.update(list(f.attributes()[layer_fields['modes']]))
            try:
                curr.execute(sql)
            except:
                logger.info(f'Failed inserting link {f.id()}')
                logger.info(sql)
                if f.id():
                    msg = f"feature with id {f.id()} could not be added to layer {table}"
                else:
                    msg = f"feature with no node id present. It could not be added to layer {table}"
                self.report.append(msg)

        # We check if all modes exist
        a = self.project.network.modes()
        for x in all_modes:
            if x not in a:
                par = [
                    f'"automatic_{x}"', f'"{x}"',
                    '"Mode automatically added during project creation from layers"'
                ]
                curr.execute(
                    f'INSERT INTO "modes" (mode_name, mode_id, description) VALUES({",".join(par)})'
                )
                logger.info(f'New mode inserted during project creation {x}')
        self.project.conn.commit()
        curr.close()

    def convert_data(self, value):
        if type(value) == NULL:
            return "NULL"
        else:
            return str(value).replace('"', "'")

    def emit_messages(self, message="", value=-1, max_val=-1):
        if len(message) > 0:
            self.ProgressText.emit(message)
        if value >= 0:
            self.ProgressValue.emit(value)
        if max_val >= 0:
            self.ProgressMaxValue.emit(max_val)
Example #6
0
class TestNetwork(TestCase):
    def setUp(self) -> None:
        self.file = os.path.join(gettempdir(),
                                 "aequilibrae_project_test.sqlite")
        self.project = Project()
        self.project.new(self.file)
        self.source = self.file
        self.file2 = os.path.join(gettempdir(),
                                  "aequilibrae_project_test2.sqlite")
        self.conn = sqlite3.connect(self.file2)
        self.conn = spatialite_connection(self.conn)
        self.network = Network(self)

    def tearDown(self) -> None:
        try:
            self.project.conn.close()
            os.unlink(self.file)

            self.conn.close()
            os.unlink(self.file2)
        except Exception as e:
            warn(f'Could not delete. {e.args}')

    def test_create_from_osm(self):
        thresh = 0.05
        if os.environ.get('GITHUB_WORKFLOW', 'ERROR') == 'Code coverage':
            thresh = 1.01

        if random() < thresh:
            # self.network.create_from_osm(west=153.1136245, south=-27.5095487, east=153.115, north=-27.5085, modes=["car"])
            self.project.network.create_from_osm(west=-112.185,
                                                 south=36.59,
                                                 east=-112.179,
                                                 north=36.60)
            curr = self.project.conn.cursor()

            curr.execute("""select count(*) from links""")
            lks = curr.fetchone()

            curr.execute("""select count(distinct osm_id) from links""")
            osmids = curr.fetchone()

            if osmids >= lks:
                self.fail("OSM links not broken down properly")

            curr.execute("""select count(*) from nodes""")
            nds = curr.fetchone()

            if lks > nds:
                self.fail(
                    "We imported more links than nodes. Something wrong here")
        else:
            print('Skipped check to not load OSM servers')

    def test_create_empty_tables(self):
        self.network.create_empty_tables()
        p = Parameters().parameters["network"]

        curr = self.conn.cursor()
        curr.execute("""PRAGMA table_info(links);""")
        fields = curr.fetchall()
        fields = [x[1] for x in fields]

        oneway = reduce(lambda a, b: dict(a, **b),
                        p["links"]["fields"]["one-way"])
        owf = list(oneway.keys())
        twoway = reduce(lambda a, b: dict(a, **b),
                        p["links"]["fields"]["two-way"])
        twf = []
        for k in list(twoway.keys()):
            twf.extend([f"{k}_ab", f"{k}_ba"])

        for f in owf + twf:
            if f not in fields:
                self.fail(f"Field {f} not added to links table")

        curr = self.conn.cursor()
        curr.execute("""PRAGMA table_info(nodes);""")
        nfields = curr.fetchall()
        nfields = [x[1] for x in nfields]

        flds = reduce(lambda a, b: dict(a, **b), p["nodes"]["fields"])
        flds = list(flds.keys())

        for f in flds:
            if f not in nfields:
                self.fail(f"Field {f} not added to nodes table")
Example #7
0
class TestNetwork(TestCase):
    def setUp(self) -> None:
        os.environ['PATH'] = os.path.join(
            gettempdir(), 'temp_data') + ';' + os.environ['PATH']
        self.proj_path = os.path.join(gettempdir(), uuid.uuid4().hex)
        copytree(siouxfalls_project, self.proj_path)
        self.siouxfalls = Project()
        self.siouxfalls.open(self.proj_path)
        self.proj_path2 = os.path.join(gettempdir(), uuid.uuid4().hex)

    def tearDown(self) -> None:
        self.siouxfalls.close()

    def test_create_from_osm(self):
        thresh = 0.05
        if os.environ.get('GITHUB_WORKFLOW', 'ERROR') == 'Code coverage':
            thresh = 1.01

        if random() < thresh:
            self.siouxfalls.close()
            self.project = Project()
            self.project.new(self.proj_path2)
            # self.network.create_from_osm(west=153.1136245, south=-27.5095487, east=153.115, north=-27.5085, modes=["car"])
            self.project.network.create_from_osm(west=-112.185,
                                                 south=36.59,
                                                 east=-112.179,
                                                 north=36.60)
            curr = self.project.conn.cursor()

            curr.execute("""select count(*) from links""")
            lks = curr.fetchone()[0]

            curr.execute("""select count(distinct osm_id) from links""")
            osmids = curr.fetchone()[0]

            if osmids == 0:
                warn('COULD NOT RETRIEVE DATA FROM OSM')
                return

            if osmids >= lks:
                self.fail("OSM links not broken down properly")

            curr.execute("""select count(*) from nodes""")
            nds = curr.fetchone()[0]

            if lks > nds:
                self.fail(
                    "We imported more links than nodes. Something wrong here")
            self.project.close()
            self.siouxfalls.open(self.proj_path)
        else:
            print('Skipped check to not load OSM servers')

    def test_count_centroids(self):
        items = self.siouxfalls.network.count_centroids()
        self.assertEqual(24, items, 'Wrong number of centroids found')

        nodes = self.siouxfalls.network.nodes
        node = nodes.get(1)
        node.is_centroid = 0
        node.save()

        items = self.siouxfalls.network.count_centroids()
        self.assertEqual(23, items, 'Wrong number of centroids found')

    def test_count_links(self):
        items = self.siouxfalls.network.count_links()
        self.assertEqual(76, items, 'Wrong number of links found')

    def test_count_nodes(self):
        items = self.siouxfalls.network.count_nodes()
        self.assertEqual(24, items, 'Wrong number of nodes found')