示例#1
0
    def register(self, function, name):
        '''
        Registers a new function or class with the system. The returned object
        can be called like a regular Python function, which returns a Cloudburst
        Future. If the input is a class, the class is expected to have a run
        method, which is what is invoked at runtime.

        function: The function object that we are registering.
        name: A unique name for the function to be stored with in the system.
        '''

        func = Function()
        func.name = name
        func.body = serializer.dump(function)

        self.func_create_sock.send(func.SerializeToString())

        resp = GenericResponse()
        resp.ParseFromString(self.func_create_sock.recv())

        if resp.success:
            registered_functon = CloudburstFunction(name, self,
                                                    self.kvs_client)
            # print("55", self.kvs_client, "in register 66")
            return registered_functon
        else:
            raise RuntimeError(
                f'Unexpected error while registering function: {resp}.')
示例#2
0
    def register_dag(self, name, functions, connections):
        '''
        Registers a new DAG with the system. This operation will fail if any of
        the functions provided cannot be identified in the system.

        name: A unique name for this DAG.
        functions: A list of names of functions to be included in this DAG.
        connections: A list of ordered pairs of function names that represent
        the edges in this DAG.
        '''

        flist = self._get_func_list()
        for fname in functions:
            if fname not in flist:
                logging.info(
                    'Function %s not registered. Please register before \
                    including it in a DAG.' % (fname))
                return False, None

        dag = Dag()
        dag.name = name
        dag.functions.extend(functions)
        for pair in connections:
            conn = dag.connections.add()
            conn.source = pair[0]
            conn.sink = pair[1]

        self.dag_create_sock.send(dag.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.dag_create_sock.recv())

        return r.success, r.error
示例#3
0
    def test_succesful_pin(self):
        '''
        This test executes a pin operation that is supposed to be successful,
        and it checks to make sure that that the correct metadata for execution
        and reporting is generated.
        '''
        # Create a new function in the KVS.
        fname = 'incr'

        def func(_, x):
            return x + 1

        create_function(func, self.kvs_client, fname)

        # Create a pin message and put it into the socket.
        msg = PinFunction(name=fname, response_address=self.ip)
        self.socket.inbox.append(msg.SerializeToString())

        # Execute the pin operation.
        pin(self.socket, self.pusher_cache, self.kvs_client, self.status,
            self.pinned_functions, self.runtimes, self.exec_counts,
            self.user_library, False, False)

        # Check that the correct messages were sent and the correct metadata
        # created.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.pusher_cache.socket.outbox[0])
        self.assertTrue(response.success)

        self.assertEqual(func('', 1), self.pinned_functions[fname]('', 1))
        self.assertTrue(fname in self.pinned_functions)
        self.assertTrue(fname in self.runtimes)
        self.assertTrue(fname in self.exec_counts)
        self.assertTrue(fname in self.status.functions)
示例#4
0
    def test_function_call_no_resources(self):
        '''
        Constructs a scenario where there are no available resources in the
        system, and ensures that the scheduler correctly returns an error to
        the user.
        '''
        # Clear all executors from the system.
        self.policy.thread_statuses.clear()
        self.policy.unpinned_executors.clear()

        # Create a function call.
        call = FunctionCall()
        call.name = 'function'
        call.request_id = 12
        val = call.arguments.values.add()
        serializer.dump(2, val)
        self.socket.inbox.append(call.SerializeToString())

        # Execute the scheduling policy.
        call_function(self.socket, self.pusher_cache, self.policy)

        # Check that the correct number of messages were sent.
        self.assertEqual(len(self.socket.outbox), 1)
        self.assertEqual(len(self.pusher_cache.socket.outbox), 0)

        # Extract and deserialize the messages.
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])

        self.assertFalse(response.success)
        self.assertEqual(response.error, NO_RESOURCES)
