Ejemplo n.º 1
0
    def register_dag(self, name, functions, connections):
        '''
        Registers a new DAG with the system. This operation will fail if any of
        the functions provided cannot be identified in the system.

        name: A unique name for this DAG.
        functions: A list of names of functions to be included in this DAG.
        connections: A list of ordered pairs of function names that represent
        the edges in this DAG.
        '''

        flist = self._get_func_list()
        for fname in functions:
            if fname not in flist:
                logging.info(
                    'Function %s not registered. Please register before \
                    including it in a DAG.' % (fname))
                return False, None

        dag = Dag()
        dag.name = name
        dag.functions.extend(functions)
        for pair in connections:
            conn = dag.connections.add()
            conn.source = pair[0]
            conn.sink = pair[1]

        self.dag_create_sock.send(dag.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.dag_create_sock.recv())

        return r.success, r.error
Ejemplo n.º 2
0
    def test_succesful_pin(self):
        '''
        This test executes a pin operation that is supposed to be successful,
        and it checks to make sure that that the correct metadata for execution
        and reporting is generated.
        '''
        # Create a new function in the KVS.
        fname = 'incr'

        def func(_, x): return x + 1
        create_function(func, self.kvs_client, fname)

        # Create a pin message and put it into the socket.
        msg = self.ip + ':' + fname
        self.socket.inbox.append(msg)

        # Execute the pin operation.
        pin(self.socket, self.pusher_cache, self.kvs_client, self.status,
            self.pinned_functions, self.runtimes, self.exec_counts)

        # Check that the correct messages were sent and the correct metadata
        # created.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.pusher_cache.socket.outbox[0])
        self.assertTrue(response.success)

        self.assertEqual(func('', 1), self.pinned_functions[fname]('', 1))
        self.assertTrue(fname in self.pinned_functions)
        self.assertTrue(fname in self.runtimes)
        self.assertTrue(fname in self.exec_counts)
        self.assertTrue(fname in self.status.functions)
Ejemplo n.º 3
0
    def test_exec_function_nonexistent(self):
        '''
        Attempts to execute a non-existent function and ensures that an error
        is thrown and returned to the user.
        '''
        # Create a call to a function that does not exist.
        call = self._create_function_call('bad_func', [1], NORMAL)
        self.socket.inbox.append(call.SerializeToString())

        # Attempt to execute the nonexistent function.
        exec_function(self.socket, self.kvs_client, self.user_library, {})

        # Assert that there have been 0 messages sent.
        self.assertEqual(len(self.socket.outbox), 0)

        # Retrieve the result from the KVS and ensure that it is a lattice as
        # we expect. Deserialize it and check for an error.
        result = self.kvs_client.get(self.response_key)[self.response_key]
        self.assertEqual(type(result), LWWPairLattice)
        result = serializer.load_lattice(result)

        # Check the type and values of the error.
        self.assertEqual(type(result), tuple)
        self.assertEqual(result[0], 'ERROR')

        # Unpack the GenericResponse and check its values.
        response = GenericResponse()
        response.ParseFromString(result[1])
        self.assertEqual(response.success, False)
        self.assertEqual(response.error, FUNC_NOT_FOUND)
Ejemplo n.º 4
0
    def test_function_call_no_resources(self):
        '''
        Constructs a scenario where there are no available resources in the
        system, and ensures that the scheduler correctly returns an error to
        the user.
        '''
        # Clear all executors from the system.
        self.policy.thread_statuses.clear()
        self.policy.unpinned_executors.clear()

        # Create a function call.
        call = FunctionCall()
        call.name = 'function'
        call.request_id = 12
        val = call.arguments.values.add()
        serializer.dump(2, val)
        self.socket.inbox.append(call.SerializeToString())

        # Execute the scheduling policy.
        call_function(self.socket, self.pusher_cache, self.policy)

        # Check that the correct number of messages were sent.
        self.assertEqual(len(self.socket.outbox), 1)
        self.assertEqual(len(self.pusher_cache.socket.outbox), 0)

        # Extract and deserialize the messages.
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])

        self.assertFalse(response.success)
        self.assertEqual(response.error, NO_RESOURCES)
