Beispiel #1
0
    def test_serializer_returns_invalid_on_empty_prompts(self):
        data = {"prompt": "", "temperature": 1}

        serializer = TextAlgorithmPromptSerializer(data=data)
        valid = serializer.is_valid()

        self.assertFalse(valid)
Beispiel #2
0
    async def _receive_new_request(self, data):
        """
        this function is kind of overwhelming (sorry), but what i'm doing is
        putting a few caches because running inference even with
        p100 gpus is still slow for transformer architectures

        the first cache checks if this request has been made before
        with the specific settings of word length, temp, etc

        the second cache sees if this request is already running,
        in most circumstances, that's overengineering, but some requests
        can take over ten seconds to run, so the worst case would be if
        it duplicated this request
        """
        serializer = TextAlgorithmPromptSerializer(data=data)

        # don't throw exceptions in the regular pattern raise_exception=True, all
        # exceptions need to be properly handled when using channels
        valid = serializer.is_valid()

        if not valid:
            return await self.return_invalid_data_prompt(data)

        prompt_serialized = serializer.validated_data

        cache_key = get_cache_key_for_text_algo_parameter(**prompt_serialized)
        cached_results = await get_cached_results(cache_key)
        if cached_results:
            return await self.send_serialized_data(cached_results)

        # technically a bug can probably occur if separate users try the same exact
        # phrase in the 180 seconds, but if that happens, that means the servers are probably
        # crushed from too many requests anyways, RIP
        duplicate_request = await check_if_cache_key_for_parameters_is_running(
            cache_key)
        if duplicate_request:
            print("Duplicate request already running.")
            return

        # if it doesnt' exist, add a state flag to say this is going to be running
        # so it will automatically broadcast back when if the frontend makes a duplicate request
        await set_request_flag_that_request_is_running_in_cache(cache_key)

        # switch auth styles, passing it here makes it a little bit more cross-operable
        # since aiohttp doesn't pass headers in the same way as the requests library
        # and you're too lazy to write custom middleware for one endpoint
        # the ml endpoints are protected via an api_key to prevent abuse
        prompt_serialized["api_key"] = settings.ML_SERVICE_ENDPOINT_API_KEY

        # pass the websocket_uuid for the ML endpoints to know how to communicate
        prompt_serialized["websocket_uuid"] = self.group_name
        prompt_serialized["cache_key"] = cache_key

        model_name = prompt_serialized["model_name"]
        url = get_api_endpoint_from_model_name(model_name)

        await self.post_to_microservice(url, prompt_serialized)
Beispiel #3
0
    def test_cache_key_with_serializer(self):
        post_message = {"prompt": "Hello"}

        serializer = TextAlgorithmPromptSerializer(data=post_message)
        serializer.is_valid(raise_exception=False)

        cache_key = get_cache_key_for_text_algo_parameter(**serializer.validated_data)

        expected_cache_key = f"writeup_8b1a9953c4611296a827abf8c47804d7_5_40_0.7_10_0_english_gpt2-medium"
        self.assertEqual(cache_key, expected_cache_key)
Beispiel #4
0
def run():
    data = {
        "text":
        "Today I Saw A Village, what if i had a lot of text" + COMMON_P,
        "temperature": 1,
        "top_k": 20,
    }

    start = time.time()
    serializer = TextAlgorithmPromptSerializer(data=data)
    serializer.is_valid()
    end = time.time()

    difference = end - start

    print(difference)