def __init__(self, config):
        self.config = config
        self.PROCESS_NUMBER = 4
        self.num_agents = self.config.num_agents
        self.size_map = [self.config.map_w, self.config.map_w]
        self.label_density = str(self.config.map_density).split('.')[-1]
        self.AgentState = AgentState(self.num_agents)
        self.communicationRadius = 5 # communicationRadius
        self.zeroTolerance = 1e-9
        self.delta = [[-1, 0],  # go up
                 [0, -1],  # go left
                 [1, 0],  # go down
                 [0, 1],  # go right
                 [0, 0]]  # stop
        self.num_actions = 5

        self.list_seqtrain_file = []
        self.list_train_file = []
        self.list_seqvalid_file = []
        self.list_validStep_file = []
        self.list_valid_file = []
        self.list_test_file = []
        self.hashids = Hashids(alphabet='01234567789abcdef', min_length=5)
        self.pathtransformer = self.pathtransformer_RelativeCoordinate
        self.label_setup = '{}{:02d}x{:02d}_density_p{}/{}_Agent'.format(self.config.loadmap_TYPE, self.size_map[0],self.size_map[1],
                                                                                          self.label_density,
                                                                                          self.num_agents)
        self.dirName_parent = os.path.join(self.config.solCases_dir, self.label_setup)
        self.dirName_Store = os.path.join(self.config.dir_SaveData, self.label_setup)
        self.dirName_input = os.path.join(self.dirName_parent, 'input')
        self.dirName_output = os.path.join(self.dirName_parent, 'output_{}'.format(config.chosen_solver))
        self.set_up()
    def __init__(self, config):
        self.config = config

        self.AgentState = AgentState(self.config.num_agents)
        self.delta_list =[[-1, 0],  # go up
                         [0, -1],  # go left
                         [1, 0],  # go down
                         [0, 1],  # go right
                         [0, 0]]  # stop
        self.delta = torch.FloatTensor(self.delta_list).to(self.config.device)

        self.List_MultiAgent_ActionVec_target = None
        self.store_MultiAgent = None
        self.channel_map = None

        self.size_map = None
        self.maxstep = None

        self.posObstacle = None
        self.numObstacle = None
        self.posStart = None
        self.posGoal = None

        self.currentState_predict = None

        self.makespanTarget = None
        self.flowtimeTarget = None
        self.makespanPredict = None
        self.flowtimePredict = None

        self.count_reachgoal = None
        self.count_reachgoalTarget = None
        self.fun_Softmax = None
        self.zeroTolerance = 1e-9
        print("run on multirobotsim with collision shielding")
    def __init__(self, config, mode):
        self.config = config
        self.datapath_exp = '{}{:02d}x{:02d}_density_p{}/{}_Agent/'.format(
            self.config.map_type, self.config.map_w, self.config.map_h,
            self.config.map_density, self.config.num_agents)

        self.dirName = os.path.join(self.config.data_root, self.datapath_exp)
        self.AgentState = AgentState(self.config.num_agents)

        if mode == "train":
            self.dir_data = os.path.join(self.dirName, 'train')
            self.search_files = self.search_target_files_withStep
            self.data_paths, self.id_stepdata = self.update_data_path_trainingset(
                self.dir_data)
            self.load_data = self.load_train_data
        elif mode == "test_trainingSet":
            self.dir_data = os.path.join(self.dirName, 'train')
            # self.search_files = self.search_target_files
            # data_paths, id_stepdata = self.update_data_path_trainingset(self.dir_data)
            data_paths, id_stepdata = self.search_target_files(self.dir_data)
            paths_total = list(zip(data_paths, id_stepdata))
            random.shuffle(paths_total)
            data_paths, id_stepdata = zip(*paths_total)
            self.data_paths = data_paths[:self.config.num_test_trainingSet]
            self.id_stepdata = id_stepdata[:self.config.num_test_trainingSet]
            self.load_data = self.load_data_during_training
        elif mode == "valid":
            self.dir_data = os.path.join(self.dirName, 'valid')
            self.data_paths, self.id_stepdata = self.obtain_data_path_validset(
                self.dir_data, self.config.num_validset)
            self.load_data = self.load_data_during_training
        elif mode == "validStep":
            self.dir_data = os.path.join(self.dirName, 'valid')
            self.data_paths, self.id_stepdata = self.search_valid_files_withStep(
                self.dir_data, self.config.num_validset)
            self.load_data = self.load_train_data
        elif mode == "test":
            self.dir_data = os.path.join(self.dirName, mode)
            self.data_paths, self.id_stepdata = self.obtain_data_path_validset(
                self.dir_data, self.config.num_testset)
            self.load_data = self.load_test_data

        self.data_size = len(self.data_paths)
Esempio n. 4
0
 def __init__(self, config):
     self.config = config
     self.PROCESS_NUMBER = 4
     self.num_agents = self.config.num_agents
     self.size_map = [self.config.map_w, self.config.map_h]
     self.AgentState = AgentState(self.num_agents)
     self.communicationRadius = 5  # communicationRadius
     self.zeroTolerance = 1e-9
     self.delta = [
         [-1, 0],  # go up
         [0, -1],  # go left
         [1, 0],  # go down
         [0, 1],  # go right
         [0, 0]
     ]  # stop
     self.num_actions = 5
     self.root_path_save = self.config.failCases_dir
     self.list_seqtrain_file = []
     self.list_train_file = []
     self.pathtransformer = self.pathtransformer_RelativeCoordinate