Ejemplo n.º 5
0
    def test_exec_function_with_error(self):
        '''
        Attempts to executet a function that raises an error during its
        execution. Ensures that an error is returned to the user.
        '''
        e_msg = 'This is a broken_function!'

        def func(_, x):
            raise ValueError(e_msg)

        fname = 'func'
        create_function(func, self.kvs_client, fname)

        # Put the functin into the KVS and create a function call.
        call = self._create_function_call(fname, [''], NORMAL)
        self.socket.inbox.append(call.SerializeToString())

        # Execute the function call.
        exec_function(self.socket, self.kvs_client, self.user_library, {})

        # Retrieve the result from the KVS and ensure that it is the correct
        # lattice type.
        result = self.kvs_client.get(self.response_key)[self.response_key]
        self.assertEqual(type(result), LWWPairLattice)
        result = serializer.load_lattice(result)

        # Check the type and values of the error.
        self.assertEqual(type(result), tuple)
        self.assertTrue(e_msg in result[0])

        # Unpack the GenericResponse and check its values.
        response = GenericResponse()
        response.ParseFromString(result[1])
        self.assertEqual(response.success, False)
        self.assertEqual(response.error, EXECUTION_ERROR)
Ejemplo n.º 6
0
    def register(self, function, name):
        '''
        Registers a new function or class with the system. The returned object
        can be called like a regular Python function, which returns a Droplet
        Future. If the input is a class, the class is expected to have a run
        method, which is what is invoked at runtime.

        function: The function object that we are registering.
        name: A unique name for the function to be stored with in the system.
        '''

        func = Function()
        func.name = name
        func.body = serializer.dump(function)

        self.func_create_sock.send(func.SerializeToString())

        resp = GenericResponse()
        resp.ParseFromString(self.func_create_sock.recv())

        if resp.success:
            return DropletFunction(name, self, self.kvs_client)
        else:
            print('Unexpected error while registering function: \n\t%s.' %
                  (resp))
Ejemplo n.º 7
0
    def delete_dag(self, dname):
        '''
        Removes the specified DAG from the system.

        dname: The name of the DAG to delete.
        '''
        self.dag_delete_sock.send_string(dname)

        r = GenericResponse()
        r.ParseFromString(self.dag_delete_sock.recv())

        return r.success, r.error
Ejemplo n.º 8
0
    def test_delete_dag(self):
        '''
        We attempt to delete a DAG that has already been created and check to
        ensure that the correct unpin messages are sent to executors and that
        the metadata is updated appropriately.
        '''
        # Create a simple two fucntion DAG and add it to the system metadata.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)

        dags = {}
        call_frequency = {}
        dags[dag.name] = (dag, {source})
        call_frequency[source] = 100
        call_frequency[sink] = 100

        # Add the correct metadata to the policy engine.
        source_location = (self.ip, 1)
        sink_location = (self.ip, 2)
        self.policy.function_locations[source] = {source_location}
        self.policy.function_locations[sink] = {sink_location}

        self.socket.inbox.append(dag.name)

        # Attempt to delete the DAG.
        delete_dag(self.socket, dags, self.policy, call_frequency)

        # Check that the correct unpin messages were sent.
        messages = self.pusher_cache.socket.outbox
        self.assertEqual(len(messages), 2)
        self.assertEqual(messages[0], source)
        self.assertEqual(messages[1], sink)

        addresses = self.pusher_cache.addresses
        self.assertEqual(len(addresses), 2)
        self.assertEqual(addresses[0], get_unpin_address(*source_location))
        self.assertEqual(addresses[1], get_unpin_address(*sink_location))

        # Check that the server metadata was updated correctly.
        self.assertEqual(len(dags), 0)
        self.assertEqual(len(call_frequency), 0)

        # Check that the correct message was sent to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox.pop())
        self.assertTrue(response.success)
Ejemplo n.º 9
0
def call_function(func_call_socket, pusher_cache, policy):
    # Parse the received protobuf for this function call.
    call = FunctionCall()
    call.ParseFromString(func_call_socket.recv())

    # If there is no response key set for this request, we generate a random
    # UUID.
    if not call.response_key:
        call.response_key = str(uuid.uuid4())

    # Filter the arguments for DropletReferences, and use the policy engine to
    # pick a node for this request.
    refs = list(
        filter(lambda arg: type(arg) == DropletReference,
               map(lambda arg: serializer.load(arg), call.arguments.values)))
    result = policy.pick_executor(refs)

    response = GenericResponse()
    if result is None:
        response.success = False
        response.error = NO_RESOURCES
        func_call_socket.send(response.SerializeToString())
        return

    # Forward the request on to the chosen executor node.
    ip, tid = result
    sckt = pusher_cache.get(utils.get_exec_address(ip, tid))
    sckt.send(call.SerializeToString())

    # Send a success response to the user with the response key.
    response.success = True
    response.response_id = call.response_key
    func_call_socket.send(response.SerializeToString())
