class TaskNavigationOptimizer(object):
    def __init__(self, yaml_path, verbose=False):
        self.input_list = TaskList()
        self.output_list = ActionList()
        self.task_sorter = TaskSorter(yaml_path)
        self.verbose = verbose

    def reinitialize_input(self, input):
        self.input_list = input
        self.output_list.clear_task()

    def prioritize_task_list(self, sort_list):
        if self.verbose: rospy.loginfo("Prioritizing task list...")
        return self.task_sorter.sort_list_by_distance_dest(sort_list)

    def optimize(self):
        if not self.prioritize_task_list(self.input_list):
            rospy.logerr("Failed to sort by distance. Check 'get_distance' service.")
        if self.verbose: 
            print('----------------------------------------------------------------------------')
            print(self.input_list)
            print('----------------------------------------------------------------------------')
        for item in self.input_list.task_list:
            TaskFormatter.format_drive_with_orientation(item.destination, item.orientation, self.output_list.task_list)
        TaskFormatter.format_drive(EXIT_KEY, self.output_list.task_list)  
        return True
class TaskTransportOptimizer(object):
    def __init__(self, holding_capacity, yaml_path, verbose=False):
        self.holding_capacity = holding_capacity
        self.input_list = TaskList()
        self.output_list = ActionList()
        self.task_sorter = TaskSorter(yaml_path)
        self.verbose = verbose
    
    def reinitialize_input(self, input):
        self.input_list = input 
        self.output_list.clear_task()

    def prioritize_task_list(self, sort_list):
        if self.verbose: rospy.loginfo("Prioritizing task list...")
        # sort_list.sort_by_src_and_dest() # may not be necessary since we're sorting by distances
        self.task_sorter.sort_list_by_distance_src_dest(sort_list)

    def add_task_to_capacity (self, holding_list):
        if not holding_list.is_full():
            (status, index, task) = self.input_list.get_next_obj_to_pick()
            if not status: # picked everything
                return True 
            # check if source changed
            if (len(holding_list.task_list) != 0 and
                task.source != holding_list.task_list[0].source): 
                return True 
            holding_list.add_task(task) 
            self.input_list.task_list[index].status_to_scheduled()
            return False
        else:
            return True
    
    def pick_from_sources (self, task_list):
        unique_sources = task_list.get_unique_source()

        if (len(unique_sources) > 0):
            for key, value in unique_sources.items():
                # Drive to location (included the empty call to move arm)
                TaskFormatter.format_drive(key, self.output_list.task_list)      
                #  List of items to pick up there         
                pick_up = task_list.get_tasks_by_source(key)   
                # Schedule a pick 
                for task in pick_up.task_list:
                    if task.is_dest_same_as_src():
                        # Pick and place right back
                        TaskFormatter.format_pick_task(task, self.output_list.task_list)
                        TaskFormatter.format_place_task(task, self.output_list.task_list)
                        # Remove it from list so a drop-off will not be scheduled
                        self.input_list.remove_task(task) 
                    else:
                        # Pick and place to robot
                        TaskFormatter.format_pick_task(task, self.output_list.task_list)
                        TaskFormatter.format_place_to_robot(task, self.output_list.task_list)

    def drop_off_item (self, task, scan_precision=False, scan_container=False):
        if (scan_precision):
            TaskFormatter.format_find_hole(task, self.output_list.task_list)
        if (scan_container):
            TaskFormatter.format_find_container(task, self.output_list.task_list)
        TaskFormatter.format_pick_from_robot(task, self.output_list.task_list)
        TaskFormatter.format_place_task(task, self.output_list.task_list)

    def drop_off (self, drop_off_list):
        for task in drop_off_list.task_list:
            if (task.destination_str == LocationIdentifierType.PP.fullname):
                self.drop_off_item (task, scan_precision=True, scan_container=False)
            elif (task.container != -1):
                self.drop_off_item (task, scan_precision=False, scan_container=True)
            else:
                self.drop_off_item (task, scan_precision=False, scan_container=False)

    def drop_off_at_destinations (self, task_list):
        unique_destinations = task_list.get_unique_destination()

        if len(unique_destinations) > 0:
            for key, value in unique_destinations.items():
                # Drive to location (included the empty call to move arm)
                TaskFormatter.format_drive(key, self.output_list.task_list)         
                # List of all items to drop off there   
                drop_off_list = task_list.get_tasks_by_destination(key)   
                # Schedule a drop off
                # LocationIdentifierType.PP = Type of the precision platform
                self.drop_off (drop_off_list)

    def optimize(self):
        # Clear things
        holding_list = TaskList(capacity=self.holding_capacity)
        holding_list_ok = False

        self.prioritize_task_list(self.input_list)
        
        if self.verbose: 
            print('----------------------------------------------------------------------------')
            print(self.input_list)
            print('----------------------------------------------------------------------------')

        # While there is still tasks to schedule
        while not self.input_list.is_empty():
            holding_list_ok = self.add_task_to_capacity(holding_list)

            if (holding_list_ok):
                # Schedule pick from the current source
                self.pick_from_sources(holding_list)
                # Schedule drop off for those items
                self.drop_off_at_destinations(holding_list)

                # Clean up holding list
                self.input_list.remove_task_in_list(holding_list.task_list)
                holding_list.clear_task()
                holding_list_ok = False # Reset flag
        
        # Check for any remaining items
        if not holding_list.is_empty():
            # Schedule pick from the current source
            self.pick_from_sources(holding_list)
            # Schedule drop off for those items
            self.drop_off_at_destinations(holding_list)

        # Clean up holding list
        self.input_list.remove_task_in_list(holding_list.task_list)
        holding_list.clear_task()

        TaskFormatter.format_drive(EXIT_KEY, self.output_list.task_list)  

        return True # well i should hope it doesn't error