示例#5
0
    def test_exec_function_with_error(self):
        '''
        Attempts to executet a function that raises an error during its
        execution. Ensures that an error is returned to the user.
        '''
        e_msg = 'This is a broken_function!'

        def func(_, x):
            raise ValueError(e_msg)

        fname = 'func'
        create_function(func, self.kvs_client, fname)

        # Put the functin into the KVS and create a function call.
        call = self._create_function_call(fname, [''], NORMAL)
        self.socket.inbox.append(call.SerializeToString())

        # Execute the function call.
        exec_function(self.socket, self.kvs_client, self.user_library, {}, {})

        # Retrieve the result from the KVS and ensure that it is the correct
        # lattice type.
        result = self.kvs_client.get(self.response_key)[self.response_key]
        self.assertEqual(type(result), LWWPairLattice)
        result = serializer.load_lattice(result)

        # Check the type and values of the error.
        self.assertEqual(type(result), tuple)
        self.assertTrue(e_msg in result[0])

        # Unpack the GenericResponse and check its values.
        response = GenericResponse()
        response.ParseFromString(result[1])
        self.assertEqual(response.success, False)
        self.assertEqual(response.error, EXECUTION_ERROR)
示例#6
0
    def test_exec_function_nonexistent(self):
        '''
        Attempts to execute a non-existent function and ensures that an error
        is thrown and returned to the user.
        '''
        # Create a call to a function that does not exist.
        call = self._create_function_call('bad_func', [1], NORMAL)
        self.socket.inbox.append(call.SerializeToString())

        # Attempt to execute the nonexistent function.
        exec_function(self.socket, self.kvs_client, self.user_library, {}, {})

        # Assert that there have been 0 messages sent.
        self.assertEqual(len(self.socket.outbox), 0)

        # Retrieve the result from the KVS and ensure that it is a lattice as
        # we expect. Deserialize it and check for an error.
        result = self.kvs_client.get(self.response_key)[self.response_key]
        self.assertEqual(type(result), LWWPairLattice)
        result = serializer.load_lattice(result)

        # Check the type and values of the error.
        self.assertEqual(type(result), tuple)
        self.assertEqual(result[0], 'ERROR')

        # Unpack the GenericResponse and check its values.
        response = GenericResponse()
        response.ParseFromString(result[1])
        self.assertEqual(response.success, False)
        self.assertEqual(response.error, FUNC_NOT_FOUND)
示例#7
0
    def test_create_gpu_dag_no_resources(self):
        # Create a simple two-function DAG and add it to the inbound socket.
        dag_name = 'dag'

        dag = create_linear_dag([None], ['fn'], self.kvs_client, dag_name)
        dag.functions[0].gpu = True
        self.socket.inbox.append(dag.SerializeToString())

        dags = {}
        call_frequency = {}

        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Check that an error was returned to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])
        self.assertFalse(response.success)
        self.assertEqual(response.error, NO_RESOURCES)

        # Test that the correct pin messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 0)

        # Check that no additional messages were sent.
        self.assertEqual(len(self.policy.unpinned_cpu_executors), 0)
        self.assertEqual(len(self.policy.function_locations), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)

        # Check that no additional metadata was created or sent.
        self.assertEqual(len(call_frequency), 0)
        self.assertEqual(len(dags), 0)
示例#8
0
    def call_dag(self, dname, arg_map, direct_response=False,
                 consistency=NORMAL, output_key=None, client_id=None):
        '''
        Issues a new request to execute the DAG. Returns a CloudburstFuture that

        dname: The name of the DAG to cexecute.
        arg_map: A map from function names to lists of arguments for each of
        the functions in the DAG.
        direct_response: If True, the response will be synchronously received
        by the client; otherwise, the result will be stored in the KVS.
        consistency: The consistency mode to use with this function: either
        NORMAL or MULTI.
        output_key: The KVS key in which to store the result of thie DAG.
        client_id: An optional ID associated with an individual client across
        requests; this is used for causal metadata.
        '''
        dc = DagCall()
        dc.name = dname
        dc.consistency = consistency

        if output_key:
            dc.output_key = output_key

        if client_id:
            dc.client_id = client_id

        for fname in arg_map:
            fname_args = arg_map[fname]
            if type(fname_args) != list:
                fname_args = [fname_args]
            args = [serializer.dump(arg, serialize=False) for arg in
                    fname_args]
            al = dc.function_args[fname]
            al.values.extend(args)

        if direct_response:
            dc.response_address = self.response_address

        self.dag_call_sock.send(dc.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.dag_call_sock.recv())

        if direct_response:
            try:
                result = self.response_sock.recv()
                return serializer.load(result)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    return None
                else:
                    raise e
        else:
            if r.success:
                return CloudburstFuture(r.response_id, self.kvs_client,
                                     serializer)
            else:
                return None