Ejemplo n.º 10
0
    def pin_function(self, dag_name, function_name):
        # If there are no functions left to choose from, then we return None,
        # indicating that we ran out of resources to use.
        if len(self.unpinned_executors) == 0:
            return False

        if dag_name not in self.pending_dags:
            self.pending_dags[dag_name] = []

        # Make a copy of the set of executors, so that we don't modify the
        # system's metadata.
        candidates = set(self.unpinned_executors)

        while True:
            # Pick a random executor from the set of candidates and attempt to
            # pin this function there.
            node, tid = sys_random.sample(candidates, 1)[0]

            sckt = self.pusher_cache.get(get_pin_address(node, tid))
            msg = self.ip + ':' + function_name
            sckt.send_string(msg)

            response = GenericResponse()
            try:
                response.ParseFromString(self.pin_accept_socket.recv())
            except zmq.ZMQError:
                logging.error('Pin operation to %s:%d timed out. Retrying.' %
                              (node, tid))
                continue

            # Do not use this executor either way: If it rejected, it has
            # something else pinned, and if it accepted, it has pinned what we
            # just asked it to pin.
            self.unpinned_executors.discard((node, tid))
            candidates.discard((node, tid))

            if response.success:
                # The pin operation succeeded, so we return the node and thread
                # ID to the caller.
                self.pending_dags[dag_name].append(
                    (function_name, (node, tid)))
                return True
            else:
                # The pin operation was rejected, remove node and try again.
                logging.error('Node %s:%d rejected pin for %s. Retrying.' %
                              (node, tid, function_name))

                continue
Ejemplo n.º 11
0
    def exec_func(self, name, args):
        call = FunctionCall()
        call.name = name
        call.request_id = self.rid

        for arg in args:
            argobj = call.arguments.values.add()
            serializer.dump(arg, argobj)

        self.func_call_sock.send(call.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.func_call_sock.recv())

        self.rid += 1
        return r.response_id
Ejemplo n.º 12
0
    def test_call_function_with_refs(self):
        '''
        Creates a scenario where the policy should deterministically pick the
        same executor to run a request on: There is one reference, and it's
        cached only on the node we create in this test.
        '''
        # Add a new executor for which we will construct cached references.
        ip_address = '192.168.0.1'
        new_key = (ip_address, 2)
        self.policy.unpinned_executors.add(new_key)

        # Create a new reference and add its metadata.
        ref_name = 'reference'
        self.policy.key_locations[ref_name] = [ip_address]

        # Create a function call that asks for this reference.
        call = FunctionCall()
        call.name = 'function'
        call.request_id = 12
        val = call.arguments.values.add()
        serializer.dump(DropletReference(ref_name, True), val)
        self.socket.inbox.append(call.SerializeToString(0))

        # Execute the scheduling policy.
        call_function(self.socket, self.pusher_cache, self.policy)

        # Check that the correct number of messages were sent.
        self.assertEqual(len(self.socket.outbox), 1)
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)

        # Extract and deserialize the messages.
        response = GenericResponse()
        forwarded = FunctionCall()
        response.ParseFromString(self.socket.outbox[0])
        forwarded.ParseFromString(self.pusher_cache.socket.outbox[0])

        self.assertTrue(response.success)
        self.assertEqual(response.response_id, forwarded.response_key)
        self.assertEqual(forwarded.name, call.name)
        self.assertEqual(forwarded.request_id, call.request_id)

        # Makes sure that the correct executor was chosen.
        self.assertEqual(len(self.pusher_cache.addresses), 1)
        self.assertEqual(self.pusher_cache.addresses[0],
                         utils.get_exec_address(*new_key))