class TaskManager(object):
    MAX_HOLDING_CAPACITY = 3

    def __init__(self, holding_capacity=MAX_HOLDING_CAPACITY, yaml_path="", verbose=False):
        # Basic config
        self.holding_capacity = holding_capacity 
        self.yaml_path = yaml_path
        self.verbose = verbose
        self.MAX_REPLAN_ATTEMPTS = 10

        # Here come the lists
        self.task_list = TaskList()                 # The current one
        self.task_list_last_updated = TaskList()    # Last replanned task list

        self.replan_list = ActionList()             # Action list from mux
        self.error_index = -1                       # Error index

        # Output
        self.output_list = ActionList()             # The output list to send to mux

        self.task_list_type = -1                    # Type of the list (TRANSPORTATION/NAVIGATION)

        # Helpers
        self.transport_optimizer = TaskTransportOptimizer(holding_capacity, yaml_path, verbose)
        self.navigation_optimizer = TaskNavigationOptimizer(yaml_path, verbose)
        self.task_replanner = TaskReplanner(yaml_path, verbose)

    ## Call to initialize task_list when received data from refbox
    def initialize_list(self, task_list):
        # Save list type as the type of the first task
        self.task_list_type = task_list.task_list[0].type

        # Save the original + set the current one
        self.task_list_last_updated = task_list.make_duplicate()
        self.task_list = task_list.make_duplicate()

    ## Initializing necessary data and then calls helpers
    def optimize_transport(self):
        # Make duplicates so i won't lose the task_list
        self.transport_optimizer.reinitialize_input(self.task_list.make_duplicate())
        self.transport_optimizer.optimize()
        # Points to output
        self.output_list = self.transport_optimizer.output_list

    def optimize_navigation(self):
        # Make duplicates so i won't lose the task_list
        self.navigation_optimizer.reinitialize_input(self.task_list.make_duplicate())
        self.navigation_optimizer.optimize()
        # Points to output
        self.output_list = self.navigation_optimizer.output_list

    def call_replanner(self):
        self.task_replanner.reinitialize_input(self.replan_list, self.task_list_last_updated, self.error_index)
        success, additional_action = self.task_replanner.replan()
        if (not success):
            return False, AdditionalAction.NONE

        # The new task_list
        self.task_list = self.task_replanner.output_list.make_duplicate()
        # Save the last replanned ones
        self.task_list_last_updated = self.task_replanner.output_list.make_duplicate()
        return True, additional_action
    
    def handle_additional_action(self, additional_action, error_task):
        if (additional_action == AdditionalAction.DUMP_ITEM_ON_HAND):
            additional_actions = ActionList()
            rospy.logwarn("Dumping object [%s] to the current table" % (error_task.object_str))
            TaskFormatter.format_place_task(error_task, additional_actions.task_list)
            rospy.loginfo("Additional actions:")
            print(additional_actions)
            return additional_actions
        elif (additional_action == AdditionalAction.DUMP_ITEM_ON_HAND_AND_PICK_TO_CAP):
            additional_actions = ActionList()
            rospy.logwarn("Dumping object [%s] to the current table" % (error_task.object_str))
            TaskFormatter.format_place_task(error_task, additional_actions.task_list)
            current_cap = self.holding_capacity - self.task_replanner.items_on_hand
            rospy.logwarn("Current capacity is: ", current_cap)

            rospy.loginfo("Additional actions:")
            print(additional_actions)
            return additional_actions
        elif (additional_action == AdditionalAction.PICK_AND_DUMP_ON_TABLE):
            additional_actions = ActionList()
            rospy.logwarn("Picking object [%s] from ROBOT" % (error_task.object_str))
            TaskFormatter.format_pick_from_robot(error_task, additional_actions.task_list)
            rospy.logwarn("Dumpinggggg. Please ignore random destination.")
            error_task.set_destination(1) # Set a random destination so mux doesn't call precision place
            TaskFormatter.format_place_task(error_task, additional_actions.task_list)
            rospy.loginfo("Additional actions:")
            print(additional_actions)
            return additional_actions
        return None

    ## Interface that task_manager_handler calls to optimize list
    def optimize_list(self):
        if (self.task_list_type == int(TaskType.TRANSPORTATION)):
            self.optimize_transport()
            return True
        elif (self.task_list_type == int(TaskType.NAVIGATION)):
            self.optimize_navigation()
            return True
        return False

    ## Interface that task_manager_handler calls when received error from mux
    def replan(self):
        # Loggity log
        rospy.logwarn("Error @ index [%d] with task: " % self.error_index)
        rospy.logwarn(self.replan_list.task_list[self.error_index])

        # Get new task list
        rospy.loginfo("Making new task list...")
        success, additional_action = self.call_replanner()
        attempt = 1

        while (not success):
            if (attempt > self.MAX_REPLAN_ATTEMPTS):
                # we are slightly f*cked
                rospy.logerr("Max attempt reached. Cannot replan.")
                rospy.logerr("Please check if there is error case and error handling behavior defined for this error")
                rospy.logfatal("Abort mission.")
                return False
            rospy.logerr("Cannot replan. Will retry: %d/%d times" % (attempt, self.MAX_REPLAN_ATTEMPTS))
            success, additional_action = self.call_replanner()
            attempt = attempt + 1

        rospy.loginfo("New list made successfully!")

        # Optimize
        if (self.task_list_type == int(TaskType.TRANSPORTATION)):
            rospy.loginfo("Sending new task list to transport optimizer...")
            self.optimize_transport()
        elif (self.task_list_type == int(TaskType.NAVIGATION)):
            rospy.loginfo("Sending new task list to navigation optimizer...")
            self.optimize_navigation()

        # Handle additional actions
        rospy.loginfo(additional_action)
        if (additional_action != AdditionalAction.NONE):
            additional_action_list = self.handle_additional_action(additional_action, self.replan_list.task_list[self.error_index])
            if (additional_action_list != None):
                # I'm appending to the top of the list now
                self.output_list.task_list = additional_action_list.task_list + self.output_list.task_list
        
        # Loggity log
        rospy.loginfo("Replan finished")
        rospy.loginfo("##################")
        print(self.output_list)
        return True
    
    ## Clear all the lists
    def clear(self):
        self.task_list.clear_task()
        self.task_list_last_updated.clear_task()
        self.output_list.clear_task()