예제 #1
0
def test_query_check_on_pumpstation(session):
    connection_node1 = factories.ConnectionNodeFactory()
    connection_node2 = factories.ConnectionNodeFactory()
    factories.ManholeFactory(connection_node=connection_node1,
                             bottom_level=1.0)
    factories.ManholeFactory(connection_node=connection_node2,
                             bottom_level=-1.0)
    pumpstation_wrong = factories.PumpstationFactory(
        connection_node_start=connection_node1, lower_stop_level=0.0)
    factories.PumpstationFactory(connection_node_start=connection_node2,
                                 lower_stop_level=2.0)

    query = (
        Query(models.Pumpstation).join(
            models.ConnectionNode,
            models.Pumpstation.connection_node_start_id ==
            models.ConnectionNode.id,  # noqa: E501
        ).join(
            models.Manhole,
            models.Manhole.connection_node_id == models.ConnectionNode.id,
        ).filter(
            models.Pumpstation.lower_stop_level <=
            models.Manhole.bottom_level, ))
    check = QueryCheck(
        column=models.Pumpstation.lower_stop_level,
        invalid=query,
        message="Pumpstation lower_stop_level should be higher than Manhole "
        "bottom_level",
    )
    invalids = check.get_invalid(session)
    assert len(invalids) == 1
    assert invalids[0].id == pumpstation_wrong.id
예제 #2
0
def test_type_check_varchar(session):
    if session.bind.name == "postgresql":
        pytest.skip("type checks not working on postgres")
    factories.ManholeFactory(code="abc")
    factories.ManholeFactory(code=123)

    type_check = TypeCheck(models.Manhole.code)
    invalid_rows = type_check.get_invalid(session)

    assert len(invalid_rows) == 0
예제 #3
0
def test_type_check(session):
    if session.bind.name == "postgresql":
        pytest.skip("type checks not working on postgres")
    factories.ManholeFactory(zoom_category=123)
    factories.ManholeFactory(zoom_category=456)

    type_check = TypeCheck(models.Manhole.zoom_category)
    invalid_rows = type_check.get_invalid(session)

    assert len(invalid_rows) == 0
예제 #4
0
def test_type_check_integer(session):
    if session.bind.name == "postgresql":
        pytest.skip("type checks not working on postgres")
    factories.ManholeFactory(zoom_category=123)
    factories.ManholeFactory(zoom_category=None)
    m1 = factories.ManholeFactory(zoom_category="abc")
    m2 = factories.ManholeFactory(zoom_category=1.23)

    type_check = TypeCheck(models.Manhole.zoom_category)
    invalid_rows = type_check.get_invalid(session)

    assert len(invalid_rows) == 2
    invalid_ids = [invalid.id for invalid in invalid_rows]
    assert m1.id in invalid_ids
    assert m2.id in invalid_ids
예제 #5
0
def test_type_check_float_can_store_integer(session):
    if session.bind.name == "postgresql":
        pytest.skip("type checks not working on postgres")
    factories.ManholeFactory(surface_level=1.3)
    factories.ManholeFactory(surface_level=None)
    factories.ManholeFactory(surface_level=1)
    m1 = factories.ManholeFactory(zoom_category="abc")

    type_check = TypeCheck(models.Manhole.zoom_category)
    invalid_rows = type_check.get_invalid(session)
    valid_rows = type_check.get_valid(session)

    assert len(valid_rows) == 3
    assert len(invalid_rows) == 1
    invalid_ids = [invalid.id for invalid in invalid_rows]
    assert m1.id in invalid_ids
예제 #6
0
def test_threedi_db_and_factories(threedi_db):
    """Test to ensure that the threedi_db and factories use the same
    session object."""
    session = threedi_db.get_session()
    factories.ManholeFactory()
    q = session.query(models.Manhole)
    assert q.count() == 1
예제 #7
0
def test_fk_check_null_fk(session):
    conn_node = factories.ConnectionNodeFactory()
    factories.ManholeFactory.create_batch(5, manhole_indicator=conn_node.id)
    factories.ManholeFactory(manhole_indicator=None)

    fk_check = ForeignKeyCheck(models.ConnectionNode.id,
                               models.Manhole.manhole_indicator)
    invalid_rows = fk_check.get_invalid(session)
    assert len(invalid_rows) == 0
