class Agent:
    def __init__(self, carX, carY, window, best=False, scale: float = 1.48):
        self.car = Car(carX, carY, scale)
        self.scale = scale
        self.window = window
        self.dead = self.car.dead
        self.fitness = 0
        self.step = 0
        self.maxStep = 500
        self.reset = False
        self.reset2 = False

        self.nn = NeuralNetwork()
        self.bestCar = best

        self.index = 0

    def draw(self, batch, foreground, background, vertices, show):
        car = self.car.draw(batch, foreground, self.bestCar)
        # intersectEyes = self.car.intersectEyes(batch, vertices, background)
        # if show or not self.dead:
        # eyes = self.car.eyes(batch, background)
        # hitbox = self.car.hitbox(batch, background)
        # return car, intersectEyes, eyes, hitbox
        # return car, intersectEyes, eyes
        # return car, intersectEyes
        return car

    def generateHitbox(self):
        hitbox = self.car.generateHitbox()
        return hitbox

    def move(self, dt, vertices):
        if self.step < self.maxStep:
            inputnn = []
            self.car.mathIntersect(vertices)
            inputnn = self.car.observe()
            if len(inputnn) > 7:
                instruction = self.nn.feedforward(inputnn)
            else:
                instruction = 3
                print(f"Input: {inputnn}")
            self.step += 1

            self.car.updateWithInstruction(dt, instruction)
        else:
            self.dead = True

    def update(self, dt, vertices):
        if not self.dead:
            if self.car.currentCheckpoint > 1 and self.reset == False:
                self.maxStep += 1000
                self.reset = True
            if self.car.currentCheckpoint > 2 and self.reset2 == False:
                self.maxStep += 1000
                self.reset2 = True
            self.move(dt, vertices)

            pos = self.car.position
            if (pos.x < 2 or pos.y < 2 or pos.x > self.window.x - 2
                    or pos.y > self.window.y - 2):
                self.dead = True
        else:
            self.car.dead = self.dead

    def calcFitness(self, skeletonLines, checkpoints, blindSpot, blindIndex):
        if self.car.currentCheckpoint > 0 or self.car.currentLap > 0:
            if self.car.currentCheckpoint >= checkpoints:
                self.car.currentLap += 1
                self.car.currentCheckpoint = 0
            minimum = 100000
            minLine = None
            for line in skeletonLines:
                distance = d if (d := line.distance(
                    self.car.position)) != None else 99999
                if distance < minimum:
                    minimum = distance
                    minLine = line
            if minimum > 80:
                minDistance = 100000
                minSpot = None
                for spot in blindSpot:
                    distance = abs(self.car.position - spot)
                    if distance < minDistance:
                        minDistance = distance
                        minSpot = spot
                indexSpot = blindSpot.index(minSpot)
                index = blindIndex[indexSpot]
                minLine = skeletonLines[index]
            else:
                index = skeletonLines.index(minLine)

            linesToIndex = skeletonLines[:index]
            distanceToIndex = 0
            for line in linesToIndex:
                startPointLine, endPointLine = line.getEndPoints()
                distanceToIndex += abs(endPointLine - startPointLine)

            lineToOutside = linline.fromVector(minLine.n,
                                               self.car.position)  #BC
            intersection = lineToOutside.intersect(minLine)  #B
            startPoint, _ = minLine.getEndPoints()  #A
            if startPoint.x == None:
                distanceLine = 0
            elif startPoint.y == None:
                distanceLine = 0
            elif intersection == None:
                distanceLine = 0
            else:
                distanceLine = intersection - startPoint  #AB

            if index > 10 and self.car.currentCheckpoint < 3:
                self.fitness = 0
            else:
                self.fitness = ((self.car.currentLap * 100) +
                                (abs(distanceLine) * 100) +
                                distanceToIndex * 100)**2 / self.step
            self.index = index
        else:
class circuitEnv(py_environment.PyEnvironment):
    def __init__(self) -> None:
        self._action_spec = array_spec.BoundedArraySpec(shape=(),
                                                        dtype=np.int32,
                                                        minimum=0,
                                                        maximum=3,
                                                        name="action")
        self._observation_spec = array_spec.BoundedArraySpec(
            shape=(8, ), dtype=np.float32, minimum=0, name="observation")

        self.circuit = circuit.fromJSON("circuits/BONK_CIRCUIT.json")
        self.agent = Car(self.circuit.startingPoint.x,
                         self.circuit.startingPoint.y)
        self._episode_ended = False
        self.discount = 0.9925
        self.stepCountingCounter = 0

        self.viewer = None

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _reset(self):
        self.agent.reset()
        self.circuit.reset()

        self.agent.updateWithInstruction(dt, None)
        self.agent.mathIntersect(self.circuit.vertices)

        self.stepCountingCounter = 0
        self._episode_ended = False
        return ts.restart(self._observe())

    def _step(self, action):
        if self._episode_ended:
            print('episode ended')
            return self.reset()

        #run physics
        self.agent.updateWithInstruction(dt, action)
        hitbox = self.agent.generateHitbox()
        self.agent.mathIntersect(self.circuit.vertices)

        self.stepCountingCounter += 1
        if self.circuit.collidedWithCar(hitbox):
            self._episode_ended = True
            return ts.termination(self._observe(), reward=-500.0)
        elif self.circuit.carCollidedWithCheckpoint(self.agent):
            reward = 300 * self.discount**self.stepCountingCounter
            return ts.transition(self._observe(), reward=reward)
        else:
            return ts.transition(self._observe(), reward=0)

    def _observe(self):
        return np.array(self.agent.observe(), dtype=np.float32)

    ### DRAWING STUFF ###
    def render(self, mode="human"):
        if self.viewer is None:
            self.viewer = Viewer(1920, 1080, self.agent, self.circuit)
        self.viewer.render()