Ejemplo n.º 1
0
def pin(pin_socket, pusher_cache, kvs, status, function_cache, runtimes,
        exec_counts, user_library, local, batching):
    serialized = pin_socket.recv()
    pin_msg = PinFunction()
    pin_msg.ParseFromString(serialized)

    sckt = pusher_cache.get(
        sutils.get_pin_accept_port(pin_msg.response_address))
    name = pin_msg.name

    # We currently only allow one pinned function per container in non-local
    # mode.
    if not local:
        if (len(function_cache) > 0 and name not in function_cache):
            sutils.error.SerializeToString()
            sckt.send(sutils.error.SerializeToString())
            return batching

    func = utils.retrieve_function(pin_msg.name, kvs, user_library)

    # The function must exist -- because otherwise the DAG couldn't be
    # registered -- so we keep trying to retrieve it.
    while not func:
        func = utils.retrieve_function(name, kvs, user_library)

    if name not in function_cache:
        print(
            f"writing function cache for entry {name}, it's a type {type(func)}"
        )
        import cloudpickle
        if isinstance(func, bytes):
            func = cloudpickle.loads(func)
        function_cache[name] = func

    if name not in status.functions:
        status.functions.append(name)

    # Add metadata tracking for the newly pinned functions.
    runtimes[name] = []
    exec_counts[name] = 0
    logging.info('Adding function %s to my local pinned functions.' % (name))

    if pin_msg.batching and len(status.functions) > 1:
        raise RuntimeError(
            'There is more than one pinned function (we are' +
            ' operating in local mode), and the function' +
            ' attempting to be pinned has batching enabled. This' +
            ' is not allowed -- you can only use batching in' +
            ' cluster mode or in local mode with one function.')

    sckt.send(sutils.ok_resp)

    return pin_msg.batching
Ejemplo n.º 2
0
def pin(pin_socket, pusher_cache, kvs, status, function_cache, runtimes,
        exec_counts, user_library, local):
    serialized = pin_socket.recv()
    pin_msg = PinFunction()
    pin_msg.ParseFromString(serialized)

    sckt = pusher_cache.get(
        sutils.get_pin_accept_port(pin_msg.response_address))
    name = pin_msg.name

    # We currently only allow one pinned function per container in non-local
    # mode.
    if (not local and ((len(function_cache) > 0 and name not in function_cache)
                       or not status.running)):
        sutils.error.SerializeToString()
        sckt.send(sutils.error.SerializeToString())
        return

    sckt.send(sutils.ok_resp)

    func = utils.retrieve_function(pin_msg.name, kvs, user_library)

    # The function must exist -- because otherwise the DAG couldn't be
    # registered -- so we keep trying to retrieve it.
    while not func:
        func = utils.retrieve_function(name, kvs, user_library)

    if name not in function_cache:
        function_cache[name] = func

    if name not in status.functions:
        status.functions.append(name)

    # Add metadata tracking for the newly pinned functions.
    runtimes[name] = []
    exec_counts[name] = 0
    logging.info('Adding function %s to my local pinned functions.' % (name))
Ejemplo n.º 3
0
    def test_create_gpu_dag(self):
        # Create a simple two-function DAG and add it to the inbound socket.
        dag_name = 'dag'
        fn = 'fn'

        dag = create_linear_dag([None], [fn], self.kvs_client, dag_name)
        dag.functions[0].gpu = True
        self.socket.inbox.append(dag.SerializeToString())

        dags = {}
        call_frequency = {}

        address_set = {(self.ip, 1)}
        self.policy.unpinned_gpu_executors.update(address_set)

        self.pin_socket.inbox.append(sutils.ok_resp)

        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Test that the correct metadata was created.
        self.assertTrue(dag_name in dags)
        created, dag_source = dags[dag_name]
        self.assertEqual(created, dag)
        self.assertEqual(len(dag_source), 1)
        self.assertEqual(list(dag_source)[0], fn)
        self.assertTrue(fn in call_frequency)
        self.assertEqual(call_frequency[fn], 0)

        # Test that the DAG is stored in the KVS correctly.
        result = self.kvs_client.get(dag_name)[dag_name]
        created = Dag()
        created.ParseFromString(result.reveal())
        self.assertEqual(created, dag)

        # Test that the correct response was returned to the user.
        self.assertTrue(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox.pop())
        self.assertTrue(response.success)

        # Test that the correct pin messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)
        messages = self.pusher_cache.socket.outbox
        function_set = {fn}
        for message in messages:
            pin_msg = PinFunction()
            pin_msg.ParseFromString(message)
            self.assertEqual(pin_msg.response_address, self.ip)
            self.assertTrue(pin_msg.name in function_set)
            function_set.discard(pin_msg.name)

        self.assertEqual(len(function_set), 0)

        for address in address_set:
            self.assertTrue(
                get_pin_address(*address) in self.pusher_cache.addresses)

        # Test that the policy engine has the correct metadata stored.
        self.assertEqual(len(self.policy.unpinned_cpu_executors), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)
        self.assertTrue(fn in self.policy.function_locations)

        self.assertEqual(len(self.policy.function_locations[fn]), 1)
