def test_get_meetings_3_robots(env3_robots):
    env_data = env3_robots.get_env_metadata()
    cycles_length = [len(x) for x in env_data['cycles']]

    assert get_valid_meetings(0, env_data['meetings'], cycles_length) == []
    assert get_valid_meetings(1, env_data['meetings'], cycles_length) == [
        Meeting(r1=0, r2=1, first_time=1)
    ]
    assert get_valid_meetings(23, env_data['meetings'], cycles_length) == [
        Meeting(r1=1, r2=2, first_time=5)] == get_valid_meetings(5, env_data['meetings'], cycles_length)
def test_get_meetings_4_robots(env4_robots):
    env_data = env4_robots.get_env_metadata()
    cycles_length = [len(x) for x in env_data['cycles']]

    meetings_at_0 = get_valid_meetings(0, env_data['meetings'], cycles_length)
    assert meetings_at_0 == []
    meetings_at_2 = get_valid_meetings(2, env_data['meetings'], cycles_length)
    assert meetings_at_2 == [
        Meeting(r1=2, r2=3, first_time=2),
        Meeting(r1=0, r2=1, first_time=2)
    ]
    assert get_valid_meetings(5, env_data['meetings'], cycles_length) == []
    meetings_at_8 = get_valid_meetings(8, env_data['meetings'], cycles_length)

    assert set(meetings_at_8) == {Meeting(r1=1, r2=2, first_time=4),
                                  Meeting(r1=0, r2=1, first_time=2),
                                  Meeting(r1=0, r2=2, first_time=4)
                                  }

    assert get_valid_meetings(9, env_data['meetings'], cycles_length) == []

    assert set(get_valid_meetings(14, env_data['meetings'], cycles_length)) == {
        Meeting(r1=2, r2=3, first_time=2),
        Meeting(r1=0, r2=1, first_time=2)
    }
    def give_reward(self, state: State, interpreted_action: List[Exchange],
                    meetings: List[Meeting], cycles_lengths: List[int],
                    max_memory: int):
        state_cpy = State(deepcopy(state.robots_data), state.time,
                          deepcopy(state.positions))
        exchanges = [x.quant for x in interpreted_action]
        valid_meetings = get_valid_meetings(state.time, meetings,
                                            cycles_lengths)

        if max(exchanges) == 0 == min(exchanges):
            return int(settings.REWARD_FOR_INVALID_ACTION * 2 / 3)
        reward = 0
        r1_r2_meeting_list = [[x.r1, x.r2] for x in valid_meetings]
        for exchange in interpreted_action:
            r1 = exchange.r1
            r2 = exchange.r2
            if [r1, r2] not in r1_r2_meeting_list and exchange.quant != 0:
                return settings.REWARD_FOR_INVALID_ACTION
            elif r1 == 0 and exchange.quant > 0:
                return settings.REWARD_FOR_INVALID_ACTION
            else:
                if r1 == 0 and exchange.quant < 0 and state_cpy.robots_data[
                        r2] + exchange.quant >= 0:
                    reward += -exchange.quant
                    state_cpy.robots_data[r1] -= exchange.quant
                    state_cpy.robots_data[r2] += exchange.quant
                    continue
                if exchange.quant < 0 and (
                        state_cpy.robots_data[r1] - exchange.quant > max_memory
                        or state_cpy.robots_data[r2] + exchange.quant < 0):
                    return settings.REWARD_FOR_INVALID_ACTION
                elif exchange.quant > 0 and (
                        state_cpy.robots_data[r2] + exchange.quant > max_memory
                        or state_cpy.robots_data[r1] - exchange.quant < 0):
                    return settings.REWARD_FOR_INVALID_ACTION
                elif r1 == 0 and exchange.quant < 0:
                    state_cpy.robots_data[r1] -= exchange.quant
                    state_cpy.robots_data[r2] += exchange.quant
                    reward += abs(exchange.quant) * 1000
                elif exchange.quant != 0:
                    state_cpy.robots_data[r1] -= exchange.quant
                    state_cpy.robots_data[r2] += exchange.quant
                    reward += 1000

        if check_if_done(state_cpy):
            return settings.END_ENVIRONMENT
        return reward
    def give_reward(self, state: State, interpreted_action: List[Exchange],
                    meetings: List[Meeting], cycles_lengths: List[int],
                    max_memory: int):
        state_cpy = State(deepcopy(state.robots_data), state.time,
                          deepcopy(state.positions))

        reward = 0
        valid_meetings = get_valid_meetings(state.time, meetings,
                                            cycles_lengths)
        r1_r2_meeting_list = [[x.r1, x.r2] for x in valid_meetings]
        for exchange in interpreted_action:
            r1 = exchange.r1
            r2 = exchange.r2
            if [r1, r2] not in r1_r2_meeting_list and exchange.quant != 0:
                reward += settings.REWARD_FOR_INVALID_MEETING
            elif r1 == 0 and exchange.quant > 0:
                reward += settings.REWARD_FOR_INVALID_MEETING
            else:
                if r1 == 0 and exchange.quant < 0 and state_cpy.robots_data[
                        r2] + exchange.quant >= 0:
                    reward += 1
                    state_cpy.robots_data[r1] -= exchange.quant
                    state_cpy.robots_data[r2] += exchange.quant
                    continue
                if exchange.quant < 0 and (
                        state_cpy.robots_data[r1] - exchange.quant > max_memory
                        or state_cpy.robots_data[r2] + exchange.quant < 0):
                    reward += settings.REWARD_FOR_INVALID_TRANSFER
                elif exchange.quant > 0 and (
                        state_cpy.robots_data[r2] + exchange.quant > max_memory
                        or state_cpy.robots_data[r1] - exchange.quant < 0):
                    reward += settings.REWARD_FOR_INVALID_TRANSFER
                elif r1 == 0 and exchange.quant < 0:
                    state_cpy.robots_data[r1] -= exchange.quant
                    state_cpy.robots_data[r2] += exchange.quant
                    reward += 1
                elif exchange.quant != 0:
                    state_cpy.robots_data[r1] -= exchange.quant
                    state_cpy.robots_data[r2] += exchange.quant
                    reward += -1
        return reward
    def select_valid_action(self) -> List[int]:
        """
        Function which selects a valid action for the current state
        :return: int (action index)
        """
        interpreted_state = get_state_from_observation(self.state)
        # print(interpreted_state)
        tmp_sets = []
        for i in range(self.__num_robots * (self.__num_robots - 1) // 2):
            r1, r2 = self.__action_mapping[i]
            rng = list(range(interpreted_state.robots_data[r1] + 1)) + list(
                range(
                    self.__max_memory + 1,
                    self.__max_memory + 1 + interpreted_state.robots_data[r2]))
            tmp_sets.append(rng)

        tmp_lazy_action_get = LazyCartesianProduct(tmp_sets)
        tmp_action = tmp_lazy_action_get.get_nth_element(
            np.random.randint(0, tmp_lazy_action_get.max_size))
        valid_meetings = get_valid_meetings(interpreted_state.time,
                                            self.__meetings,
                                            self.__cycles_lengths)
        # print("VALID MEETINGS: " + str(valid_meetings))
        pair_list = set([(x.r1, x.r2) for x in valid_meetings])
        # print("TMP_ACTION_BEFORE" + str(tmp_action))
        for i in range(len(self.__action_mapping)):
            r1, r2 = self.__action_mapping[i]
            if (r1, r2) not in pair_list:
                tmp_action[i] = 0
            elif r1 == 0:
                if interpreted_state.robots_data[r2] == 0:
                    tmp_action[i] = 0
                else:
                    tmp_action[i] = np.random.randint(
                        self.__max_memory + 1, self.__max_memory + 1 +
                        interpreted_state.robots_data[r2])
        # print("TMP_ACTION_AFTER" + str(tmp_action))
        return tmp_action