示例#9
0
    def pin_function(self, dag_name, function_ref):
        # If there are no functions left to choose from, then we return None,
        # indicating that we ran out of resources to use.
        if len(self.unpinned_executors) == 0:
            return False

        if dag_name not in self.pending_dags:
            self.pending_dags[dag_name] = []

        # Make a copy of the set of executors, so that we don't modify the
        # system's metadata.
        candidates = set(self.unpinned_executors)

        # Construct a PinFunction message to be sent to executors.
        pin_msg = PinFunction()
        pin_msg.name = function_ref.name
        pin_msg.response_address = self.ip

        serialized = pin_msg.SerializeToString()

        while True:
            # Pick a random executor from the set of candidates and attempt to
            # pin this function there.
            node, tid = sys_random.sample(candidates, 1)[0]

            sckt = self.pusher_cache.get(get_pin_address(node, tid))
            sckt.send(serialized)

            response = GenericResponse()
            try:
                response.ParseFromString(self.pin_accept_socket.recv())
            except zmq.ZMQError:
                logging.error('Pin operation to %s:%d timed out. Retrying.' %
                              (node, tid))
                continue

            # Do not use this executor either way: If it rejected, it has
            # something else pinned, and if it accepted, it has pinned what we
            # just asked it to pin.
            # In local model allow executors to have multiple functions pinned
            if not self.local:
                self.unpinned_executors.discard((node, tid))
                candidates.discard((node, tid))

            if response.success:
                # The pin operation succeeded, so we return the node and thread
                # ID to the caller.
                self.pending_dags[dag_name].append((function_ref.name, (node,
                                                                        tid)))
                return True
            else:
                # The pin operation was rejected, remove node and try again.
                logging.error('Node %s:%d rejected pin for %s. Retrying.'
                              % (node, tid, function_ref.name))

                continue
示例#10
0
    def delete_dag(self, dname):
        '''
        Removes the specified DAG from the system.

        dname: The name of the DAG to delete.
        '''
        self.dag_delete_sock.send_string(dname)

        r = GenericResponse()
        r.ParseFromString(self.dag_delete_sock.recv())

        return r.success, r.error
示例#11
0
    def test_delete_dag(self):
        '''
        We attempt to delete a DAG that has already been created and check to
        ensure that the correct unpin messages are sent to executors and that
        the metadata is updated appropriately.
        '''
        # Create a simple two fucntion DAG and add it to the system metadata.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)

        dags = {}
        call_frequency = {}
        dags[dag.name] = (dag, {source})
        call_frequency[source] = 100
        call_frequency[sink] = 100

        # Add the correct metadata to the policy engine.
        source_location = (self.ip, 1)
        sink_location = (self.ip, 2)
        self.policy.function_locations[source] = {source_location}
        self.policy.function_locations[sink] = {sink_location}

        self.socket.inbox.append(dag.name)

        # Attempt to delete the DAG.
        delete_dag(self.socket, dags, self.policy, call_frequency)

        # Check that the correct unpin messages were sent.
        messages = self.pusher_cache.socket.outbox
        self.assertEqual(len(messages), 2)
        self.assertEqual(messages[0], source)
        self.assertEqual(messages[1], sink)

        addresses = self.pusher_cache.addresses
        self.assertEqual(len(addresses), 2)
        self.assertEqual(addresses[0], get_unpin_address(*source_location))
        self.assertEqual(addresses[1], get_unpin_address(*sink_location))

        # Check that the server metadata was updated correctly.
        self.assertEqual(len(dags), 0)
        self.assertEqual(len(call_frequency), 0)

        # Check that the correct message was sent to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox.pop())
        self.assertTrue(response.success)