class multiRobotSim:
    def __init__(self, config):
        self.config = config

        self.AgentState = AgentState(self.config.num_agents)
        self.delta_list =[[-1, 0],  # go up
                         [0, -1],  # go left
                         [1, 0],  # go down
                         [0, 1],  # go right
                         [0, 0]]  # stop
        self.delta = torch.FloatTensor(self.delta_list).to(self.config.device)
        # self.onlineExpert = ComputeCBSSolution(self.config)
        self.List_MultiAgent_ActionVec_target = None
        self.store_MultiAgent = None
        self.channel_map = None

        self.size_map = None
        self.maxstep = None

        self.posObstacle = None
        self.numObstacle = None
        self.posStart = None
        self.posGoal = None

        self.currentState_predict = None

        self.makespanTarget = None
        self.flowtimeTarget = None
        self.makespanPredict = None
        self.flowtimePredict = None

        self.count_reachgoal = None
        self.count_reachgoalTarget = None
        self.fun_Softmax = None

        self.zeroTolerance = 1e-9
        print("run on multirobotsim with collision shielding")

    def setup(self, loadInput, loadTarget, makespanTarget, tensor_map, ID_dataset):

        # self.fun_Softmax = nn.Softmax(dim=-1)
        self.fun_Softmax = nn.LogSoftmax(dim=-1)
        self.ID_dataset = ID_dataset

        self.store_GSO = []
        self.store_communication_radius = []
        self.status_MultiAgent = {}
        # setupState = loadInput.permute(3, 4, 2, 1, 0)
        target = loadTarget.permute(1, 2, 3, 0)
        self.List_MultiAgent_ActionVec_target = target[:, :, :,0]
        # self.List_MultiAgent_ActionVec_target = target[:, :, 0]

        self.channel_map = tensor_map[0] # setupState[:, :, 0, 0, 0]
        self.AgentState.setmap(self.channel_map)
        self.posObstacle = self.findpos(self.channel_map).to(self.config.device)
        self.numObstacle = self.posObstacle.shape[0]
        self.size_map = self.channel_map.shape

        # self.communicationRadius = 5 #self.size_map[0] * 0.5
        # self.maxstep = self.size_map[0] * self.size_map[1]
        if self.config.num_agents >=20:
            self.rate_maxstep = 3
        else:
            self.rate_maxstep = self.config.rate_maxstep

        self.maxstep = int(makespanTarget.type(torch.int32) * self.rate_maxstep)

        self.check_predictCollsion = False
        self.check_moveCollision = True
        self.check_predictEdgeCollsion = [False] * self.config.num_agents
        self.count_reachgoal = [False] * self.config.num_agents
        self.count_reachgoalTarget = [False] * self.config.num_agents
        self.allReachGoal_Target = False
        self.makespanTarget = 0
        self.flowtimeTarget = 0
        self.makespanPredict = self.maxstep
        self.flowtimePredict = self.maxstep * self.config.num_agents #0
        # used for determining flowtimes of non rouge agents
        self.nonRogueFlowtimePredict = None

        self.stopKeyValue = torch.tensor(4).to(self.config.device)
        self.reset_disabled_action = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]).float().to(self.config.device)

        self.store_goalAgents = loadInput[0, 0, :,:]
        self.store_stateAgents = loadInput[0, 1, :, :]
        for id_agent in range(self.config.num_agents):

            status_CurrentAgent = {}

            posGoal = loadInput[:, 0,id_agent,:] #self.findpos(goal_CurrentAgent)
            posStart = loadInput[:, 1,id_agent,:] #self.findpos(start_CurrentAgent)


            path_predict = {0:posStart}
            path_target = {0:posStart}
            len_action_predict  = 0
            list_actionKey_predict = []
            actionVec_target_CurrentAgents = self.List_MultiAgent_ActionVec_target[id_agent, :, :]
            actionKeyList_target_CurrentAgents = torch.max(actionVec_target_CurrentAgents, 1)[1]

            disabled_action_predict_currentAgent = self.reset_disabled_action
            startStep_action_currentAgent = None
            endStep_action_currentAgent = None


            len_action_target = actionKeyList_target_CurrentAgents.shape[0]

            status_CurrentAgents = {"goal": posGoal,
                                    "start": posStart,#torch.FloatTensor([[0,0]]).to(self.config.device),
                                    "currentState": posStart,
                                    "path_target": path_target,
                                    "action_target": actionKeyList_target_CurrentAgents,
                                    "len_action_target": len_action_target,
                                    "startStep_action_target": startStep_action_currentAgent,
                                    "endStep_action_target": endStep_action_currentAgent,
                                    "path_predict": path_predict,
                                    "nextState_predict": posStart,
                                    "action_predict": list_actionKey_predict,
                                    "disabled_action_predict": disabled_action_predict_currentAgent,
                                    "len_action_predict": len_action_predict,
                                    "startStep_action_predict": startStep_action_currentAgent,
                                    "endStep_action_predict": endStep_action_currentAgent
                                    }
            # print("Agent{} - goal:{} - start:{} - currentState:{}".format(id_agent, posGoal,posStart,posStart))
            name_agent = "agent{}".format(id_agent)
            self.status_MultiAgent.update({name_agent: status_CurrentAgents})

        self.getPathTarget()
        pass

    def findpos(self, channel):
        pos_object = channel.nonzero()
        num_object = pos_object.shape[0]
        pos = torch.zeros(num_object, 2)
        # pos_list = []

        for i in range(num_object):
            pos[i][0] = pos_object[i][0]
            pos[i][1] = pos_object[i][1]
        #     pos_list.append([pos_object[i][0], pos_object[i][1]])
        # pos = torch.FloatTensor(pos_list)
        return pos


    def getPathTarget(self):
        #todo check the length for ground truth, out of index

        list_len_action_target = []
        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)

            len_actionTarget_currentAgent = self.status_MultiAgent[name_agent]["len_action_target"]
            list_len_action_target.append(len_actionTarget_currentAgent)

        maxStep = max(list_len_action_target)

        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)

            pathTarget_currentAgent = self.status_MultiAgent[name_agent]["path_target"]
            currentState_target = self.status_MultiAgent[name_agent]['start']
            goal_currentAgent = self.status_MultiAgent[name_agent]['goal']

            nextState_target = currentState_target
            goalIndexX = int(goal_currentAgent[0][0])
            goalIndexY = int(goal_currentAgent[0][1])


            for step in range(maxStep):

                actionKey_target = self.status_MultiAgent[name_agent]['action_target'][step]

                check_move = (actionKey_target != self.stopKeyValue)
                check_startStep_action = self.status_MultiAgent[name_agent]["startStep_action_target"]

                if check_move == 1 and check_startStep_action is None:
                    self.status_MultiAgent[name_agent]["startStep_action_target"] = step

                else:
                    currentState_target = nextState_target

                action_target = self.delta[actionKey_target]
                nextState_target = torch.add(currentState_target, action_target)

                pathTarget_currentAgent.update({step+1: nextState_target})

                self.status_MultiAgent[name_agent]["path_target"] = pathTarget_currentAgent

                if nextState_target[0][0] == goalIndexX and nextState_target[0][1] == goalIndexY and not self.count_reachgoalTarget[id_agent]:
                    self.count_reachgoalTarget[id_agent] = True
                    self.status_MultiAgent[name_agent]["endStep_action_target"] = step + 1

                self.allReachGoal_Target = all(self.count_reachgoalTarget)

            if self.allReachGoal_Target:
                List_endStep_target = []
                List_startStep_target = []
                self.flowtimeTarget = 0
                for id_agent in range(self.config.num_agents):
                    name_agent = "agent{}".format(id_agent)
                    List_endStep_target.append(self.status_MultiAgent[name_agent]["endStep_action_target"])
                    List_startStep_target.append(self.status_MultiAgent[name_agent]["startStep_action_target"])

                    self.flowtimeTarget += self.status_MultiAgent[name_agent]["endStep_action_target"] - \
                                            self.status_MultiAgent[name_agent]["startStep_action_target"]

                    len_action_predict = self.status_MultiAgent[name_agent]["endStep_action_target"] - \
                                         self.status_MultiAgent[name_agent]["startStep_action_target"]
                    self.status_MultiAgent[name_agent]["len_action_target"] = len_action_predict

                self.makespanTarget = max(List_endStep_target) - min(List_startStep_target)

                # print("Makespane(target):{} \n Flowtime(target): {} \n ").format(self.makespanTarget, self.flowtimeTarget)
                break



    def getOptimalityMetrics(self):
        return [self.makespanPredict, self.makespanTarget], [self.flowtimePredict, self.flowtimeTarget]

    def getMaxstep(self):
        return self.maxstep

    def getMapsize(self):
        return self.size_map

    def initCommunicationRadius(self):
        self.communicationRadius = self.config.commR
        # self.communicationRadius = 5
        # self.communicationRadius = 6
        # self.communicationRadius = 7
        # self.communicationRadius = 8
        # self.communicationRadius = 9
        # self.communicationRadius = 10


    def reachObstacle(self, state):
        reach_obstacle = False

        # name_agent = "agent{}".format(id_agent)
        currentState_predict = state #self.status_MultiAgent[name_agent]["currentState"]
        currentStateIndexX = currentState_predict[0][0]
        currentStateIndexY = currentState_predict[0][1]

        # print(self.channel_map.shape)
        # print(self.channel_map)
        # time.sleep(10)

        if self.channel_map[int(currentStateIndexX)][int(currentStateIndexY)] == 1:
            # print('Reach obstacle.')
            reach_obstacle = True
        else:
            reach_obstacle = False


        # if reach_obstacle:
        #     break
        return reach_obstacle

    def reachEdge(self, state):
        reach_edge = False

        # name_agent = "agent{}".format(id_agent)
        currentState_predict = state #self.status_MultiAgent[name_agent]["currentState"]
        currentStateIndexX = currentState_predict[0][0]
        currentStateIndexY = currentState_predict[0][1]

        if currentStateIndexX >= self.size_map[0] or currentStateIndexX < 0 or currentStateIndexY >= self.size_map[1] or currentStateIndexY < 0:
            # print('Reach edge.')
            reach_edge = True
            # break
        else:
            reach_edge = False
        return reach_edge

    def computeAdjacencyMatrix_fixedCommRadius(self, step, agentPos, CommunicationRadius, graphConnected=False):
        len_TimeSteps = agentPos.shape[0]  # length of timesteps
        nNodes = agentPos.shape[1]  # Number of nodes
        # Create the space to hold the adjacency matrices
        W = np.zeros([len_TimeSteps, nNodes, nNodes])

        # Initial matrix
        distances = squareform(pdist(agentPos[0]))  # nNodes x nNodes

        # I will increase the communication radius by 10% each time,
        # but I have to do it consistently within the while loop,
        # so in order to not affect the first value set of communication radius, I will account for that initial 10% outside


        distances = squareform(pdist(agentPos[0]))  # nNodes x nNodes
        W[0] = (distances < self.communicationRadius).astype(agentPos.dtype)
        W[0] = W[0] - np.diag(np.diag(W[0]))
        graphConnected = graph.isConnected(W[0])
        deg = np.sum(W[0], axis=1)  # nNodes (degree vector)
        zeroDeg = np.nonzero(np.abs(deg) < self.zeroTolerance)[0]
        deg[zeroDeg] = 1.
        invSqrtDeg = np.sqrt(1. / deg)
        invSqrtDeg[zeroDeg] = 0.
        Deg = np.diag(invSqrtDeg)
        W[0] = Deg @ W[0] @ Deg

        return W, self.communicationRadius, graphConnected


    def computeAdjacencyMatrix(self, step, agentPos, CommunicationRadius, graphConnected=False):
        len_TimeSteps = agentPos.shape[0]  # length of timesteps
        nNodes = agentPos.shape[1]  # Number of nodes
        # Create the space to hold the adjacency matrices
        W = np.zeros([len_TimeSteps, nNodes, nNodes])

        # Initial matrix
        distances = squareform(pdist(agentPos[0]))  # nNodes x nNodes


        # I will increase the communication radius by 10% each time,
        # but I have to do it consistently within the while loop,
        # so in order to not affect the first value set of communication radius, I will account for that initial 10% outside
        if step == 0:
            self.communicationRadius = self.communicationRadius / 1.1
            while graphConnected is False:
                self.communicationRadius = self.communicationRadius * 1.1
                W[0] = (distances < self.communicationRadius).astype(agentPos.dtype)
                W[0] = W[0] - np.diag(np.diag(W[0]))
                graphConnected = graph.isConnected(W[0])
            # And once we have found a connected initial position, we normalize it
            deg = np.sum(W[0], axis=1)  # nNodes (degree vector)
            zeroDeg = np.nonzero(np.abs(deg) < self.zeroTolerance)[0]
            deg[zeroDeg] = 1.
            invSqrtDeg = np.sqrt(1. / deg)
            invSqrtDeg[zeroDeg] = 0.
            Deg = np.diag(invSqrtDeg)
            W[0] = Deg @ W[0] @ Deg

        # And once we have found a communication radius that makes the initial graph connected,
        # just follow through with the rest of the times, with that communication radius
        else:
            distances = squareform(pdist(agentPos[0]))  # nNodes x nNodes
            W[0] = (distances < self.communicationRadius).astype(agentPos.dtype)
            W[0] = W[0] - np.diag(np.diag(W[0]))
            graphConnected = graph.isConnected(W[0])
            deg = np.sum(W[0], axis=1)  # nNodes (degree vector)
            zeroDeg = np.nonzero(np.abs(deg) < self.zeroTolerance)[0]
            deg[zeroDeg] = 1.
            invSqrtDeg = np.sqrt(1. / deg)
            invSqrtDeg[zeroDeg] = 0.
            Deg = np.diag(invSqrtDeg)
            W[0] = Deg @ W[0] @ Deg

        return W, self.communicationRadius, graphConnected

    def get_PosAgents(self):
        list_PosAgents = []
        action_CurrentAgents=[]
        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)
            currentState_predict = self.status_MultiAgent[name_agent]["currentState"]
            currentPredictIndexX = int(currentState_predict[0][0])
            currentPredictIndexY = int(currentState_predict[0][1])
            action_CurrentAgents.append([currentPredictIndexX, currentPredictIndexY])
        list_PosAgents.append(action_CurrentAgents)
        return np.asarray(list_PosAgents)
    
    def getGSO(self, step):
        store_PosAgents = self.get_PosAgents()

        if step == 0:
            self.initCommunicationRadius()
        # print("{} - Step-{} - initCommunication Radius:{}".format(self.ID_dataset, step, self.communicationRadius))

        # comm radius fixed
        # GSO, communicationRadius, graphConnected = self.computeAdjacencyMatrix_fixedCommRadius(step, store_PosAgents, self.communicationRadius)

        # comm radius that ensure initial graph connected
        GSO, communicationRadius, graphConnected = self.computeAdjacencyMatrix(step, store_PosAgents, self.communicationRadius)
        GSO_tensor = torch.from_numpy(GSO)

        self.store_GSO.append(GSO)
        self.store_communication_radius.append(communicationRadius)

        # print("################## currentstep {} - size of GSO: {}".format(step, len(self.store_GSO)))
        # print("{} - Step-{} - Communication Radius:{} - graphConnected:{}".format(self.ID_dataset, step, communicationRadius, graphConnected))
        return GSO_tensor

    def getCurrentState__(self):

        tensor_currentState = torch.zeros([1, self.config.num_agents, 3, self.size_map[0], self.size_map[1]])
        # tensor_currentState_all = torch.zeros([1, self.size_map[0], self.size_map[1]])
        for id_agent in range(self.config.num_agents):

            name_agent = "agent{}".format(id_agent)

            goal_CurrentAgent = self.status_MultiAgent[name_agent]["goal"]
            goalIndexX = int(goal_CurrentAgent[0][0])
            goalIndexY = int(goal_CurrentAgent[0][1])
            channel_goal = torch.zeros([self.size_map[0], self.size_map[1]])

            currentState_predict = self.status_MultiAgent[name_agent]["currentState"]
            currentPredictIndexX = int(currentState_predict[0][0])
            currentPredictIndexY = int(currentState_predict[0][1])
            channel_state = torch.zeros([self.size_map[0], self.size_map[1]])

            channel_goal[goalIndexX][goalIndexY] = 1
            channel_state[currentPredictIndexX][currentPredictIndexY] = 1

            tensor_currentState[0, id_agent, 0, :, :] = self.channel_map
            tensor_currentState[0, id_agent, 1, :, :] = channel_goal
            tensor_currentState[0, id_agent, 2, :, :] = channel_state
            # tensor_currentState_allagents = torch.add(tensor_currentState_all, channel_state)
        # print(tensor_currentState_allagents)
        return tensor_currentState


    def getCurrentState(self, return_GPos=False):


        store_goalAgents = torch.zeros([self.config.num_agents, 2])
        store_stateAgents = torch.zeros([self.config.num_agents, 2])

        for id_agent in range(self.config.num_agents):

            name_agent = "agent{}".format(id_agent)

            goal_CurrentAgent = self.status_MultiAgent[name_agent]["goal"]
            goalIndexX = int(goal_CurrentAgent[0][0])
            goalIndexY = int(goal_CurrentAgent[0][1])
            store_goalAgents[id_agent,:] = torch.FloatTensor([goalIndexX,goalIndexY])

            currentState_predict = self.status_MultiAgent[name_agent]["currentState"]
            currentPredictIndexX = int(currentState_predict[0][0])
            currentPredictIndexY = int(currentState_predict[0][1])

            store_stateAgents[id_agent, :] = torch.FloatTensor([currentPredictIndexX, currentPredictIndexY])

        tensor_currentState = self.AgentState.toInputTensor(store_goalAgents, store_stateAgents)
        tensor_currentState = tensor_currentState.unsqueeze(0)
        # print(tensor_currentState_allagents)

        if return_GPos:
            return tensor_currentState, store_stateAgents.unsqueeze(0)
        else:
            return tensor_currentState

    def getCurrentState_(self):

        tensor_currentState = self.AgentState.toInputTensor(self.store_goalAgents, self.store_stateAgents)
        tensor_currentState = tensor_currentState.unsqueeze(0)
        return tensor_currentState

    def interRobotCollision(self):

        # collision = 0
        collision = False

        allagents_pos = {}
        list_pos = []
        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)

            nextstate_currrentAgent = self.status_MultiAgent[name_agent]["nextState_predict"].tolist()
            list_pos.append(nextstate_currrentAgent)
            allagents_pos.update({id_agent: nextstate_currrentAgent})

        for i in range(self.config.num_agents):
            pos = list_pos[i]
            count_collision = list_pos.count(pos)
            if count_collision > 1:
                collision = True
                collided_agents = []

                for id_agent, pos_agent in allagents_pos.items():
                    if pos_agent == pos:
                        name_agent = "agent{}".format(id_agent)
                        collided_agents.append(name_agent)

                # id_agent2move = max(heuristic_agents.items(), key=operator.itemgetter(1))[0]
                id_agent2move = random.choice(collided_agents)
                # print("In {}, {} need to move".format(collided_agents, id_agent2move))
                for name_agent in collided_agents:
                    # name_agent = "agent{}".format(id_agent)
                    # print("The action list of {}:\n{}".format(name_agent,self.status_MultiAgent[name_agent]["action_predict"]))
                    list_actionKey_predict = self.status_MultiAgent[name_agent]["action_predict"]

                    if list_actionKey_predict[-1] == self.stopKeyValue:
                        # print('##### one of the agent has stoppted.#####')
                        for name_agent in collided_agents:
                            list_actionKey_predict = self.status_MultiAgent[name_agent]["action_predict"]
                            list_actionKey_predict[-1] = self.stopKeyValue
                            self.status_MultiAgent[name_agent]["action_predict"] = list_actionKey_predict
                            self.status_MultiAgent[name_agent]["nextState_predict"] = self.status_MultiAgent[name_agent]["currentState"]
                            # print("All agents {} stops:\n{}\n".format(name_agent,self.status_MultiAgent[name_agent]["action_predict"]))
                            id_agent = int(name_agent.replace("agent",""))
                            list_pos[id_agent] =  self.status_MultiAgent[name_agent]["nextState_predict"].tolist()
                    else:

                        if name_agent != id_agent2move:

                            list_actionKey_predict = self.status_MultiAgent[name_agent]["action_predict"]
                            list_actionKey_predict[-1] = self.stopKeyValue
                            self.status_MultiAgent[name_agent]["action_predict"] = list_actionKey_predict
                            self.status_MultiAgent[name_agent]["nextState_predict"] = self.status_MultiAgent[name_agent]["currentState"]
                            id_agent = int(name_agent.replace("agent",""))
                            list_pos[id_agent] =  self.status_MultiAgent[name_agent]["nextState_predict"].tolist()
                        #     print('{} stop'.format(name_agent))
                        # print("The action list of {} after changed:\n{}\n".format(name_agent, self.status_MultiAgent[name_agent]["action_predict"]))

        ## position swap
        list_nextpos = []

        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)

            nextstate_currrentAgent = self.status_MultiAgent[name_agent]["nextState_predict"].tolist()
            list_nextpos.append(nextstate_currrentAgent)

        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)
            currentstate_currrentAgent = self.status_MultiAgent[name_agent]["currentState"].tolist()
            if currentstate_currrentAgent in list_nextpos:
                id_agent_swap = list_nextpos.index(currentstate_currrentAgent)
                name_agent_swap = "agent{}".format(id_agent_swap)
                if name_agent_swap != name_agent:
                    if self.status_MultiAgent[name_agent_swap]["currentState"].tolist() == self.status_MultiAgent[name_agent]["nextState_predict"].tolist():
                        # print("In #{}case(test), {} and {} swap position happens.".format(self.ID_dataset,name_agent,name_agent_swap))
                        self.status_MultiAgent[name_agent]["nextState_predict"] = self.status_MultiAgent[name_agent]["currentState"]
                        self.status_MultiAgent[name_agent_swap]["nextState_predict"] = self.status_MultiAgent[name_agent_swap]["currentState"]

                        id_agent = int(name_agent.replace("agent",""))
                        list_pos[id_agent] =  self.status_MultiAgent[name_agent]["nextState_predict"].tolist()

                        id_agent_swap = int(name_agent_swap.replace("agent",""))
                        list_pos[id_agent_swap] =  self.status_MultiAgent[name_agent_swap]["nextState_predict"].tolist()

                        self.status_MultiAgent[name_agent]["action_predict"][-1] = self.stopKeyValue
                        self.status_MultiAgent[name_agent_swap]["action_predict"][-1] = self.stopKeyValue

                        collision = True


        return collision

    def heuristic(self, current_pos, goal):

        value = abs(goal[0] - current_pos[0]) + abs(goal[1] - current_pos[1])
        return value

    def move(self, actionVec, currentstep, rouge_agent_count=None):

        allReachGoal = all(self.count_reachgoal)
        allReachGoal_withoutcollision = False

        self.check_predictCollsion = False
        self.check_moveCollision = False

        if (not allReachGoal) or (currentstep < self.maxstep):
        # if not allReachGoal and currentstep < self.maxstep:
            for id_agent in range(self.config.num_agents):
                name_agent = "agent{}".format(id_agent)

                # disabled_actionPredict_currentAgent = self.status_MultiAgent[name_agent]["disabled_action_predict"]
                # if self.config.num_agents == 1:
                #     actionVec_predict_CurrentAgents = torch.mul(self.fun_Softmax(actionVec),
                #                                                 disabled_actionPredict_currentAgent)
                # else:
                #     # actionVec_tmp = actionVec[id_agent]
                #     actionVec_current = self.fun_Softmax(actionVec[id_agent])
                #     actionVec_predict_CurrentAgents = torch.mul(actionVec_current, disabled_actionPredict_currentAgent)

                actionVec_current = self.fun_Softmax(actionVec[id_agent])
                if rouge_agent_count and id_agent < rouge_agent_count and currentstep % 2 == 1:
                    # if this is a rouge agent, randomly select an action every other step
                    actionKey_predict = torch.tensor([np.random.randint(5)])
                else:
                    actionKey_predict = torch.max(actionVec_current, 1)[1]

                # set flag of the timestep that agent start to move
                check_move = (actionKey_predict != self.stopKeyValue)

                startStep_action = self.status_MultiAgent[name_agent]["startStep_action_predict"]
                if check_move == 1 and startStep_action is None:
                    self.status_MultiAgent[name_agent]["startStep_action_predict"] = currentstep - 1

                list_actionKey_predict = self.status_MultiAgent[name_agent]["action_predict"]

                currentState_predict = self.status_MultiAgent[name_agent]["currentState"]
                nextState_predict = torch.add(currentState_predict, self.delta[actionKey_predict])

                # ----- check edge and obstacle
                checkEdge = self.reachEdge(nextState_predict)
                checkObstacle = False
                if not checkEdge:
                    checkObstacle = self.reachObstacle(nextState_predict)

                if checkEdge or checkObstacle:
                    # print('Reach obstacle or edge.')
                    # break
                    # todo : remove the collision motion disabled
                    # disabled_actionPredict_currentAgent[actionKey_predict] = 0.0
                    # self.status_MultiAgent[name_agent]["disabled_action_predict"] = disabled_actionPredict_currentAgent
                    # self.move(actionVec, currentstep)
                    self.check_predictCollsion = True

                    list_actionKey_predict.append(self.stopKeyValue)
                    self.status_MultiAgent[name_agent]["action_predict"] = list_actionKey_predict
                    self.status_MultiAgent[name_agent]["nextState_predict"] = currentState_predict
                    # self.check_predictEdgeCollsion
                else:
                    # self.status_MultiAgent[name_agent]["currentState"] = nextState_predict
                    self.status_MultiAgent[name_agent]["nextState_predict"] = nextState_predict
                    # self.status_MultiAgent[name_agent]["disabled_action_predict"] = self.reset_disabled_action

                    list_actionKey_predict.append(actionKey_predict[0])
                    self.status_MultiAgent[name_agent]["action_predict"] = list_actionKey_predict

                # if not self.check_predictCollsion:
            detect_interRobotCollision = self.interRobotCollision()

            # while detect_interRobotCollision:
            for _ in range(self.config.num_agents):
                # print('Collision happens.')
                if detect_interRobotCollision:
                    detect_interRobotCollision = self.interRobotCollision()
                    self.check_predictCollsion = True
                    # print("Collision happens")
                else:
                    # print("Collision avoided by collision shielding")
                    break

            self.check_moveCollision = self.interRobotCollision()

            for id_agent in range(self.config.num_agents):
                name_agent = "agent{}".format(id_agent)
                nextState_predict = self.status_MultiAgent[name_agent]["nextState_predict"]

                self.status_MultiAgent[name_agent]["currentState"] = nextState_predict
                # self.store_stateAgents[id_agent,:] = nextState_predict
                path_predict = self.status_MultiAgent[name_agent]["path_predict"]
                path_predict.update({currentstep: nextState_predict})
                # print("################## Current Step:{}  - size of path_predict: {}".format(currentstep, len(path_predict)))
                self.status_MultiAgent[name_agent]["path_predict"] = path_predict

                goal_CurrentAgent = self.status_MultiAgent[name_agent]["goal"]
                goalIndexX = int(goal_CurrentAgent[0][0])
                goalIndexY = int(goal_CurrentAgent[0][1])

                if nextState_predict[0][0] == goalIndexX and nextState_predict[0][1] == goalIndexY and not \
                self.count_reachgoal[id_agent]:
                    self.count_reachgoal[id_agent] = True
                    self.status_MultiAgent[name_agent]["endStep_action_predict"] = currentstep
                if currentstep >= (self.maxstep) and not self.count_reachgoal[id_agent]:
                    # self.count_reachgoal[id_agent] = False
                    self.status_MultiAgent[name_agent]["endStep_action_predict"] = currentstep
                    # print("\t \t {} - status(Reach Goal) - {}".format(name_agent, self.count_reachgoal[id_agent]))
                    if self.status_MultiAgent[name_agent]["startStep_action_predict"] is None:
                        self.status_MultiAgent[name_agent]["startStep_action_predict"] =  0 #currentstep #

        if allReachGoal or (currentstep >= self.maxstep):
            List_endStep = []
            List_startStep = []
            self.flowtimePredict = 0
            if rouge_agent_count:
                self.nonRogueFlowtimePredict = 0
            for id_agent in range(self.config.num_agents):
                name_agent = "agent{}".format(id_agent)
                List_endStep.append(self.status_MultiAgent[name_agent]["endStep_action_predict"])
                List_startStep.append(self.status_MultiAgent[name_agent]["startStep_action_predict"])

                if self.config.distribution == "default":
                    self.flowtimePredict += self.status_MultiAgent[name_agent]["endStep_action_predict"] - \
                                            self.status_MultiAgent[name_agent]["startStep_action_predict"]
                    # keep track of flowtime of non rogue agents
                    if rouge_agent_count and id_agent >= rouge_agent_count:
                        self.nonRogueFlowtimePredict += self.status_MultiAgent[name_agent]["endStep_action_predict"] - \
                                            self.status_MultiAgent[name_agent]["startStep_action_predict"]

                    len_action_predict = self.status_MultiAgent[name_agent]["endStep_action_predict"] - \
                                        self.status_MultiAgent[name_agent]["startStep_action_predict"]
                    self.status_MultiAgent[name_agent]["len_action_predict"] = len_action_predict

            if self.config.distribution == "default":
                self.makespanPredict = max(List_endStep) - min(List_startStep)
            if rouge_agent_count:
                self.nonRogueFlowtimePredict /= (self.config.num_agents - rouge_agent_count)


            # if not self.check_predictCollsion:
            # if not self.check_moveCollision:
            #     print("##################### Find collisionfreeSol #############################")
            #
            #     allReachGoal_withoutcollision = True
            # print("Makespane(predict):{} \n Flowtime(predict): {} \n ".format(self.makespanPredict , self.flowtimePredict))

        return allReachGoal, self.check_moveCollision, self.check_predictCollsion

    def count_numAgents_ReachGoal(self):
        return self.count_reachgoal.count(True)

    def count_GSO_communcationRadius(self, step):
        _ = self.getGSO(step)
        return self.store_GSO, self.store_communication_radius

    def save_failure_cases(self):

        inputfile_name = os.path.join(self.failureCases_input,'failureCases_ID{:05d}.yaml'.format(self.ID_dataset))
        print('############## failureCases in training set ID{} ###############'.format(self.ID_dataset))
        f = open(inputfile_name, 'w')
        f.write("map:\n")
        f.write("    dimensions: {}\n".format([self.size_map[0], self.size_map[1]]))
        f.write("    obstacles:\n")
        for ID_obs in range(self.numObstacle):
            obstacleIndexX = int(self.posObstacle[ID_obs][0].cpu().detach().numpy())
            obstacleIndexY = int(self.posObstacle[ID_obs][1].cpu().detach().numpy())
            list_obs = [obstacleIndexX,obstacleIndexY]
            f.write("    - {}\n".format(list_obs))
        f.write("agents:\n")
        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)
            log_goal_currentAgent = self.status_MultiAgent[name_agent]["goal"].cpu().detach().numpy()
            log_currentState_currentAgent = self.status_MultiAgent[name_agent]["nextState_predict"].cpu().detach().numpy()
            goalX = int(log_goal_currentAgent[0][0])
            goalY = int(log_goal_currentAgent[0][1])
            currentStateX = int(log_currentState_currentAgent[0][0])
            currentStateY = int(log_currentState_currentAgent[0][1])
            goal_currentAgent = [goalX, goalY]
            currentState_currentAgent = [currentStateX, currentStateY]
            f.write("  - name: agent{}\n    start: {}\n    goal: {}\n".format(id_agent, currentState_currentAgent, goal_currentAgent))
        f.close()

    def save_success_cases(self, mode):

        inputfile_name = os.path.join(self.config.result_AnimeDemo_dir_input, '{}Cases_ID{:05d}.yaml'.format(mode, self.ID_dataset))
        if mode == 'success':
            outputfile_name = os.path.join(self.config.result_AnimeDemo_dir_predict_success, '{}Cases_ID{:05d}.yaml'.format(mode,self.ID_dataset))
            checkSuccess = 1
        else:
            outputfile_name = os.path.join(self.config.result_AnimeDemo_dir_predict_failure,
                                           '{}Cases_ID{:05d}.yaml'.format(mode, self.ID_dataset))
            checkSuccess = 0

        targetfile_name = os.path.join(self.config.result_AnimeDemo_dir_target,
                                       '{}Cases_ID{:05d}.yaml'.format(mode, self.ID_dataset))

        gsofile_name = os.path.join(self.config.result_AnimeDemo_dir_GSO,
                                       '{}Cases_ID{:05d}.mat'.format(mode, self.ID_dataset))

        save_statistics_GSO = {'gso':self.store_GSO, 'commRadius': self.store_communication_radius}
        sio.savemat(gsofile_name, save_statistics_GSO)

        # print('############## successCases in training set ID{} ###############'.format(self.ID_dataset))
        f = open(inputfile_name, 'w')
        f.write("map:\n")
        f.write("    dimensions: {}\n".format([self.size_map[0], self.size_map[1]]))
        f.write("    obstacles:\n")
        for ID_obs in range(self.numObstacle):
            obstacleIndexX = int(self.posObstacle[ID_obs][0].cpu().detach().numpy())
            obstacleIndexY = int(self.posObstacle[ID_obs][1].cpu().detach().numpy())
            list_obs = [obstacleIndexX, obstacleIndexY]
            f.write("    - {}\n".format(list_obs))
        f.write("agents:\n")
        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)
            log_goal_currentAgent = self.status_MultiAgent[name_agent]["goal"].cpu().detach().numpy()
            log_currentState_currentAgent = self.status_MultiAgent[name_agent]["start"].cpu().detach().numpy()
            goalX = int(log_goal_currentAgent[0][0])
            goalY = int(log_goal_currentAgent[0][1])
            startX = int(log_currentState_currentAgent[0][0])
            startY = int(log_currentState_currentAgent[0][1])
            goal_currentAgent = [goalX, goalY]
            currentState_currentAgent = [startX, startY]
            f.write("  - name: agent{}\n    start: {}\n    goal: {}\n".format(id_agent, currentState_currentAgent,
                                                                              goal_currentAgent))
        f.close()

        f_sol = open(outputfile_name, 'w')
        f_sol.write("statistics:\n")
        f_sol.write("    cost: {}\n".format(self.flowtimePredict))
        f_sol.write("    makespan: {}\n".format(self.makespanPredict))
        f_sol.write("    succeed: {}\n".format(int(checkSuccess)))
        f_sol.write("schedule:\n")

        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)
            # print(self.status_MultiAgent[name_agent]["path_predict"])
            path = self.status_MultiAgent[name_agent]["path_predict"]

            len_path = len(path)

            f_sol.write("    agent{}:\n".format(id_agent))
            for step in range(len_path):


                pathIndexX = int(path[step][0][0].cpu().detach().numpy())
                pathIndexY = int(path[step][0][1].cpu().detach().numpy())

                f_sol.write("       - x: {}\n         y: {}\n         t: {}\n".format(pathIndexX,pathIndexY, step))
        f_sol.close()

        f_target = open(targetfile_name, 'w')
        f_target.write("statistics:\n")
        f_target.write("    cost: {}\n".format(self.flowtimeTarget))
        f_target.write("    makespan: {}\n".format(self.makespanTarget))
        f_target.write("schedule:\n")

        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)
            # print(self.status_MultiAgent[name_agent]["path_predict"])
            path = self.status_MultiAgent[name_agent]["path_target"]

            len_path = len(path)

            f_target.write("    agent{}:\n".format(id_agent))
            for step in range(len_path):
                pathIndexX = int(path[step][0][0].cpu().detach().numpy())
                pathIndexY = int(path[step][0][1].cpu().detach().numpy())

                f_target.write("       - x: {}\n         y: {}\n         t: {}\n".format(pathIndexX, pathIndexY, step))
        f_target.close()

    def createfolder_failure_cases(self):

        self.failureCases_input = self.config.failCases_dir + 'input/'
        self.dir_sol = os.path.join(self.config.failCases_dir, "output_ECBS/")

        if os.path.exists(self.failureCases_input) and os.path.isdir(self.failureCases_input):
            shutil.rmtree(self.failureCases_input)
        if os.path.exists(self.dir_sol) and os.path.isdir(self.dir_sol):
            shutil.rmtree(self.dir_sol)
        try:
            # Create target Directory
            os.makedirs(self.failureCases_input)
        except FileExistsError:
            # print("Directory ", dirName, " already exists")
            pass


    def checkOptimality(self, collisionFreeSol):


        if self.makespanPredict <= self.makespanTarget and self.flowtimePredict <= self.flowtimeTarget and collisionFreeSol:
            findOptimalSolution = True
        else:
            findOptimalSolution = False

        return findOptimalSolution, [self.makespanPredict, self.makespanTarget], [self.flowtimePredict, self.flowtimeTarget]

    def draw(self, ID_dataset):
        status_MultiAgent = {}
        status_MultiAgent_Target = {}
        status_MultiAgent_Predict = {}
        for id_agent in range(self.config.num_agents):
            name_agent = "agent{}".format(id_agent)
            status_CurrentAgents_Target = {"goal": self.status_MultiAgent[name_agent]["goal"],
                                           "start": self.status_MultiAgent[name_agent]["start"],
                                           "path": self.status_MultiAgent[name_agent]["path_target"],
                                           "action": self.status_MultiAgent[name_agent]["action_target"],
                                           "len_action": self.status_MultiAgent[name_agent]["len_action_target"]
                                           }
            status_CurrentAgents_Predict = {"goal": self.status_MultiAgent[name_agent]["goal"],
                                            "start": self.status_MultiAgent[name_agent]["start"],
                                            "path": self.status_MultiAgent[name_agent]["path_predict"],
                                            "action": self.status_MultiAgent[name_agent]["action_predict"],
                                            "len_action": self.status_MultiAgent[name_agent]["len_action_predict"]
                                            }

            status_MultiAgent_Target.update({name_agent: status_CurrentAgents_Target})
            status_MultiAgent_Predict.update({name_agent: status_CurrentAgents_Predict})

        status_MultiAgent_Target.update({"makespan": self.makespanTarget, "flowtime": self.flowtimeTarget})
        status_MultiAgent_Predict.update({"makespan": self.makespanPredict, "flowtime": self.flowtimePredict})

        status_MultiAgent.update({"target":status_MultiAgent_Target, "predict":status_MultiAgent_Predict})
        draw = DrawpathCombine(self.config, self.channel_map, self.posObstacle, status_MultiAgent)

        draw.draw(ID_dataset)
        draw.save()
