class StraightRoadBuilder:
    def __init__(self, country=CountryCodes.US):
        self.STD_ROADMARK = pyodrx.RoadMark(pyodrx.RoadMarkType.solid,
                                            0.2,
                                            rule=pyodrx.MarkRule.no_passing)
        self.STD_START_CLOTH = 1 / 1000000000
        self.country = country
        self.laneBuilder = LaneBuilder()
        self.configuration = Configuration()
        self.name = 'StraightRoadBuilder'
        pass

    def createPVForLine(self, length):
        line1 = pyodrx.Line(length)

        # create planviews
        pv = extensions.ExtendedPlanview()
        pv.add_geometry(line1)
        return pv

    def createRandom(self,
                     roadId,
                     randomState=None,
                     length=20,
                     junction=-1,
                     lane_offset=None,
                     maxLanePerSide=2,
                     minLanePerSide=0,
                     turns=False,
                     merges=False,
                     medianType=None,
                     medianWidth=3,
                     skipEndpoint=None,
                     force3Section=False):

        if randomState is not None:
            np.random.set_state(randomState)

        if lane_offset is None:
            lane_offset = self.configuration.get("default_lane_width")

        if maxLanePerSide < 1:
            raise Exception(
                f"{self.name}: createRandom: maxLanePerSide cannot be less than 1"
            )

        laneRange = np.arange(minLanePerSide, maxLanePerSide + 1)
        n_lanes_left = np.random.choice(laneRange)
        n_lanes_right = np.random.choice(laneRange)

        if (n_lanes_left == 0) and (n_lanes_right == 0):
            return self.createRandom(roadId,
                                     randomState=randomState,
                                     length=length,
                                     junction=junction,
                                     lane_offset=lane_offset,
                                     maxLanePerSide=maxLanePerSide,
                                     minLanePerSide=minLanePerSide,
                                     turns=turns,
                                     merges=merges,
                                     medianType=medianType,
                                     medianWidth=3,
                                     skipEndpoint=skipEndpoint,
                                     force3Section=force3Section)

        numLeftTurnsOnLeft = 0
        numRightTurnsOnRight = 0
        numLeftMergeOnLeft = 0
        numRightMergeOnRight = 0
        numberOfLeftTurnLanesOnRight = 0
        numberOfRightTurnLanesOnLeft = 0
        mergeLaneOnTheOppositeSideForInternalTurn = np.random.choice(
            [True, False])
        if turns:
            numLeftTurnsOnLeft = np.random.choice([0, 1])
            numRightTurnsOnRight = np.random.choice([0, 1])
            numberOfLeftTurnLanesOnRight = np.random.choice([0, 1])
            numberOfRightTurnLanesOnLeft = np.random.choice([0, 1])
        elif merges:
            numLeftMergeOnLeft = np.random.choice([0, 1])
            numRightMergeOnRight = np.random.choice([0, 1])

        if medianType is None:
            return self.create(
                roadId,
                n_lanes_left=n_lanes_left,
                n_lanes_right=n_lanes_right,
                length=length,
                junction=junction,
                lane_offset=lane_offset,
                laneSides=LaneSides.BOTH,
                numLeftTurnsOnLeft=numLeftTurnsOnLeft,
                numRightTurnsOnRight=numRightTurnsOnRight,
                numLeftMergeOnLeft=numLeftMergeOnLeft,
                numRightMergeOnRight=numRightMergeOnRight,
                numberOfLeftTurnLanesOnRight=numberOfLeftTurnLanesOnRight,
                numberOfRightTurnLanesOnLeft=numberOfRightTurnLanesOnLeft,
                mergeLaneOnTheOppositeSideForInternalTurn=
                mergeLaneOnTheOppositeSideForInternalTurn,
                force3Section=force3Section)
        else:
            return self.createWithMedianRestrictedLane(
                roadId,
                n_lanes_left=n_lanes_left,
                n_lanes_right=n_lanes_right,
                length=length,
                junction=junction,
                lane_offset=lane_offset,
                laneSides=LaneSides.BOTH,
                numLeftTurnsOnLeft=numLeftTurnsOnLeft,
                numRightTurnsOnRight=numRightTurnsOnRight,
                numLeftMergeOnLeft=numLeftMergeOnLeft,
                numRightMergeOnRight=numRightMergeOnRight,
                numberOfLeftTurnLanesOnRight=numberOfLeftTurnLanesOnRight,
                numberOfRightTurnLanesOnLeft=numberOfRightTurnLanesOnLeft,
                mergeLaneOnTheOppositeSideForInternalTurn=
                mergeLaneOnTheOppositeSideForInternalTurn,
                medianType=medianType,
                medianWidth=medianWidth,
                skipEndpoint=skipEndpoint)

        pass

    def create(self,
               roadId,
               n_lanes_left=1,
               n_lanes_right=1,
               length=20,
               junction=-1,
               lane_offset=3,
               laneSides=LaneSides.BOTH,
               numLeftTurnsOnLeft=0,
               numRightTurnsOnRight=0,
               numLeftMergeOnLeft=0,
               numRightMergeOnRight=0,
               numberOfLeftTurnLanesOnRight=0,
               numberOfRightTurnLanesOnLeft=0,
               mergeLaneOnTheOppositeSideForInternalTurn=True,
               force3Section=False):

        # create geometry
        pv = self.createPVForLine(length)

        # laneSections = self.laneBuilder.getStandardLanes(n_lanes, lane_offset, laneSides,
        #                                                     roadLength=length,
        #                                                     isLeftTurnLane=isLeftTurnLane, isRightTurnLane=isRightTurnLane,
        #                                                     isLeftMergeLane=isLeftMergeLane, isRightMergeLane=isRightMergeLane)
        singleSide = False
        if laneSides != LaneSides.BOTH:
            singleSide = True
        laneSections = self.laneBuilder.getLanes(
            n_lanes_left,
            n_lanes_right,
            lane_offset=lane_offset,
            singleSide=singleSide,
            roadLength=length,
            numLeftTurnsOnLeft=numLeftTurnsOnLeft,
            numRightTurnsOnRight=numRightTurnsOnRight,
            numLeftMergeOnLeft=numLeftMergeOnLeft,
            numRightMergeOnRight=numRightMergeOnRight,
            numberOfLeftTurnLanesOnRight=numberOfLeftTurnLanesOnRight,
            numberOfRightTurnLanesOnLeft=numberOfRightTurnLanesOnLeft,
            mergeLaneOnTheOppositeSideForInternalTurn=
            mergeLaneOnTheOppositeSideForInternalTurn,
            force3Section=force3Section)

        road = ExtendedRoad(roadId, pv, laneSections, road_type=junction)
        return road

    def createWithMedianRestrictedLane(
            self,
            roadId,
            n_lanes_left=1,
            n_lanes_right=1,
            length=20,
            junction=-1,
            lane_offset=3,
            laneSides=LaneSides.BOTH,
            numLeftTurnsOnLeft=0,
            numRightTurnsOnRight=0,
            numLeftMergeOnLeft=0,
            numRightMergeOnRight=0,
            numberOfLeftTurnLanesOnRight=0,
            numberOfRightTurnLanesOnLeft=0,
            mergeLaneOnTheOppositeSideForInternalTurn=True,
            medianType='partial',
            medianWidth=3,
            skipEndpoint=None):

        road = self.create(
            roadId,
            n_lanes_left=n_lanes_left,
            n_lanes_right=n_lanes_right,
            length=length,
            junction=junction,
            lane_offset=lane_offset,
            laneSides=laneSides,
            numLeftTurnsOnLeft=numLeftTurnsOnLeft,
            numRightTurnsOnRight=numRightTurnsOnRight,
            numLeftMergeOnLeft=numLeftMergeOnLeft,
            numRightMergeOnRight=numRightMergeOnRight,
            numberOfLeftTurnLanesOnRight=numberOfLeftTurnLanesOnRight,
            numberOfRightTurnLanesOnLeft=numberOfRightTurnLanesOnLeft,
            mergeLaneOnTheOppositeSideForInternalTurn=
            mergeLaneOnTheOppositeSideForInternalTurn,
            force3Section=True)
        if medianType == 'partial':
            if skipEndpoint is None:
                raise Exception(
                    f"{self.name}: createWithMedianRestrictedLane skipEndpoint cannot be None for partial median lanes."
                )

            self.laneBuilder.addMedianIslandsTo2Of3Sections(
                road,
                roadLength=length,
                skipEndpoint=skipEndpoint,
                width=medianWidth)
        else:
            self.laneBuilder.addMedianIslandsToAllSections(road,
                                                           width=medianWidth)

        return road

    def createWithRightTurnLanesOnLeft(
            self,
            roadId,
            length=100,
            junction=-1,
            n_lanes=1,
            lane_offset=3,
            laneSides=LaneSides.BOTH,
            isLeftTurnLane=False,
            isRightTurnLane=False,
            isLeftMergeLane=False,
            isRightMergeLane=False,
            numberOfRightTurnLanesOnLeft=1,
            mergeLaneOnTheOppositeSideForInternalTurn=True):

        # create geometry
        pv = self.createPVForLine(length)

        laneSections = self.laneBuilder.getStandardLanesWithInternalTurns(
            n_lanes,
            lane_offset,
            laneSides,
            roadLength=length,
            isLeftTurnLane=isLeftTurnLane,
            isRightTurnLane=isRightTurnLane,
            isLeftMergeLane=isLeftMergeLane,
            isRightMergeLane=isRightMergeLane,
            numberOfRightTurnLanesOnLeft=numberOfRightTurnLanesOnLeft,
            mergeLaneOnTheOppositeSideForInternalTurn=
            mergeLaneOnTheOppositeSideForInternalTurn)

        road = ExtendedRoad(roadId, pv, laneSections, road_type=junction)
        return road

    def createWithLeftTurnLanesOnRight(
            self,
            roadId,
            length=100,
            junction=-1,
            n_lanes=1,
            lane_offset=3,
            laneSides=LaneSides.BOTH,
            isLeftTurnLane=False,
            isRightTurnLane=False,
            isLeftMergeLane=False,
            isRightMergeLane=False,
            numberOfLeftTurnLanesOnRight=1,
            mergeLaneOnTheOppositeSideForInternalTurn=True):
        """Will create numberOfLeftTurnLanesOnRight left turn lanes on the right side of the center line. Equal number of mergelanes will be created on the left side of the center lane, too.

        Args:
            roadId ([type]): [description]
            length (int, optional): [description]. Defaults to 100.
            junction (int, optional): [description]. Defaults to -1.
            n_lanes (int, optional): [description]. Defaults to 1.
            lane_offset (int, optional): [description]. Defaults to 3.
            laneSides ([type], optional): [description]. Defaults to LaneSides.BOTH.
            isLeftTurnLane (bool, optional): [description]. Defaults to False.
            isRightTurnLane (bool, optional): [description]. Defaults to False.
            isLeftMergeLane (bool, optional): [description]. Defaults to False.
            isRightMergeLane (bool, optional): [description]. Defaults to False.
            numberOfLeftTurnLanesOnRight (int, optional): [description]. Defaults to 1.

        Returns:
            [type]: [description]
        """

        # create geometry
        pv = self.createPVForLine(length)

        laneSections = self.laneBuilder.getStandardLanesWithInternalTurns(
            n_lanes,
            lane_offset,
            laneSides,
            roadLength=length,
            isLeftTurnLane=isLeftTurnLane,
            isRightTurnLane=isRightTurnLane,
            isLeftMergeLane=isLeftMergeLane,
            isRightMergeLane=isRightMergeLane,
            numberOfLeftTurnLanesOnRight=numberOfLeftTurnLanesOnRight,
            mergeLaneOnTheOppositeSideForInternalTurn=
            mergeLaneOnTheOppositeSideForInternalTurn)

        road = ExtendedRoad(roadId, pv, laneSections, road_type=junction)
        return road

    def createWithDifferentLanes(self,
                                 roadId,
                                 length=100,
                                 junction=-1,
                                 n_lanes_left=1,
                                 n_lanes_right=1,
                                 lane_offset=3,
                                 force3Section=False):

        return self.create(roadId,
                           n_lanes_left=n_lanes_left,
                           n_lanes_right=n_lanes_right,
                           length=length,
                           junction=junction,
                           lane_offset=lane_offset,
                           force3Section=force3Section)