示例#12
0
    def register_dag(self, name, functions, connections):
        '''
        Registers a new DAG with the system. This operation will fail if any of
        the functions provided cannot be identified in the system.

        name: A unique name for this DAG.
        functions: A list of names of functions to be included in this DAG.
        connections: A list of ordered pairs of function names that represent
        the edges in this DAG.
        '''

        flist = self._get_func_list()
        for fname in functions:
            if isinstance(fname, tuple):
                fname = fname[0]

            if fname not in flist:
                raise RuntimeError(
                    f'Function {fname} not registered. Please register before '
                    + 'including it in a DAG.')

        dag = Dag()
        dag.name = name
        for function in functions:
            ref = dag.functions.add()

            if type(function) == tuple:
                fname = function[0]
                invalids = function[1]
                ref.type = MULTIEXEC
            else:
                fname = function
                invalids = []

            ref.name = fname
            for invalid in invalids:
                ref.invalid_results.append(serializer.dump(invalid))

        for pair in connections:
            conn = dag.connections.add()
            conn.source = pair[0]
            conn.sink = pair[1]

        self.dag_create_sock.send(dag.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.dag_create_sock.recv())

        return r.success, r.error
示例#13
0
    def exec_func(self, name, args):
        call = FunctionCall()
        call.name = name
        call.request_id = self.rid

        for arg in args:
            argobj = call.arguments.values.add()
            serializer.dump(arg, argobj)

        self.func_call_sock.send(call.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.func_call_sock.recv())

        self.rid += 1
        return r.response_id
示例#14
0
    def test_call_function_with_refs(self):
        '''
        Creates a scenario where the policy should deterministically pick the
        same executor to run a request on: There is one reference, and it's
        cached only on the node we create in this test.
        '''
        # Add a new executor for which we will construct cached references.
        ip_address = '192.168.0.1'
        new_key = (ip_address, 2)
        self.policy.unpinned_executors.add(new_key)

        # Create a new reference and add its metadata.
        ref_name = 'reference'
        self.policy.key_locations[ref_name] = [ip_address]

        # Create a function call that asks for this reference.
        call = FunctionCall()
        call.name = 'function'
        call.request_id = 12
        val = call.arguments.values.add()
        serializer.dump(CloudburstReference(ref_name, True), val)
        self.socket.inbox.append(call.SerializeToString(0))

        # Execute the scheduling policy.
        call_function(self.socket, self.pusher_cache, self.policy)

        # Check that the correct number of messages were sent.
        self.assertEqual(len(self.socket.outbox), 1)
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)

        # Extract and deserialize the messages.
        response = GenericResponse()
        forwarded = FunctionCall()
        response.ParseFromString(self.socket.outbox[0])
        forwarded.ParseFromString(self.pusher_cache.socket.outbox[0])

        self.assertTrue(response.success)
        self.assertEqual(response.response_id, forwarded.response_key)
        self.assertEqual(forwarded.name, call.name)
        self.assertEqual(forwarded.request_id, call.request_id)

        # Makes sure that the correct executor was chosen.
        self.assertEqual(len(self.pusher_cache.addresses), 1)
        self.assertEqual(self.pusher_cache.addresses[0],
                         utils.get_exec_address(*new_key))
示例#15
0
    def test_occupied_pin(self):
        '''
        This test attempts to pin a function onto a node where another function
        is already pinned. We currently only allow one pinned node per machine,
        so this operation should fail.
        '''
        # Create a new function in the KVS.
        fname = 'incr'

        def func(_, x):
            return x + 1

        create_function(func, self.kvs_client, fname)

        # Create a pin message and put it into the socket.
        msg = PinFunction(name=fname, response_address=self.ip)
        self.socket.inbox.append(msg.SerializeToString())

        # Add an already pinned_function, so that we reject the request.
        self.pinned_functions['square'] = lambda _, x: x * x
        self.runtimes['square'] = []
        self.exec_counts['square'] = []
        self.status.functions.append('square')

        # Execute the pin operation.
        pin(self.socket, self.pusher_cache, self.kvs_client, self.status,
            self.pinned_functions, self.runtimes, self.exec_counts,
            self.user_library, False, False)

        # Check that the correct messages were sent and the correct metadata
        # created.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.pusher_cache.socket.outbox[0])
        self.assertFalse(response.success)

        # Make sure that none of the metadata was corrupted with this failed
        # pin attempt
        self.assertTrue(fname not in self.pinned_functions)
        self.assertTrue(fname not in self.runtimes)
        self.assertTrue(fname not in self.exec_counts)
        self.assertTrue(fname not in self.status.functions)