class DataTransformer:
    def __init__(self, config):
        self.config = config
        self.PROCESS_NUMBER = 4
        self.num_agents = self.config.num_agents
        self.size_map = [self.config.map_w, self.config.map_w]
        self.label_density = str(self.config.map_density).split('.')[-1]
        self.AgentState = AgentState(self.num_agents)
        self.communicationRadius = 5 # communicationRadius
        self.zeroTolerance = 1e-9
        self.delta = [[-1, 0],  # go up
                 [0, -1],  # go left
                 [1, 0],  # go down
                 [0, 1],  # go right
                 [0, 0]]  # stop
        self.num_actions = 5

        self.list_seqtrain_file = []
        self.list_train_file = []
        self.list_seqvalid_file = []
        self.list_validStep_file = []
        self.list_valid_file = []
        self.list_test_file = []
        self.hashids = Hashids(alphabet='01234567789abcdef', min_length=5)
        self.pathtransformer = self.pathtransformer_RelativeCoordinate
        self.label_setup = '{}{:02d}x{:02d}_density_p{}/{}_Agent'.format(self.config.loadmap_TYPE, self.size_map[0],self.size_map[1],
                                                                                          self.label_density,
                                                                                          self.num_agents)
        self.dirName_parent = os.path.join(self.config.solCases_dir, self.label_setup)
        self.dirName_Store = os.path.join(self.config.dir_SaveData, self.label_setup)
        self.dirName_input = os.path.join(self.dirName_parent, 'input')
        self.dirName_output = os.path.join(self.dirName_parent, 'output_{}'.format(config.chosen_solver))
        self.set_up()

    def set_up(self):

        self.list_failureCases_solution = self.search_failureCases(self.dirName_output)
        self.list_failureCases_input = self.search_failureCases(self.dirName_input)
        self.nameprefix_input = self.list_failureCases_input[0].split('input/')[-1].split('ID')[0]
        self.list_failureCases_solution = sorted(self.list_failureCases_solution)
        self.len_failureCases_solution = len(self.list_failureCases_solution)


    def reset(self):

        self.task_queue = Queue()
        dirpath = self.dirName_Store
        if os.path.exists(dirpath) and os.path.isdir(dirpath):
            shutil.rmtree(dirpath)
        self.path_save_solDATA = self.dirName_Store

        try:
            # Create target Directory
            os.makedirs(self.path_save_solDATA)
            os.makedirs(os.path.join(self.path_save_solDATA, 'train'))
            os.makedirs(os.path.join(self.path_save_solDATA, 'valid'))
            os.makedirs(os.path.join(self.path_save_solDATA, 'test'))
        except FileExistsError:
            # print("Directory ", dirName, " already exists")
            pass


    def solutionTransformer(self):

        # div_train = 21000
        # div_valid = 61
        # div_test = 4500

        # div_train = 0
        # div_valid = 0
        # div_test = 1500

        div_train = self.config.div_train
        div_valid = self.config.div_valid
        div_test = self.config.div_test
        # div_train = 5
        # div_valid = 2
        # div_test = 2

        num_used_data = div_train + div_valid + div_test

        num_data_loop = min(num_used_data, self.len_failureCases_solution)
        # for id_sol in range(num_data_loop):
        for id_sol in range(self.config.id_start, num_data_loop):
            if id_sol < div_train:
                mode = "train"
                case_config = (mode, id_sol)
                self.task_queue.put(case_config)
            elif id_sol < (div_train+div_valid):
                mode = "valid"
                case_config = (mode, id_sol)
                self.task_queue.put(case_config)
            elif id_sol <= num_used_data:
                mode = "test"
                case_config = (mode, id_sol)
                self.task_queue.put(case_config)

        time.sleep(0.3)
        processes = []
        for i in range(self.PROCESS_NUMBER):
            # Run Multiprocesses
            p = Process(target=self.compute_thread, args=(str(i)))

            processes.append(p)

        [x.start() for x in processes]


    def compute_thread(self,thread_id):
        while True:
            try:
                case_config = self.task_queue.get(block=False)
                (mode, id_sol) = case_config
                print('thread {} get task:{} - {}'.format(thread_id, mode, id_sol))
                self.pipeline(case_config)

            except:
                # print('thread {} no task, exit'.format(thread_id))
                return

    def pipeline(self, case_config):
        (mode, id_sol) = case_config
        agents_schedule, agents_goal, makespan, map_data, id_case = self.load_ExpertSolution(id_sol)
        # agents_schedule, agents_goal, makespan, map_data, id_case = self.load_ExpertSolution_(id_sol)
        log_str = 'Transform_failureCases_ID_#{} in MAP_ID{}'.format(id_case[1],id_case[0])
        print('############## {} ###############'.format(log_str))
        # print(agents_schedule)
        if mode == "train" or mode == "valid":
            self.pathtransformer(map_data, agents_schedule, agents_goal, makespan + 1, id_case, mode)
        else:
            self.pathtransformer_test(map_data, agents_schedule, agents_goal, makespan + 1, id_case, mode)

        

    def load_ExpertSolution(self, ID_case):

        name_solution_file = self.list_failureCases_solution[ID_case]
        # id_solved_case = name_solution_file.split('_ID')[-1].split('.yaml')[0]
        map_setup = name_solution_file.split('output_')[-1].split('_IDMap')[0]
        id_sol_map = name_solution_file.split('_IDMap')[-1].split('_IDCase')[0]
        id_sol_case = name_solution_file.split('_IDCase')[-1].split('_')[0]

        name_inputfile = os.path.join(self.dirName_input,
                                      'input_{}_IDMap{}_IDCase{}.yaml'.format(map_setup, id_sol_map, id_sol_case))

        # print(name_inputfile)
        # print(name_solution_file)

        with open(name_inputfile, 'r') as stream:
            try:
                # print(yaml.safe_load(stream))
                data_config = yaml.safe_load(stream)
            except yaml.YAMLError as exc:
                print(exc)
        with open(name_solution_file, 'r') as stream:
            try:
                # print(yaml.safe_load(stream))
                data_output = yaml.safe_load(stream)
            except yaml.YAMLError as exc:
                print(exc)

        agentsConfig = data_config['agents']
        num_agent = len(agentsConfig)
        list_posObstacle = data_config['map']['obstacles']

        if list_posObstacle == None:
            map_data = np.zeros(self.size_map, dtype=np.int64)
        else:
            map_data = self.setup_map(list_posObstacle)
        
        schedule = data_output['schedule']
        makespan = data_output['statistics']['makespan']


        goal_allagents = np.zeros([num_agent, 2])
        schedule_agentsState = np.zeros([makespan + 1, num_agent, 2])
        schedule_agentsActions = np.zeros([makespan + 1, num_agent, self.num_actions])
        schedule_agents = [schedule_agentsState, schedule_agentsActions]
        hash_ids = np.zeros(self.num_agents)
        for id_agent in range(num_agent):
            goalX = agentsConfig[id_agent]['goal'][0]
            goalY = agentsConfig[id_agent]['goal'][1]
            goal_allagents[id_agent][:] = [goalX, goalY]

            schedule_agents = self.obtainSchedule(id_agent, schedule, schedule_agents, goal_allagents, makespan + 1)

            str_id = '{}_{}_{}'.format(id_sol_map,id_sol_case,id_agent)
            int_id = int(hashlib.sha256(str_id.encode('utf-8')).hexdigest(), 16) % (10 ** 5)
            # hash_ids[id_agent]=np.divide(int_id,10**5)
            hash_ids[id_agent] = int_id

        # print(id_sol_map, id_sol_case, hash_ids)
        return schedule_agents, goal_allagents, makespan, map_data, (id_sol_map, id_sol_case, hash_ids)

    def load_ExpertSolution_(self, ID_case):

        name_solution_file = self.list_failureCases_solution[ID_case]
        id_sol_case = name_solution_file.split('_ID')[-1].split('.yaml')[0]

        map_setup = 'demo'
        id_sol_map = '0'

        name_inputfile = os.path.join(self.dirName_input,
                                      'failureCases_ID{}.yaml'.format(id_sol_case))

        # print(name_inputfile)
        # print(name_solution_file)

        with open(name_inputfile, 'r') as stream:
            try:
                # print(yaml.safe_load(stream))
                data_config = yaml.safe_load(stream)
            except yaml.YAMLError as exc:
                print(exc)
        with open(name_solution_file, 'r') as stream:
            try:
                # print(yaml.safe_load(stream))
                data_output = yaml.safe_load(stream)
            except yaml.YAMLError as exc:
                print(exc)

        agentsConfig = data_config['agents']
        num_agent = len(agentsConfig)
        list_posObstacle = data_config['map']['obstacles']

        if list_posObstacle == None:
            map_data = np.zeros(self.size_map, dtype=np.int64)
        else:
            map_data = self.setup_map(list_posObstacle)

        schedule = data_output['schedule']
        makespan = data_output['statistics']['makespan']

        # print(schedule)
        goal_allagents = np.zeros([num_agent, 2])
        schedule_agentsState = np.zeros([makespan + 1, num_agent, 2])
        schedule_agentsActions = np.zeros([makespan + 1, num_agent, self.num_actions])
        schedule_agents = [schedule_agentsState, schedule_agentsActions]
        hash_ids = np.zeros(self.num_agents)
        for id_agent in range(num_agent):
            goalX = agentsConfig[id_agent]['goal'][0]
            goalY = agentsConfig[id_agent]['goal'][1]
            goal_allagents[id_agent][:] = [goalX, goalY]

            schedule_agents = self.obtainSchedule(id_agent, schedule, schedule_agents, goal_allagents, makespan + 1)

            str_id = '{}_{}_{}'.format(id_sol_map, id_sol_case, id_agent)
            int_id = int(hashlib.sha256(str_id.encode('utf-8')).hexdigest(), 16) % (10 ** 5)
            # hash_ids[id_agent]=np.divide(int_id,10**5)
            hash_ids[id_agent] = int_id
        print(schedule_agents)
        # print(id_sol_map, id_sol_case, hash_ids)
        return schedule_agents, goal_allagents, makespan, map_data, (id_sol_map, id_sol_case, hash_ids)

    def obtainSchedule(self, id_agent, agentplan, schedule_agents, goal_allagents, teamMakeSpan):

        name_agent = "agent{}".format(id_agent)
        [schedule_agentsState, schedule_agentsActions] = schedule_agents
        
        planCurrentAgent = agentplan[name_agent]
        pathLengthCurrentAgent = len(planCurrentAgent)

        actionKeyListAgent = []

        for step in range(teamMakeSpan):
            if step < pathLengthCurrentAgent:
                currentX = planCurrentAgent[step]['x']
                currentY = planCurrentAgent[step]['y']
            else:
                currentX = goal_allagents[id_agent][0]
                currentY = goal_allagents[id_agent][1]
                
            schedule_agentsState[step][id_agent][:] = [currentX, currentY]
            # up left down right stop
            actionVectorTarget = [0, 0, 0, 0, 0]

            # map action with respect to the change of position of agent
            if step < (pathLengthCurrentAgent - 1):
                nextX = planCurrentAgent[step + 1]['x']
                nextY = planCurrentAgent[step + 1]['y']
                # actionCurrent = [nextX - currentX, nextY - currentY]

            elif step >= (pathLengthCurrentAgent - 1):
                nextX = goal_allagents[id_agent][0]
                nextY = goal_allagents[id_agent][1]

            actionCurrent = [nextX - currentX, nextY - currentY]


            actionKeyIndex = self.delta.index(actionCurrent)
            actionKeyListAgent.append(actionKeyIndex)

            actionVectorTarget[actionKeyIndex] = 1
            schedule_agentsActions[step][id_agent][:] = actionVectorTarget


        return [schedule_agentsState,schedule_agentsActions]

    def setup_map(self, list_posObstacle):
        num_obstacle = len(list_posObstacle)
        map_data = np.zeros(self.size_map)
        for ID_obs in range(num_obstacle):
            obstacleIndexX = list_posObstacle[ID_obs][0]
            obstacleIndexY = list_posObstacle[ID_obs][1]
            map_data[obstacleIndexX][obstacleIndexY] = 1

        return map_data



    def pathtransformer_RelativeCoordinate(self, map_data, agents_schedule, agents_goal, makespan, ID_case, mode):
        # input: start and goal position,
        # output: a set of file,
        #         each file consist of state (map. goal, state) and target (action for current state)
        [schedule_agentsState, schedule_agentsActions] = agents_schedule
        save_PairredData = {}
        # print(ID_case)
        # compute AdjacencyMatrix
        GSO, communicationRadius = self.computeAdjacencyMatrix(schedule_agentsState, self.communicationRadius)

        # transform into relative Coordinate, loop "makespan" times
        self.AgentState.setmap(map_data)
        input_seq_tensor = self.AgentState.toSeqInputTensor(agents_goal, schedule_agentsState, makespan)

        list_input = input_seq_tensor.cpu().detach().numpy()
        save_PairredData.update({'map': map_data, 'goal': agents_goal, 'inputState': schedule_agentsState,
                                 'inputTensor': list_input, 'target': schedule_agentsActions,
                                 'GSO': GSO,'makespan':makespan, 'HashIDs':ID_case[2]})
        # print(save_PairredData)
        self.save(mode, save_PairredData, ID_case, makespan)
        print("Save as  {}set_#{} from MAP ID_{}.".format(mode, ID_case[1], ID_case[0]))

    def pathtransformer_test(self, map_data, agents_schedule, agents_goal, makespan, ID_case, mode):
        # input: start and goal position,
        # output: a set of file,
        #         each file consist of state (map. goal, state) and target (action for current state)

        [schedule_agentsState, schedule_agentsActions] = agents_schedule
        save_PairredData = {}
        save_PairredData.update({'map': map_data, 'goal': agents_goal,
                                 'inputState': schedule_agentsState[0],
                                 'target': schedule_agentsActions,
                                 'makespan': makespan, 'HashIDs':ID_case[2]})
        # print(save_PairredData)
        self.save(mode, save_PairredData, ID_case, makespan)
        print("Save as  {}set_#{} from MAP ID_{}.".format(mode, ID_case[1], ID_case[0]))

    def save(self, mode, save_PairredData, ID_case, makespan):

        (id_sol_map, id_sol_case,_) = ID_case

        file_name = os.path.join(self.path_save_solDATA, mode,'{}_IDMap{}_IDCase{}_MP{}.mat'.format(mode, id_sol_map, id_sol_case, makespan))
        # print(file_name)

        sio.savemat(file_name, save_PairredData)

    def record_pathdata(self, mode, ID_case, makespan):
        (id_sol_map, id_sol_case) = ID_case
        data_name_mat = '{}_IDMap{}_IDCase{}_MP{}.mat'.format(mode, id_sol_map, id_sol_case, makespan)

        if mode == "train":
            self.list_seqtrain_file.append([data_name_mat, makespan, 0])
            # print("\n train --", self.list_seqtrain_file)
            for step in range(makespan):
                self.list_train_file.append([data_name_mat, step, 0])
        elif mode =='validStep':
            self.list_seqvalid_file.append([data_name_mat, makespan, 0])
            for step in range(makespan):
                self.list_validStep_file.append([data_name_mat, step, 0])
        elif mode == "valid":
            self.list_valid_file.append([data_name_mat, makespan, 0]) # 0
        elif mode == "test":
            self.list_test_file.append([data_name_mat, makespan, 0]) # 0

    def save_filepath(self):
        dirName = self.path_save_solDATA

        file_seqtrain_name = os.path.join(dirName,'{}seq_filename.csv'.format('train'))
        with open(file_seqtrain_name, "w", newline="") as f:
            writer = csv.writer(f)
            print("\n train hello --", self.list_seqtrain_file)
            writer.writerows(self.list_seqtrain_file)

        file_train_name = os.path.join(dirName,'{}_filename.csv'.format('train'))
        with open(file_train_name, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerows(self.list_train_file)

        file_seqvalid_name = os.path.join(dirName,'{}seq_filename.csv'.format('valid'))
        with open(file_seqvalid_name, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerows(self.list_seqvalid_file)

        file_validStep_name = os.path.join(dirName,'{}_filename.csv'.format('validStep'))
        with open(file_validStep_name, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerows(self.list_validStep_file)

        file_valid_name = os.path.join(dirName,'{}_filename.csv'.format('valid'))
        with open(file_valid_name, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerows(self.list_valid_file)

        file_test_name = os.path.join(dirName,'{}_filename.csv'.format('test'))
        with open(file_test_name, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerows(self.list_test_file)

    def search_failureCases(self, dir):
        # make a list of file name of input yaml
        list_path = []
        assert os.path.isdir(dir), '%s is not a valid directory' % dir

        for root, _, fnames in sorted(os.walk(dir)):
            for fname in fnames:
                if self.is_target_file(fname):
                    path = os.path.join(root, fname)
                    list_path.append(path)

        return list_path

    def is_target_file(self, filename):
        DATA_EXTENSIONS = ['.yaml']
        return any(filename.endswith(extension) for extension in DATA_EXTENSIONS)

    def computeAdjacencyMatrix(self, pos, CommunicationRadius, connected=True):

        # First, transpose the axis of pos so that the rest of the code follows
        # through as legible as possible (i.e. convert the last two dimensions
        # from 2 x nNodes to nNodes x 2)
        # pos: TimeSteps x nAgents x 2 (X, Y)

        # Get the appropriate dimensions
        nSamples = pos.shape[0]
        len_TimeSteps = pos.shape[0]  # length of timesteps
        nNodes = pos.shape[1]  # Number of nodes
        # Create the space to hold the adjacency matrices
        W = np.zeros([len_TimeSteps, nNodes, nNodes])
        threshold = CommunicationRadius  # We compute a different
        # threshold for each sample, because otherwise one bad trajectory
        # will ruin all the adjacency matrices

        for t in range(len_TimeSteps):
            # Compute the distances
            distances = squareform(pdist(pos[t]))  # nNodes x nNodes
            # Threshold them
            W[t] = (distances < threshold).astype(pos.dtype)
            # And get rid of the self-loops
            W[t] = W[t] - np.diag(np.diag(W[t]))
            # Now, check if it is connected, if not, let's make the
            # threshold bigger
            while (not graph.isConnected(W[t])) and (connected):
                # while (not graph.isConnected(W[t])) and (connected):
                # Increase threshold
                threshold = threshold * 1.1  # Increase 10%
                # Compute adjacency matrix
                W[t] = (distances < threshold).astype(pos.dtype)
                W[t] = W[t] - np.diag(np.diag(W[t]))

        # And since the threshold has probably changed, and we want the same
        # threshold for all nodes, we repeat:
        W = np.zeros([len_TimeSteps, nNodes, nNodes])
        for t in range(len_TimeSteps):
            distances = squareform(pdist(pos[t]))
            W[t] = (distances < threshold).astype(pos.dtype)
            W[t] = W[t] - np.diag(np.diag(W[t]))
            # And, when we compute the adjacency matrix, we normalize it by
            # the degree
            deg = np.sum(W[t], axis=1)  # nNodes (degree vector)
            # Build the degree matrix powered to the -1/2
            Deg = np.diag(np.sqrt(1. / deg))
            # And finally get the correct adjacency
            W[t] = Deg @ W[t] @ Deg

        return W, threshold
class CreateDataset(data.Dataset):
    def __init__(self, config, mode):
        self.config = config
        self.datapath_exp = '{}{:02d}x{:02d}_density_p{}/{}_Agent/'.format(
            self.config.map_type, self.config.map_w, self.config.map_h,
            self.config.map_density, self.config.num_agents)

        self.dirName = os.path.join(self.config.data_root, self.datapath_exp)
        self.AgentState = AgentState(self.config.num_agents)

        if mode == "train":
            self.dir_data = os.path.join(self.dirName, 'train')
            self.search_files = self.search_target_files_withStep
            self.data_paths, self.id_stepdata = self.update_data_path_trainingset(
                self.dir_data)
            self.load_data = self.load_train_data
        elif mode == "test_trainingSet":
            self.dir_data = os.path.join(self.dirName, 'train')
            # self.search_files = self.search_target_files
            # data_paths, id_stepdata = self.update_data_path_trainingset(self.dir_data)
            data_paths, id_stepdata = self.search_target_files(self.dir_data)
            paths_total = list(zip(data_paths, id_stepdata))
            random.shuffle(paths_total)
            data_paths, id_stepdata = zip(*paths_total)
            self.data_paths = data_paths[:self.config.num_test_trainingSet]
            self.id_stepdata = id_stepdata[:self.config.num_test_trainingSet]
            self.load_data = self.load_data_during_training
        elif mode == "valid":
            self.dir_data = os.path.join(self.dirName, 'valid')
            self.data_paths, self.id_stepdata = self.obtain_data_path_validset(
                self.dir_data, self.config.num_validset)
            self.load_data = self.load_data_during_training
        elif mode == "validStep":
            self.dir_data = os.path.join(self.dirName, 'valid')
            self.data_paths, self.id_stepdata = self.search_valid_files_withStep(
                self.dir_data, self.config.num_validset)
            self.load_data = self.load_train_data
        elif mode == "test":
            self.dir_data = os.path.join(self.dirName, mode)
            self.data_paths, self.id_stepdata = self.obtain_data_path_validset(
                self.dir_data, self.config.num_testset)
            self.load_data = self.load_test_data

        self.data_size = len(self.data_paths)

    def __getitem__(self, index):

        path = self.data_paths[index % self.data_size]
        id_step = int(self.id_stepdata[index % self.data_size])
        input, target, GSO, map_tensor = self.load_data(path, id_step)
        return input, target, id_step, GSO, map_tensor

    def update_data_path_trainingset(self, dir_data):
        # only used for training set and online expert - training purpose
        data_paths_total = []
        step_paths_total = []
        # load common training set (21000)
        data_paths, step_paths = self.search_files(dir_data)
        data_paths_total.extend(data_paths)
        step_paths_total.extend(step_paths)
        # load training set from online expert based on failCases
        data_paths_failcases, step_paths_failcases = self.search_files(
            self.config.failCases_dir)
        data_paths_total.extend(data_paths_failcases)
        step_paths_total.extend(step_paths_failcases)
        paths_total = list(zip(data_paths_total, step_paths_total))
        random.shuffle(paths_total)
        data_paths_total, step_paths_total = zip(*paths_total)
        return data_paths_total, step_paths_total

    def obtain_data_path_validset(self, dir_data, case_limit):
        # obtain validation data to valid the decision making at given state
        data_paths, id_stepdata = self.search_target_files(dir_data)
        paths_bundle = list(zip(data_paths, id_stepdata))
        paths_bundle = sorted(paths_bundle)
        data_paths, id_stepdata = zip(*paths_bundle)
        data_paths = data_paths[:case_limit]
        id_stepdata = id_stepdata[:case_limit]
        return data_paths, id_stepdata

    def load_train_data(self, path, id_step):

        data_contents = sio.loadmat(path)
        map_channel = data_contents['map']  # W x H

        input_tensor = data_contents[
            'inputTensor']  # step x num_agent x 3 x 11 x 11
        target_sequence = data_contents['target']  # step x num_agent x 5
        input_GSO_sequence = data_contents[
            'GSO']  # Step x num_agent x num_agent

        tensor_map = torch.from_numpy(map_channel).float()

        step_input_tensor = torch.from_numpy(input_tensor[id_step][:]).float()
        step_input_GSO = torch.from_numpy(
            input_GSO_sequence[id_step, :, :]).float()
        step_target = torch.from_numpy(target_sequence[id_step, :, :]).long()

        return step_input_tensor, step_target, step_input_GSO, tensor_map

    def load_data_during_training(self, path, _):
        # load dataset into validation mode during training - only initial position, predict action towards goal
        # test on training set and test on validation set
        data_contents = sio.loadmat(path)
        map_channel = data_contents['map']  # W x H
        goal_allagents = data_contents['goal']  # num_agent x 2

        input_sequence = data_contents['inputState'][
            0]  # from step x num_agent x 2 to # initial pos x num_agent x 2
        target_sequence = data_contents['target']  # step x num_agent x 5

        self.AgentState.setmap(map_channel)
        step_input_tensor = self.AgentState.stackinfo(goal_allagents,
                                                      input_sequence)

        step_target = torch.from_numpy(target_sequence).long()
        # from step x num_agent x action (5) to  id_agent x step x action(5)
        step_target = step_target.permute(1, 0, 2)
        step_input_rs = step_input_tensor.squeeze(0)
        step_target_rs = step_target.squeeze(0)

        tensor_map = torch.from_numpy(map_channel).float()
        GSO_none = torch.zeros(1)
        return step_input_rs, step_target_rs, GSO_none, tensor_map

    def load_test_data(self, path, _):
        # load dataset into test mode - only initial position, predict action towards goal

        data_contents = sio.loadmat(path)
        map_channel = data_contents['map']  # W x H
        goal_allagents = data_contents['goal']  # num_agent x 2

        input_sequence = data_contents['inputState']  # num_agent x 2
        target_sequence = data_contents['target']  # step x num_agent x 5
        bots_to_change = 0
        if int(self.config.distribution_num) != 0:
            bots_to_change = int(self.config.distribution_num)
        idxs = []
        for i in range(bots_to_change):
            num = 0
            idx = np.random.choice(range(len(input_sequence)))
            while (idx in idxs) and num < 10:
                idx = np.random.choice(range(len(input_sequence)))
                num += 1
            if self.config.distribution == "LLQ":
                print("RUNNING LLQ")
                choice = [
                    np.random.choice(range(int(map_channel.shape[0] / 2))),
                    np.random.choice(range(int(map_channel.shape[1] / 2)))
                ]
                num = 0
                while (map_channel[choice[0]][choice[1]] == 1
                       or choice in input_sequence) and num < 10:
                    choice = [
                        np.random.choice(range(int(map_channel.shape[0] / 2))),
                        np.random.choice(range(int(map_channel.shape[1] / 2)))
                    ]
                input_sequence[idx] = np.array(choice)
            if self.config.distribution == "edge":
                print("RUNNING EDGE")
                edge = np.random.choice([0, 1])
                if edge == 0 and map_channel[0][int(
                        input_sequence[idx][1])] != 1 and [
                            0, int(input_sequence[idx][1])
                        ] not in input_sequence:
                    input_sequence[idx][0] = 0
                if edge == 1 and map_channel[int(
                        input_sequence[idx][0])][0] != 1 and [
                            input_sequence[idx][0], 0
                        ] not in input_sequence:
                    input_sequence[idx][1] = 0

        self.AgentState.setmap(map_channel)
        step_input_tensor = self.AgentState.stackinfo(goal_allagents,
                                                      input_sequence)
        step_target = torch.from_numpy(target_sequence).long()
        # from step x num_agent x action (5) to  id_agent x step x action(5)
        step_target = step_target.permute(1, 0, 2)
        step_input_rs = step_input_tensor.squeeze(0)
        step_target_rs = step_target.squeeze(0)

        tensor_map = torch.from_numpy(map_channel).float()
        GSO_none = torch.zeros(1)
        return step_input_rs, step_target_rs, GSO_none, tensor_map

    def search_target_files(self, dir):
        # make a list of file name of input yaml
        list_path = []
        list_path_stepdata = []
        assert os.path.isdir(dir), '%s is not a valid directory' % dir

        for root, _, fnames in sorted(os.walk(dir)):
            for fname in fnames:
                if self.is_target_file(fname):
                    makespan = int(fname.split('_MP')[-1].split('.mat')[0])
                    path = os.path.join(root, fname)
                    list_path.append(path)
                    list_path_stepdata.append(makespan)

        return list_path, list_path_stepdata

    def search_target_files_withStep(self, dir):
        # make a list of file name of input yaml
        list_path = []
        list_path_stepdata = []
        assert os.path.isdir(dir), '%s is not a valid directory' % dir

        for root, _, fnames in sorted(os.walk(dir)):
            for fname in fnames:
                if self.is_target_file(fname):
                    makespan = int(fname.split('_MP')[-1].split('.mat')[0])
                    path = os.path.join(root, fname)
                    for step in range(makespan):
                        # path = os.path.join(root, fname, str(step))
                        list_path.append(path)
                        list_path_stepdata.append(step)

        return list_path, list_path_stepdata

    def search_valid_files_withStep(self, dir, case_limit):
        # make a list of file name of input yaml
        list_path = []
        list_path_stepdata = []
        count_num_cases = 0
        assert os.path.isdir(dir), '%s is not a valid directory' % dir

        for root, _, fnames in sorted(os.walk(dir)):
            for fname in fnames:
                if self.is_target_file(fname):
                    makespan = int(fname.split('_MP')[-1].split('.mat')[0])
                    path = os.path.join(root, fname)
                    if count_num_cases <= case_limit:
                        for step in range(makespan):
                            # path = os.path.join(root, fname, str(step))
                            list_path.append(path)
                            list_path_stepdata.append(step)
                        count_num_cases += 1
                    else:
                        break

        return list_path, list_path_stepdata

    def is_target_file(self, filename):
        DATA_EXTENSIONS = ['.mat']
        return any(
            filename.endswith(extension) for extension in DATA_EXTENSIONS)

    def __len__(self):
        return self.data_size