Ejemplo n.º 4
0
    def test_create_dag_insufficient_resources(self):
        '''
        This test attempts to create a DAG even though there are not enough
        free executors in the system. It checks that a pin message is attempted
        to be sent, we run out of resources, and then the request is rejected.
        We check that the metadata is properly restored back to its original
        state.
        '''
        # Create a simple two-function DAG and add it to the inbound socket.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)
        self.socket.inbox.append(dag.SerializeToString())

        # Add relevant metadata to the policy engine, but set the number of
        # executors to fewer than needed.
        address_set = {(self.ip, 1)}
        self.policy.unpinned_cpu_executors.update(address_set)

        # Prepopulate the pin_accept socket with sufficient success messages.
        self.pin_socket.inbox.append(sutils.ok_resp)

        # Attempt to create the DAG.
        dags = {}
        call_frequency = {}
        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Check that an error was returned to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])
        self.assertFalse(response.success)
        self.assertEqual(response.error, NO_RESOURCES)

        # Test that the correct pin messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 2)
        messages = self.pusher_cache.socket.outbox

        # Checks for the pin message.
        pin_msg = PinFunction()
        pin_msg.ParseFromString(messages[0])
        self.assertEqual(pin_msg.response_address, self.ip)
        self.assertEqual(pin_msg.name, source)

        # Checks for the unpin message.
        self.assertEqual(messages[1], source)

        address = random.sample(address_set, 1)[0]
        addresses = self.pusher_cache.addresses
        self.assertEqual(get_pin_address(*address), addresses[0])
        self.assertEqual(get_unpin_address(*address), addresses[1])

        # Check that no additional messages were sent.
        self.assertEqual(len(self.policy.unpinned_cpu_executors), 0)
        self.assertEqual(len(self.policy.function_locations), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)

        # Check that no additional metadata was created or sent.
        self.assertEqual(len(call_frequency), 0)
        self.assertEqual(len(dags), 0)
Ejemplo n.º 5
0
    def test_create_dag(self):
        '''
        This test creates a new DAG, checking that the correct pin messages are
        sent to executors and that it is persisted in the KVS correctly. It
        also checks that the server metadata was updated as expected.
        '''
        # Create a simple two-function DAG and add it to the inbound socket.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)
        self.socket.inbox.append(dag.SerializeToString())

        # Add relevant metadata to the policy engine.
        address_set = {(self.ip, 1), (self.ip, 2)}
        self.policy.unpinned_cpu_executors.update(address_set)

        # Prepopulate the pin_accept socket with sufficient success messages.
        self.pin_socket.inbox.append(sutils.ok_resp)
        self.pin_socket.inbox.append(sutils.ok_resp)

        # Call the DAG creation method.
        dags = {}
        call_frequency = {}
        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Test that the correct metadata was created.
        self.assertTrue(dag_name in dags)
        created, dag_source = dags[dag_name]
        self.assertEqual(created, dag)
        self.assertEqual(len(dag_source), 1)
        self.assertEqual(list(dag_source)[0], source)
        self.assertTrue(source in call_frequency)
        self.assertTrue(sink in call_frequency)
        self.assertEqual(call_frequency[source], 0)
        self.assertEqual(call_frequency[sink], 0)

        # Test that the DAG is stored in the KVS correctly.
        result = self.kvs_client.get(dag_name)[dag_name]
        created = Dag()
        created.ParseFromString(result.reveal())
        self.assertEqual(created, dag)

        # Test that the correct response was returned to the user.
        self.assertTrue(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox.pop())
        self.assertTrue(response.success)

        # Test that the correct pin messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 2)
        messages = self.pusher_cache.socket.outbox
        function_set = {source, sink}
        for message in messages:
            pin_msg = PinFunction()
            pin_msg.ParseFromString(message)
            self.assertEqual(pin_msg.response_address, self.ip)
            self.assertTrue(pin_msg.name in function_set)
            function_set.discard(pin_msg.name)

        self.assertEqual(len(function_set), 0)

        for address in address_set:
            self.assertTrue(
                get_pin_address(*address) in self.pusher_cache.addresses)

        # Test that the policy engine has the correct metadata stored.
        self.assertEqual(len(self.policy.unpinned_cpu_executors), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)
        self.assertTrue(source in self.policy.function_locations)
        self.assertTrue(sink in self.policy.function_locations)

        self.assertEqual(len(self.policy.function_locations[source]), 1)
        self.assertEqual(len(self.policy.function_locations[sink]), 1)