示例#16
0
    def test_delete_nonexistent_dag(self):
        '''
        This test attempts to delete a nonexistent DAG and ensures that no
        metadata is affected by the failed operation.
        '''

        # Make a request to delete an unknown DAG.
        self.socket.inbox.append('dag')
        delete_dag(self.socket, {}, self.policy, {})

        # Ensure that an error response is sent to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])
        self.assertFalse(response.success)
        self.assertEqual(response.error, NO_SUCH_DAG)

        # Check that no additional messages were sent and no metadata changed.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 0)
        self.assertEqual(len(self.policy.function_locations), 0)
        self.assertEqual(len(self.policy.unpinned_executors), 0)
示例#17
0
    def test_create_dag_already_exists(self):
        '''
        This test attempts to create a DAG that already exists and makes sure
        that the server correctly rejects the request.
        '''
        # Create a simple two-function DAG and add it to the inbound socket.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)
        self.socket.inbox.append(dag.SerializeToString())

        # Add this to the existing server metadata.
        dags = {dag.name: (dag, {source})}

        # Add relevant metadata to the policy engine.
        address_set = {(self.ip, 1), (self.ip, 2)}
        self.policy.unpinned_executors.update(address_set)

        # Attempt to create the DAG.
        call_frequency = {}
        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Check that an error was returned to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])
        self.assertFalse(response.success)
        self.assertEqual(response.error, DAG_ALREADY_EXISTS)

        # Check that no additional metadata was created or sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 0)
        self.assertEqual(len(self.policy.unpinned_executors), 2)
        self.assertEqual(len(self.policy.function_locations), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)
示例#18
0
    def test_call_function_no_refs(self):
        '''
        A basic test that makes sure that an executor is successfully chosen
        when there is only one possible executor to choose from.
        '''
        # Create a new function call for a function that doesn't exist.
        call = FunctionCall()
        call.name = 'function'
        call.request_id = 12

        # Add an argument to thhe function.
        val = call.arguments.values.add()
        serializer.dump(2, val)
        self.socket.inbox.append(call.SerializeToString())

        # Execute the scheduling policy.
        call_function(self.socket, self.pusher_cache, self.policy)

        # Check that the correct number of messages were sent.
        self.assertEqual(len(self.socket.outbox), 1)
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)

        # Extract and deserialize the messages.
        response = GenericResponse()
        forwarded = FunctionCall()
        response.ParseFromString(self.socket.outbox[0])
        forwarded.ParseFromString(self.pusher_cache.socket.outbox[0])

        self.assertTrue(response.success)
        self.assertEqual(response.response_id, forwarded.response_key)
        self.assertEqual(forwarded.name, call.name)
        self.assertEqual(forwarded.request_id, call.request_id)

        # Makes sure that the correct executor was chosen.
        self.assertEqual(len(self.pusher_cache.addresses), 1)
        self.assertEqual(self.pusher_cache.addresses[0],
                         utils.get_exec_address(*self.executor_key))