예제 #8
0
def test_fk_check_missing_fk(session):
    conn_node = factories.ConnectionNodeFactory()
    factories.ManholeFactory.create_batch(5, manhole_indicator=conn_node.id)
    missing_fk = factories.ManholeFactory(manhole_indicator=-1)

    fk_check = ForeignKeyCheck(models.ConnectionNode.id,
                               models.Manhole.manhole_indicator)
    invalid_rows = fk_check.get_invalid(session)
    assert len(invalid_rows) == 1
    assert invalid_rows[0].id == missing_fk.id
예제 #9
0
def test_conditional_check_storage_area(session):
    # if connection node is a manhole, then the storage area of the
    # connection_node must be > 0
    factories.ConnectionNodeFactory(storage_area=5)
    factories.ConnectionNodeFactory(storage_area=-3)
    conn_node_manhole_valid = factories.ConnectionNodeFactory(storage_area=4)
    conn_node_manhole_invalid = factories.ConnectionNodeFactory(
        storage_area=-5)
    factories.ManholeFactory(connection_node=conn_node_manhole_valid)
    factories.ManholeFactory(connection_node=conn_node_manhole_invalid)

    query = (Query(models.ConnectionNode).join(
        models.Manhole).filter(models.ConnectionNode.storage_area <= 0))
    query_check = QueryCheck(column=models.ConnectionNode.storage_area,
                             invalid=query,
                             message="")

    invalids = query_check.get_invalid(session)
    assert len(invalids) == 1
    assert invalids[0].id == conn_node_manhole_invalid.id
예제 #10
0
def test_unique_check_duplicate_value(session):
    manholes = factories.ManholeFactory.create_batch(
        5, zoom_category=factory.Sequence(lambda n: n))
    duplicate_manhole = factories.ManholeFactory(
        zoom_category=manholes[0].zoom_category)

    unique_check = UniqueCheck(models.Manhole.zoom_category)
    invalid_rows = unique_check.get_invalid(session)

    assert len(invalid_rows) == 2
    invalid_ids = [invalid.id for invalid in invalid_rows]
    assert manholes[0].id in invalid_ids
    assert duplicate_manhole.id in invalid_ids
예제 #11
0
def test_query_check_manhole_drain_level_calc_type_2(session):
    # manhole.drain_level can be null, but if manhole.calculation_type == 2 (Connected)
    # then manhole.drain_level >= manhole.bottom_level
    factories.ManholeFactory(drain_level=None)
    factories.ManholeFactory(drain_level=1)
    m3_error = factories.ManholeFactory(
        drain_level=None,
        calculation_type=constants.CalculationTypeNode.CONNECTED
    )  # drain_level cannot be null when calculation_type is CONNECTED
    m4_error = factories.ManholeFactory(
        drain_level=1,
        bottom_level=2,
        calculation_type=constants.CalculationTypeNode.CONNECTED,
    )  # bottom_level  >= drain_level when calculation_type is CONNECTED
    factories.ManholeFactory(
        drain_level=1,
        bottom_level=0,
        calculation_type=constants.CalculationTypeNode.CONNECTED,
    )
    factories.ManholeFactory(
        drain_level=None,
        bottom_level=0,
        calculation_type=constants.CalculationTypeNode.EMBEDDED,
    )

    query_drn_lvl_st_bttm_lvl = Query(models.Manhole).filter(
        models.Manhole.drain_level < models.Manhole.bottom_level,
        models.Manhole.calculation_type ==
        constants.CalculationTypeNode.CONNECTED,
    )
    query_invalid_not_null = Query(models.Manhole).filter(
        models.Manhole.calculation_type ==
        constants.CalculationTypeNode.CONNECTED,
        models.Manhole.drain_level == None,
    )
    check_drn_lvl_gt_bttm_lvl = QueryCheck(
        column=models.Manhole.bottom_level,
        invalid=query_drn_lvl_st_bttm_lvl,
        message="Manhole.drain_level >= Manhole.bottom_level when "
        "Manhole.calculation_type is CONNECTED",
    )
    check_invalid_not_null = QueryCheck(
        column=models.Manhole.drain_level,
        invalid=query_invalid_not_null,
        message=
        "Manhole.drain_level cannot be null when Manhole.calculation_type is "
        "CONNECTED",
    )
    errors1 = check_drn_lvl_gt_bttm_lvl.get_invalid(session)
    errors2 = check_invalid_not_null.get_invalid(session)
    assert len(errors1) == 1
    assert len(errors2) == 1
    assert m3_error.id == errors2[0].id
    assert m4_error.id == errors1[0].id