Ejemplo n.º 13
0
    def test_occupied_pin(self):
        '''
        This test attempts to pin a function onto a node where another function
        is already pinned. We currently only allow one pinned node per machine,
        so this operation should fail.
        '''
        # Create a new function in the KVS.
        fname = 'incr'

        def func(_, x):
            return x + 1

        create_function(func, self.kvs_client, fname)

        # Create a pin message and put it into the socket.
        msg = self.ip + ':' + fname
        self.socket.inbox.append(msg)

        # Add an already pinned_function, so that we reject the request.
        self.pinned_functions['square'] = lambda _, x: x * x
        self.runtimes['square'] = []
        self.exec_counts['square'] = []
        self.status.functions.append('square')

        # Execute the pin operation.
        pin(self.socket, self.pusher_cache, self.kvs_client, self.status,
            self.pinned_functions, self.runtimes, self.exec_counts,
            self.user_library)

        # Check that the correct messages were sent and the correct metadata
        # created.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.pusher_cache.socket.outbox[0])
        self.assertFalse(response.success)

        # Make sure that none of the metadata was corrupted with this failed
        # pin attempt
        self.assertTrue(fname not in self.pinned_functions)
        self.assertTrue(fname not in self.runtimes)
        self.assertTrue(fname not in self.exec_counts)
        self.assertTrue(fname not in self.status.functions)
Ejemplo n.º 14
0
    def test_delete_nonexistent_dag(self):
        '''
        This test attempts to delete a nonexistent DAG and ensures that no
        metadata is affected by the failed operation.
        '''

        # Make a request to delete an unknown DAG.
        self.socket.inbox.append('dag')
        delete_dag(self.socket, {}, self.policy, {})

        # Ensure that an error response is sent to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])
        self.assertFalse(response.success)
        self.assertEqual(response.error, NO_SUCH_DAG)

        # Check that no additional messages were sent and no metadata changed.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 0)
        self.assertEqual(len(self.policy.function_locations), 0)
        self.assertEqual(len(self.policy.unpinned_executors), 0)
Ejemplo n.º 15
0
    def test_create_dag_already_exists(self):
        '''
        This test attempts to create a DAG that already exists and makes sure
        that the server correctly rejects the request.
        '''
        # Create a simple two-function DAG and add it to the inbound socket.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)
        self.socket.inbox.append(dag.SerializeToString())

        # Add this to the existing server metadata.
        dags = {dag.name: (dag, {source})}

        # Add relevant metadata to the policy engine.
        address_set = {(self.ip, 1), (self.ip, 2)}
        self.policy.unpinned_executors.update(address_set)

        # Attempt to create the DAG.
        call_frequency = {}
        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Check that an error was returned to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])
        self.assertFalse(response.success)
        self.assertEqual(response.error, DAG_ALREADY_EXISTS)

        # Check that no additional metadata was created or sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 0)
        self.assertEqual(len(self.policy.unpinned_executors), 2)
        self.assertEqual(len(self.policy.function_locations), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)
Ejemplo n.º 16
0
    def test_call_function_no_refs(self):
        '''
        A basic test that makes sure that an executor is successfully chosen
        when there is only one possible executor to choose from.
        '''
        # Create a new function call for a function that doesn't exist.
        call = FunctionCall()
        call.name = 'function'
        call.request_id = 12

        # Add an argument to thhe function.
        val = call.arguments.values.add()
        serializer.dump(2, val)
        self.socket.inbox.append(call.SerializeToString())

        # Execute the scheduling policy.
        call_function(self.socket, self.pusher_cache, self.policy)

        # Check that the correct number of messages were sent.
        self.assertEqual(len(self.socket.outbox), 1)
        self.assertEqual(len(self.pusher_cache.socket.outbox), 1)

        # Extract and deserialize the messages.
        response = GenericResponse()
        forwarded = FunctionCall()
        response.ParseFromString(self.socket.outbox[0])
        forwarded.ParseFromString(self.pusher_cache.socket.outbox[0])

        self.assertTrue(response.success)
        self.assertEqual(response.response_id, forwarded.response_key)
        self.assertEqual(forwarded.name, call.name)
        self.assertEqual(forwarded.request_id, call.request_id)

        # Makes sure that the correct executor was chosen.
        self.assertEqual(len(self.pusher_cache.addresses), 1)
        self.assertEqual(self.pusher_cache.addresses[0],
                         utils.get_exec_address(*self.executor_key))