示例#19
0
    def pin_function(self, dag_name, function_ref, colocated):
        # If there are no functions left to choose from, then we return None,
        # indicating that we ran out of resources to use.
        if function_ref.gpu and len(self.unpinned_gpu_executors) == 0:
            return False
        elif not function_ref.gpu and len(self.unpinned_cpu_executors) == 0:
            return False

        if dag_name not in self.pending_dags:
            self.pending_dags[dag_name] = []

        # Make a copy of the set of executors, so that we don't modify the
        # system's metadata.
        if function_ref.gpu:
            candidates = set(self.unpinned_gpu_executors)
        elif len(colocated) == 0:
            # If this is not a GPU function, just look at all of the unpinned
            # executors.
            candidates = set(self.unpinned_cpu_executors)
        else:
            candidates = set()

            already_pinned = set()
            for fn, thread in self.pending_dags[dag_name]:
                if fn in colocated:
                    already_pinned.add((fn, thread))
            candidate_nodes = set()

            if len(already_pinned) > 0:
                for fn, thread in already_pinned:
                    candidate_nodes.add(thread[0]) # The node's IP

                for node, tid in self.unpinned_cpu_executors:
                    if node in candidate_nodes:
                        candidates.add((node, tid))
            else:
                # If this is the first colocate to be pinned, try to assign to
                # an empty node.
                nodes = {}
                for node, tid in self.unpinned_cpu_executors:
                    if node not in nodes:
                        nodes[node] = 0
                    nodes[node] += 1

                for node in nodes:
                    if nodes[node] == NUM_EXECUTOR_THREADS:
                        for i in range(NUM_EXECUTOR_THREADS):
                            candidates.add((node, i))

        if len(candidates) == 0: # There no valid executors to colocate on.
            return self.pin_function(dag_name, function_ref, [])

        # Construct a PinFunction message to be sent to executors.
        pin_msg = PinFunction()
        pin_msg.name = function_ref.name
        pin_msg.batching = function_ref.batching
        pin_msg.response_address = self.ip

        serialized = pin_msg.SerializeToString()

        while True:
            # Pick a random executor from the set of candidates and attempt to
            # pin this function there.
            node, tid = sys_random.sample(candidates, 1)[0]

            sckt = self.pusher_cache.get(get_pin_address(node, tid))
            sckt.send(serialized)

            response = GenericResponse()
            try:
                response.ParseFromString(self.pin_accept_socket.recv())
            except zmq.ZMQError:
                logging.error('Pin operation to %s:%d timed out. Retrying.' %
                              (node, tid))
                continue

            # Do not use this executor either way: If it rejected, it has
            # something else pinned, and if it accepted, it has pinned what we
            # just asked it to pin. In local mode, however we allow executors
            # to have multiple functions pinned.
            if not self.local:
                if function_ref.gpu:
                    self.unpinned_gpu_executors.discard((node, tid))
                    candidates.discard((node, tid))
                else:
                    self.unpinned_cpu_executors.discard((node, tid))
                    candidates.discard((node, tid))

            if response.success:
                # The pin operation succeeded, so we return the node and thread
                # ID to the caller.
                self.pending_dags[dag_name].append((function_ref.name, (node,
                                                                        tid)))
                return True
            else:
                # The pin operation was rejected, remove node and try again.
                logging.error('Node %s:%d rejected pin for %s. Retrying.'
                              % (node, tid, function_ref.name))

                continue

            if len(candidates) == 0 and len(colocated) > 0:
                # Try again without colocation.
                return self.pin_function(self, dag_name, function_ref, [])
示例#20
0
    def test_create_gpu_dag(self):
        # Create a simple two-function DAG and add it to the inbound socket.
        dag_name = 'dag'
        fn = 'fn'

        dag = create_linear_dag([None], [fn], self.kvs_client, dag_name)
        dag.functions[0].gpu = True
        self.socket.inbox.append(dag.SerializeToString())

        dags = {}
        call_frequency = {}

        address_set = {(self.ip, 1)}
        self.policy.unpinned_gpu_executors.update(address_set)

        self.pin_socket.inbox.append(sutils.ok_resp)

        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Test that the correct metadata was created.
        self.assertTrue(dag_name in dags)
        created, dag_source = dags[dag_name]
        self.assertEqual(created, dag)
        self.assertEqual(len(dag_source), 1)
        self.assertEqual(list(dag_source)[0], fn)
        self.assertTrue(fn in call_frequency)
        self.assertEqual(call_frequency[fn], 0)

        # Test that the DAG is stored in the KVS correctly.
        result = self.kvs_client.get(dag_name)[dag_name]
        created = Dag()
        created.ParseFromString(result.reveal())
        self.assertEqual(created, dag)

        # Test that the correct response was returned to the user.
        self.assertTrue(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox.pop())
        self.assertTrue(response.success)

        # Test that the correct pin messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)
        messages = self.pusher_cache.socket.outbox
        function_set = {fn}
        for message in messages:
            pin_msg = PinFunction()
            pin_msg.ParseFromString(message)
            self.assertEqual(pin_msg.response_address, self.ip)
            self.assertTrue(pin_msg.name in function_set)
            function_set.discard(pin_msg.name)

        self.assertEqual(len(function_set), 0)

        for address in address_set:
            self.assertTrue(
                get_pin_address(*address) in self.pusher_cache.addresses)

        # Test that the policy engine has the correct metadata stored.
        self.assertEqual(len(self.policy.unpinned_cpu_executors), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)
        self.assertTrue(fn in self.policy.function_locations)

        self.assertEqual(len(self.policy.function_locations[fn]), 1)