class test_LaneBuilder(unittest.TestCase):
    def setUp(self):

        self.configuration = Configuration()
        self.esminiPath = self.configuration.get("esminipath")
        self.roadBuilder = RoadBuilder()
        self.laneBuilder = LaneBuilder()
        self.laneLinker = LaneLinker()
        self.straightRoadBuilder = StraightRoadBuilder()

    def test_RightLane(self):
        # test scenario for connection road

        roads = []
        roads.append(pyodrx.create_straight_road(0, 10))
        # roads.append(self.roadBuilder.createSimpleCurve(1, np.pi/4, True, curvature = 0.2))
        # roads.append(pyodrx.create_straight_road(2, 10))

        # roads[0].add_successor(pyodrx.ElementType.junction,1)

        # roads[1].add_predecessor(pyodrx.ElementType.road,0,pyodrx.ContactPoint.end)
        # # roads[1].add_predecessor(pyodrx.ElementType.road,0,pyodrx.ContactPoint.start)
        # roads[1].add_successor(pyodrx.ElementType.road,2,pyodrx.ContactPoint.start)

        # roads[2].add_predecessor(pyodrx.ElementType.junction,1, pyodrx.ContactPoint.end)

        odrName = "test_connectionRoad"
        odr = extensions.createOdrByPredecessor(odrName, roads, [])

        self.laneBuilder.addRightTurnLaneUS(roads[0], 3)
        # self.laneBuilder.addRightLaneUS(roads[1])

        odr.resetAndReadjust(byPredecessor=True)

        extensions.view_road(
            odr, os.path.join('..', self.configuration.get("esminipath")))

        xmlPath = f"output/test-RightLane.xodr"
        odr.write_xml(xmlPath)

    def test_DifferentLaneConfigurations(self):
        roads = []
        roads.append(
            self.straightRoadBuilder.createWithDifferentLanes(0,
                                                              10,
                                                              n_lanes_left=1,
                                                              n_lanes_right=1))
        connectionRoad = self.straightRoadBuilder.createWithDifferentLanes(
            1, 10, n_lanes_left=2, n_lanes_right=2)
        roads.append(connectionRoad)
        roads.append(
            self.straightRoadBuilder.createWithDifferentLanes(2,
                                                              10,
                                                              n_lanes_left=1,
                                                              n_lanes_right=2))

        roads[0].addExtendedSuccessor(roads[1], 0, pyodrx.ContactPoint.start)

        roads[1].addExtendedPredecessor(roads[0], 0, pyodrx.ContactPoint.end)
        roads[1].addExtendedSuccessor(roads[2], 0, pyodrx.ContactPoint.start)

        roads[2].addExtendedPredecessor(roads[1], 0, pyodrx.ContactPoint.end)

        self.laneBuilder.createLanesForConnectionRoad(connectionRoad, roads[0],
                                                      roads[2])

        odrName = "test_DifferentLaneConfigurations"
        odr = extensions.createOdrByPredecessor(odrName, roads, [])

        extensions.view_road(
            odr, os.path.join('..', self.configuration.get("esminipath")))

        xmlPath = f"output/test_DifferentLaneConfigurations.xodr"
        odr.write_xml(xmlPath)

    def test_addMedianIslandsToAllSections(self):
        roads = []
        roads.append(
            self.straightRoadBuilder.createWithDifferentLanes(0,
                                                              10,
                                                              n_lanes_left=1,
                                                              n_lanes_right=1))
        self.laneBuilder.addMedianIslandsToAllSections(
            roads[0], self.configuration.get('default_lane_width'))
        odrName = "test_DifferentLaneConfigurations"
        odr = extensions.createOdrByPredecessor(odrName, roads, [])

        extensions.view_road(
            odr, os.path.join('..', self.configuration.get("esminipath")))

        xmlPath = f"output/test_addMedianIslandsToAllSections.xodr"
        odr.write_xml(xmlPath)

    def test_addMedianIslandsTo3Sections(self):

        road = self.straightRoadBuilder.create(1,
                                               n_lanes_left=1,
                                               n_lanes_right=1,
                                               length=20,
                                               force3Section=False)

        try:
            self.laneBuilder.addMedianIslandsTo2Of3Sections(
                road, 20, skipEndpoint=pyodrx.ContactPoint.start, width=3)
            assert False
        except:
            assert True

        road = self.straightRoadBuilder.create(1,
                                               n_lanes_left=1,
                                               n_lanes_right=1,
                                               length=20,
                                               force3Section=True)
        self.laneBuilder.addMedianIslandsTo2Of3Sections(
            road, 20, skipEndpoint=pyodrx.ContactPoint.start, width=3)

        assert len(road.lanes.lanesections[0].leftlanes) == 1
        assert len(road.lanes.lanesections[0].rightlanes) == 1
        assert len(road.lanes.lanesections[1].leftlanes) == 2
        assert len(road.lanes.lanesections[1].rightlanes) == 2
        assert len(road.lanes.lanesections[2].leftlanes) == 2
        assert len(road.lanes.lanesections[2].rightlanes) == 2

        road = self.straightRoadBuilder.create(1,
                                               n_lanes_left=1,
                                               n_lanes_right=1,
                                               length=20,
                                               force3Section=True)
        self.laneBuilder.addMedianIslandsTo2Of3Sections(
            road, 20, skipEndpoint=pyodrx.ContactPoint.end, width=3)

        assert len(road.lanes.lanesections[0].leftlanes) == 2
        assert len(road.lanes.lanesections[0].rightlanes) == 2
        assert len(road.lanes.lanesections[1].leftlanes) == 2
        assert len(road.lanes.lanesections[1].rightlanes) == 2
        assert len(road.lanes.lanesections[2].leftlanes) == 1
        assert len(road.lanes.lanesections[2].rightlanes) == 1

        odrName = "test_DifferentLaneConfigurations"
        odr = extensions.createOdrByPredecessor(odrName, [road], [])

        extensions.view_road(
            odr, os.path.join('..', self.configuration.get("esminipath")))

        xmlPath = f"output/test_addMedianIslandsTo3Sections.xodr"
        odr.write_xml(xmlPath)