Ejemplo n.º 17
0
def call_dag(call, pusher_cache, dags, policy):
    dag, sources = dags[call.name]

    schedule = DagSchedule()
    schedule.id = str(uuid.uuid4())
    schedule.dag.CopyFrom(dag)
    schedule.start_time = time.time()
    schedule.consistency = call.consistency

    if call.response_address:
        schedule.response_address = call.response_address

    if call.output_key:
        schedule.output_key = call.output_key

    if call.client_id:
        schedule.client_id = call.client_id

    for fname in dag.functions:
        args = call.function_args[fname].values

        refs = list(
            filter(lambda arg: type(arg) == DropletReference,
                   map(lambda arg: serializer.load(arg), args)))

        result = policy.pick_executor(refs, fname)
        if result is None:
            response = GenericResponse()
            response.success = False
            response.error = NO_RESOURCES
            return response

        ip, tid = result
        schedule.locations[fname] = ip + ':' + str(tid)

        # copy over arguments into the dag schedule
        arg_list = schedule.arguments[fname]
        arg_list.values.extend(args)

    for fname in dag.functions:
        loc = schedule.locations[fname].split(':')
        ip = utils.get_queue_address(loc[0], loc[1])
        schedule.target_function = fname

        triggers = sutils.get_dag_predecessors(dag, fname)
        if len(triggers) == 0:
            triggers.append('BEGIN')

        schedule.ClearField('triggers')
        schedule.triggers.extend(triggers)

        sckt = pusher_cache.get(ip)
        sckt.send(schedule.SerializeToString())

    for source in sources:
        trigger = DagTrigger()
        trigger.id = schedule.id
        trigger.source = 'BEGIN'
        trigger.target_function = source

        ip = sutils.get_dag_trigger_address(schedule.locations[source])
        sckt = pusher_cache.get(ip)
        sckt.send(trigger.SerializeToString())

    response = GenericResponse()
    response.success = True
    if schedule.output_key:
        response.response_id = schedule.output_key
    else:
        response.response_id = schedule.id

    return response
Ejemplo n.º 18
0
DAG_QUEUE_PORT = 4030
DAG_EXEC_PORT = 4040
SELF_DEPART_PORT = 4050

STATUS_PORT = 5007
SCHED_UPDATE_PORT = 5008
BACKOFF_PORT = 5009
PIN_ACCEPT_PORT = 5010

# For message sending via the user library.
RECV_INBOX_PORT = 5500

STATISTICS_REPORT_PORT = 7006

# Create a generic error response protobuf.
error = GenericResponse()
error.success = False

# Create a generic success response protobuf.
ok = GenericResponse()
ok.success = True
ok_resp = ok.SerializeToString()

# Create a default vector clock for keys that have no dependencies.
DEFAULT_VC = VectorClock({'base': MaxIntLattice(1)})


def get_func_kvs_name(fname):
    return FUNC_PREFIX + fname

Ejemplo n.º 19
0
    def test_create_dag_insufficient_resources(self):
        '''
        This test attempts to create a DAG even though there are not enough
        free executors in the system. It checks that a pin message is attempted
        to be sent, we run out of resources, and then the request is rejected.
        We check that the metadata is properly restored back to its original
        state.
        '''
        # Create a simple two-function DAG and add it to the inbound socket.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)
        self.socket.inbox.append(dag.SerializeToString())

        # Add relevant metadata to the policy engine, but set the number of
        # executors to fewer than needed.
        address_set = {(self.ip, 1)}
        self.policy.unpinned_executors.update(address_set)

        # Prepopulate the pin_accept socket with sufficient success messages.
        self.pin_socket.inbox.append(sutils.ok_resp)

        # Attempt to create the DAG.
        dags = {}
        call_frequency = {}
        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Check that an error was returned to the user.
        self.assertEqual(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox[0])
        self.assertFalse(response.success)
        self.assertEqual(response.error, NO_RESOURCES)

        # Test that the correct pin messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 2)
        messages = self.pusher_cache.socket.outbox

        # Checks for the pin message.
        self.assertTrue(':' in messages[0])
        ip, fname = messages[0].split(':')
        self.assertEqual(ip, self.ip)
        self.assertEqual(source, fname)

        # Checks for the unpin message.
        self.assertEqual(messages[1], source)

        address = random.sample(address_set, 1)[0]
        addresses = self.pusher_cache.addresses
        self.assertEqual(get_pin_address(*address), addresses[0])
        self.assertEqual(get_unpin_address(*address), addresses[1])

        # Check that no additional messages were sent.
        self.assertEqual(len(self.policy.unpinned_executors), 0)
        self.assertEqual(len(self.policy.function_locations), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)

        # Check that no additional metadata was created or sent.
        self.assertEqual(len(call_frequency), 0)
        self.assertEqual(len(dags), 0)