示例#21
0
    def test_create_dag(self):
        '''
        This test creates a new DAG, checking that the correct pin messages are
        sent to executors and that it is persisted in the KVS correctly. It
        also checks that the server metadata was updated as expected.
        '''
        # Create a simple two-function DAG and add it to the inbound socket.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)
        self.socket.inbox.append(dag.SerializeToString())

        # Add relevant metadata to the policy engine.
        address_set = {(self.ip, 1), (self.ip, 2)}
        self.policy.unpinned_executors.update(address_set)

        # Prepopulate the pin_accept socket with sufficient success messages.
        self.pin_socket.inbox.append(sutils.ok_resp)
        self.pin_socket.inbox.append(sutils.ok_resp)

        # Call the DAG creation method.
        dags = {}
        call_frequency = {}
        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Test that the correct metadata was created.
        self.assertTrue(dag_name in dags)
        created, dag_source = dags[dag_name]
        self.assertEqual(created, dag)
        self.assertEqual(len(dag_source), 1)
        self.assertEqual(list(dag_source)[0], source)
        self.assertTrue(source in call_frequency)
        self.assertTrue(sink in call_frequency)
        self.assertEqual(call_frequency[source], 0)
        self.assertEqual(call_frequency[sink], 0)

        # Test that the DAG is stored in the KVS correctly.
        result = self.kvs_client.get(dag_name)[dag_name]
        created = Dag()
        created.ParseFromString(result.reveal())
        self.assertEqual(created, dag)

        # Test that the correct response was returned to the user.
        self.assertTrue(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox.pop())
        self.assertTrue(response.success)

        # Test that the correct pin messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 2)
        messages = self.pusher_cache.socket.outbox
        function_set = {source, sink}
        for message in messages:
            self.assertTrue(':' in message)
            ip, fname = message.split(':')
            self.assertEqual(ip, self.ip)
            self.assertTrue(fname in function_set)
            function_set.discard(fname)

        self.assertEqual(len(function_set), 0)

        for address in address_set:
            self.assertTrue(
                get_pin_address(*address) in self.pusher_cache.addresses)

        # Test that the policy engine has the correct metadata stored.
        self.assertEqual(len(self.policy.unpinned_executors), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)
        self.assertTrue(source in self.policy.function_locations)
        self.assertTrue(sink in self.policy.function_locations)

        self.assertEqual(len(self.policy.function_locations[source]), 1)
        self.assertEqual(len(self.policy.function_locations[sink]), 1)
示例#22
0
    def test_create_dag_insufficient_resources(self):
        '''
        This test attempts to create a DAG even though there are not enough
        free executors in the system. It checks that a pin message is attempted
        to be sent, we run out of resources, and then the request is rejected.
        We check that the metadata is properly restored back to its original
        state.
        '''
        # Create a simple two-function DAG and add it to the inbound socket.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)
        self.socket.inbox.append(dag.SerializeToString())

        # Add relevant metadata to the policy engine, but set the number of
        # executors to fewer than needed.
        address_set = {(self.ip, 1)}
        self.policy.unpinned_executors.update(address_set)

        # Prepopulate the pin_accept socket with sufficient success messages.
        self.pin_socket.inbox.append(sutils.ok_resp)

        # Attempt to create the DAG.
        dags = {}
        call_frequency = {}
        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Check that an error was returned to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])
        self.assertFalse(response.success)
        self.assertEqual(response.error, NO_RESOURCES)

        # Test that the correct pin messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 2)
        messages = self.pusher_cache.socket.outbox

        # Checks for the pin message.
        self.assertTrue(':' in messages[0])
        ip, fname = messages[0].split(':')
        self.assertEqual(ip, self.ip)
        self.assertEqual(source, fname)

        # Checks for the unpin message.
        self.assertEqual(messages[1], source)

        address = random.sample(address_set, 1)[0]
        addresses = self.pusher_cache.addresses
        self.assertEqual(get_pin_address(*address), addresses[0])
        self.assertEqual(get_unpin_address(*address), addresses[1])

        # Check that no additional messages were sent.
        self.assertEqual(len(self.policy.unpinned_executors), 0)
        self.assertEqual(len(self.policy.function_locations), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)

        # Check that no additional metadata was created or sent.
        self.assertEqual(len(call_frequency), 0)
        self.assertEqual(len(dags), 0)