#!/usr/bin/env python

import fileinput
import random
from kdtree import KDTree

# read in the points from a file specified on the command line. E.g.:
# $ ./kdtree_test.py ../../../data/sim_waypoints.csv
points = []
for line in fileinput.input():
    line_parts = line.split(',')
    points.append((float(line_parts[0]), float(line_parts[1]), int(i)))
    i += 1

# generate the K-D Tree from the points
kdtree = KDTree(points)

# pick a random point from the points
rand_index = random.randint(0, len(points))

# find the closest point
point = points[rand_index]
print ("randomly chose point {} at index {}".format(point, rand_index))
new_point = (point[0] + 2.0, point[1] + 7.0)
print ("tweaked x,y to be {}".format(new_point, rand_index))

closest = kdtree.closest_point(new_point)

print ("closest point to {} is {}".format(new_point, closest))
class TLDetector(object):
    def __init__(self):

        self.pose = None
        self.waypoints = None
        self.camera_image = None
        self.lights = []

        sub1 = rospy.Subscriber('/current_pose', PoseStamped, self.pose_cb)
        sub2 = rospy.Subscriber('/base_waypoints', Lane, self.waypoints_cb)
        /vehicle/traffic_lights provides you with the location of the traffic light in 3D map space and
        helps you acquire an accurate ground truth data source for the traffic light
        classifier by sending the current color state of all traffic lights in the
        simulator. When testing on the vehicle, the color state will not be available. You'll need to
        rely on the position of the light and the camera image to predict it.
        sub3 = rospy.Subscriber('/vehicle/traffic_lights', TrafficLightArray,
        sub6 = rospy.Subscriber('/image_color', Image, self.image_cb)

        rospy.Subscriber('/image_color', Image, self.collect_images_callback)

        config_string = rospy.get_param("/traffic_light_config")
        self.config = yaml.load(config_string)

        self.upcoming_red_light_pub = rospy.Publisher('/traffic_waypoint',

        self.bridge = CvBridge()
        self.listener = tf.TransformListener()

        self.state = TrafficLight.UNKNOWN
        self.last_state = TrafficLight.UNKNOWN
        self.last_wp = -1
        self.state_count = 0

        # Parameters for collecting frames from the camera
        self.should_collect_data = False
        self.dump_images_dir = create_dir_if_nonexistent(
            join(expanduser('~'), 'traffic_light_dataset', 'raw_images'))
        self.dump_images_counter = len(os.listdir(self.dump_images_dir))
        self.last_dump_tstamp = rospy.get_time()

        # Used to find the closest waypoint
        self.kdtree = None
        # Data file to store the image name and light state in the image.
        self.datafile = open(self.dump_images_dir + "/lightsData.csv", "w+")

        self.lightState = None

        if not PREFER_GROUND_TRUTH:
            # Create tensorflow session
            self.session = tensorflow.Session()

            # Import classifier and restore pre-trained weights
            self.light_classifier = TrafficLightClassifier(
                input_shape=[64, 64], learning_rate=1e-4)
                self.session, TrafficLightClassifier.checkpoint_path)


    def collect_images_callback(self, msg):
        Save camera images (currently once per second)
        def should_collect_camera_image():
            return self.should_collect_data and (
                rospy.get_time() - self.last_dump_tstamp > 1)

        if should_collect_camera_image():

            # Convert image message to actual numpy data
            image_data = self.bridge.imgmsg_to_cv2(msg)
            image_data = cv2.cvtColor(
                image_data, cv2.COLOR_RGB2BGR)  # opencv uses BGR convention
            image_path = join(self.dump_images_dir,

            # Dump image to dump directory
            cv2.imwrite(image_path, image_data)

            # write the state of the light and the image name to a csv file
            print("Writing to datafile")
            self.datafile.write('{:06d}.jpg'.format(self.dump_images_counter) +
                                " , " + self.lightState + "\n")
            # Update counter and timestamp
            self.dump_images_counter += 1
            self.last_dump_tstamp = rospy.get_time()

    def pose_cb(self, msg):
        self.pose = msg

    def waypoints_cb(self, waypoints):
        self.waypoints = waypoints.waypoints

    def traffic_cb(self, msg):
        self.lights = msg.lights

    def image_cb(self, msg):
        """Identifies red lights in the incoming camera image and publishes the index
            of the waypoint closest to the red light's stop line to /traffic_waypoint

            msg (Image): image from car-mounted camera

        self.has_image = True
        self.camera_image = msg
        light_wp, state = self.process_traffic_lights()
        Publish upcoming red lights at camera frequency.
        Each predicted state has to occur `STATE_COUNT_THRESHOLD` number
        of times till we start using it. Otherwise the previous stable state is
        if self.state != state:
            self.state_count = 0
            self.state = state
        elif self.state_count >= STATE_COUNT_THRESHOLD:
            self.last_state = self.state
            light_wp = light_wp if state == TrafficLight.RED else -1
            self.last_wp = light_wp
        self.state_count += 1

    def get_closest_waypoint(self, pose):
        """Identifies the closest path waypoint to the given position
            pose (Pose): position to match a waypoint to

            int: index of the closest waypoint in self.waypoints

        if self.waypoints is not None and self.kdtree is None:
            if VERBOSE:
                print('tl_detector: g_cl_wp: initializing kdtree')
            points = []

            for i, waypoint in enumerate(self.waypoints):
                               float(waypoint.pose.pose.position.y), i))

            self.kdtree = KDTree(points)

        if self.kdtree is not None:
            current_position = (pose.position.x, pose.position.y)
            closest = self.kdtree.closest_point(current_position)
            if VERBOSE:

                print('tl_detector: g_cl_wp: closest point to {} is {}'.format(
                    current_position, closest))
            return closest[2]

        return 0

    def get_light_state(self, light):
        """Determines the current color of the traffic light

            light (TrafficLight): light to classify

            int: ID of traffic light color (specified in styx_msgs/TrafficLight)

        if not self.has_image:
            self.prev_light_loc = None
            return False

        cv_image = self.bridge.imgmsg_to_cv2(self.camera_image, "bgr8")

        light_state = self.light_classifier.get_classification(
            self.session, cv_image)

        return light_state

    def process_traffic_lights(self):
        """Finds closest visible traffic light, if one exists, and determines its
            location and color

            int: index of waypoint closest to the upcoming stop line for a traffic light (-1 if none exists)
            int: ID of traffic light color (specified in styx_msgs/TrafficLight)

        light = None

        # List of positions that correspond to the line to stop in front of for a given intersection
        stop_line_positions = self.config['stop_line_positions']
        if self.pose:
            car_position = self.get_closest_waypoint(self.pose.pose)

        #TODO find the closest visible traffic light (if one exists)
        if VERBOSE:
            print("tl_detector: p_tl: There are {} traffic lights to analyze.".

        min_distance = float("Infinity")
        for current_light in self.lights:

            # Check to see whether the traffic light is ahead of the car
            if is_ahead(current_light, self.pose.pose):

                # Get the simplified Euclidean distance (no sqrt) between it and the car
                light_distance = get_simple_distance_from_waypoint(
                    current_light, self.pose.pose)

                # If the light is closer, remember it
                if (light_distance < min_distance):
                    min_distance = light_distance
                    light = current_light

        # If we found a light ahead of us
        if light:
            self.lightState = self._light_color(light.state)

            # Calculate the actual distance the of the light.
            light_distance = math.sqrt(min_distance)

            if VERBOSE:
                    "tl_detector: p_tl: closest light to {} is at {} (Distance: {})."
                        (self.pose.pose.position.x, self.pose.pose.position.y),
                         light.pose.pose.position.y), light_distance))

            # Look up the closest waypoint to it
            # TODO: [brahm] Can we assume self.kdtree is initialized?
            light_wp = self.get_closest_waypoint(light.pose.pose)

            # Determine the state of the light
            state = -1
            if PREFER_GROUND_TRUTH:

                if VERBOSE:
                    print("tl_detector: p_tl: Ground truth light color: {}".

                # TODO: [brahm] Determine what light.state is when not available (e.g. not in the simulator)
                if light.state is not None:
                    state = light.state

            #if True: # (state == -1):
            if state == -1:
                # this is where we classify the light
                state_inferred = self.get_light_state(light)

            # If the traffic light is close, let us know
            if (light_distance < TL_NEARNESS_THRESHOLD):
                if VERBOSE:
                    print("tl_detector: p_tl: light is close: {} meters away.".

            return light_wp, state

        self.waypoints = None
        return -1, TrafficLight.UNKNOWN

    # Helper
    def _light_color(self, state):
        if (state == TrafficLight.RED):
            return "RED"
        elif (state == TrafficLight.YELLOW):
            return "YELLOW"
        elif (state == TrafficLight.GREEN):
            return "GREEN"
            return "UNKNOWN"