Ejemplo n.º 20
0
    def call_dag(self,
                 dname,
                 arg_map,
                 direct_response=False,
                 consistency=NORMAL,
                 output_key=None,
                 client_id=None):
        '''
        Issues a new request to execute the DAG. Returns a DropletFuture that

        dname: The name of the DAG to cexecute.
        arg_map: A map from function names to lists of arguments for each of
        the functions in the DAG.
        direct_response: If True, the response will be synchronously received
        by the client; otherwise, the result will be stored in the KVS.
        consistency: The consistency mode to use with this function: either
        NORMAL or MULTI.
        output_key: The KVS key in which to store the result of thie DAG.
        client_id: An optional ID associated with an individual client across
        requests; this is used for causal metadata.
        '''
        dc = DagCall()
        dc.name = dname
        dc.consistency = consistency

        if output_key:
            dc.output_key = output_key

        if client_id:
            dc.client_id = client_id

        for fname in arg_map:
            args = [
                serializer.dump(arg, serialize=False) for arg in arg_map[fname]
            ]
            al = dc.function_args[fname]
            al.values.extend(args)

        if direct_response:
            dc.response_address = self.response_address

        self.dag_call_sock.send(dc.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.dag_call_sock.recv())

        if direct_response:
            try:
                result = self.response_sock.recv()
                return serializer.load(result)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    return None
                else:
                    raise e
        else:
            if r.success:
                return DropletFuture(r.response_id, self.kvs_client,
                                     serializer)
            else:
                return None
Ejemplo n.º 21
0
    def test_create_dag(self):
        '''
        This test creates a new DAG, checking that the correct pin messages are
        sent to executors and that it is persisted in the KVS correctly. It
        also checks that the server metadata was updated as expected.
        '''
        # Create a simple two-function DAG and add it to the inbound socket.
        source = 'source'
        sink = 'sink'
        dag_name = 'dag'

        dag = create_linear_dag([None, None], [source, sink], self.kvs_client,
                                dag_name)
        self.socket.inbox.append(dag.SerializeToString())

        # Add relevant metadata to the policy engine.
        address_set = {(self.ip, 1), (self.ip, 2)}
        self.policy.unpinned_executors.update(address_set)

        # Prepopulate the pin_accept socket with sufficient success messages.
        self.pin_socket.inbox.append(sutils.ok_resp)
        self.pin_socket.inbox.append(sutils.ok_resp)

        # Call the DAG creation method.
        dags = {}
        call_frequency = {}
        create_dag(self.socket, self.pusher_cache, self.kvs_client, dags,
                   self.policy, call_frequency)

        # Test that the correct metadata was created.
        self.assertTrue(dag_name in dags)
        created, dag_source = dags[dag_name]
        self.assertEqual(created, dag)
        self.assertEqual(len(dag_source), 1)
        self.assertEqual(list(dag_source)[0], source)
        self.assertTrue(source in call_frequency)
        self.assertTrue(sink in call_frequency)
        self.assertEqual(call_frequency[source], 0)
        self.assertEqual(call_frequency[sink], 0)

        # Test that the DAG is stored in the KVS correctly.
        result = self.kvs_client.get(dag_name)[dag_name]
        created = Dag()
        created.ParseFromString(result.reveal())
        self.assertEqual(created, dag)

        # Test that the correct response was returned to the user.
        self.assertTrue(len(self.socket.outbox), 1)
        response = GenericResponse()
        response.ParseFromString(self.socket.outbox.pop())
        self.assertTrue(response.success)

        # Test that the correct pin messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 2)
        messages = self.pusher_cache.socket.outbox
        function_set = {source, sink}
        for message in messages:
            self.assertTrue(':' in message)
            ip, fname = message.split(':')
            self.assertEqual(ip, self.ip)
            self.assertTrue(fname in function_set)
            function_set.discard(fname)

        self.assertEqual(len(function_set), 0)

        for address in address_set:
            self.assertTrue(
                get_pin_address(*address) in self.pusher_cache.addresses)

        # Test that the policy engine has the correct metadata stored.
        self.assertEqual(len(self.policy.unpinned_executors), 0)
        self.assertEqual(len(self.policy.pending_dags), 0)
        self.assertTrue(source in self.policy.function_locations)
        self.assertTrue(sink in self.policy.function_locations)

        self.assertEqual(len(self.policy.function_locations[source]), 1)
        self.assertEqual(len(self.policy.function_locations[sink]), 1)