Example #1
0
    def test_dag_call_no_refs(self):
        '''
        Tests a DAG call without any references. We do not currently have a
        test for selecting a DAG call with references because the reference
        logic in the default policy is the same for individual functions
        (tested above) and for DAGs.
        '''

        # Create a simple DAG.
        source = 'source'
        sink = 'sink'
        dag, source_address, sink_address = self._construct_dag_with_locations(
            source, sink)

        # Create a DAG call that corresponds to this new DAG.
        call = DagCall()
        call.name = dag.name
        call.consistency = NORMAL
        call.output_key = 'output_key'
        call.client_id = '0'

        # Execute the scheduling policy.
        call_dag(call, self.pusher_cache, {dag.name: (dag, {source})},
                 self.policy)

        # Check that the correct number of messages were sent.
        self.assertEqual(len(self.pusher_cache.socket.outbox), 3)

        # Extract each of the two schedules and ensure that they are correct.
        source_schedule = DagSchedule()
        source_schedule.ParseFromString(self.pusher_cache.socket.outbox[0])
        self._verify_dag_schedule(source, 'BEGIN', source_schedule, dag, call)

        sink_schedule = DagSchedule()
        sink_schedule.ParseFromString(self.pusher_cache.socket.outbox[1])
        self._verify_dag_schedule(sink, source, sink_schedule, dag, call)

        # Make sure that only trigger was sent, and it was for the DAG source.
        trigger = DagTrigger()
        trigger.ParseFromString(self.pusher_cache.socket.outbox[2])
        self.assertEqual(trigger.id, source_schedule.id)
        self.assertEqual(trigger.target_function, source)
        self.assertEqual(trigger.source, 'BEGIN')
        self.assertEqual(len(trigger.version_locations), 0)
        self.assertEqual(len(trigger.dependencies), 0)

        # Ensure that all the the destination addresses match the addresses we
        # expect.
        self.assertEqual(len(self.pusher_cache.addresses), 3)
        self.assertEqual(self.pusher_cache.addresses[0],
                         utils.get_queue_address(*source_address))
        self.assertEqual(self.pusher_cache.addresses[1],
                         utils.get_queue_address(*sink_address))
        self.assertEqual(
            self.pusher_cache.addresses[2],
            sutils.get_dag_trigger_address(':'.join(
                map(lambda s: str(s), source_address))))
Example #2
0
def scheduler(ip, mgmt_ip, route_addr):

    # If the management IP is not set, we are running in local mode.
    local = (mgmt_ip is None)
    kvs = AnnaTcpClient(route_addr, ip, local=local)

    scheduler_id = str(uuid.uuid4())

    context = zmq.Context(1)

    # A mapping from a DAG's name to its protobuf representation.
    dags = {}

    # Tracks how often a request for each function is received.
    call_frequency = {}

    # Tracks the time interval between successive requests for a particular
    # DAG.
    interarrivals = {}

    # Tracks the most recent arrival for each DAG -- used to calculate
    # interarrival times.
    last_arrivals = {}

    # Maintains a list of all other schedulers in the system, so we can
    # propagate metadata to them.
    schedulers = []

    connect_socket = context.socket(zmq.REP)
    connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT))

    func_create_socket = context.socket(zmq.REP)
    func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT))

    func_call_socket = context.socket(zmq.REP)
    func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT))

    dag_create_socket = context.socket(zmq.REP)
    dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT))

    dag_call_socket = context.socket(zmq.REP)
    dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT))

    dag_delete_socket = context.socket(zmq.REP)
    dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT))

    list_socket = context.socket(zmq.REP)
    list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT))

    exec_status_socket = context.socket(zmq.PULL)
    exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT))

    sched_update_socket = context.socket(zmq.PULL)
    sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.SCHED_UPDATE_PORT))

    pin_accept_socket = context.socket(zmq.PULL)
    pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500)
    pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                           (sutils.PIN_ACCEPT_PORT))

    requestor_cache = SocketCache(context, zmq.REQ)
    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(connect_socket, zmq.POLLIN)
    poller.register(func_create_socket, zmq.POLLIN)
    poller.register(func_call_socket, zmq.POLLIN)
    poller.register(dag_create_socket, zmq.POLLIN)
    poller.register(dag_call_socket, zmq.POLLIN)
    poller.register(dag_delete_socket, zmq.POLLIN)
    poller.register(list_socket, zmq.POLLIN)
    poller.register(exec_status_socket, zmq.POLLIN)
    poller.register(sched_update_socket, zmq.POLLIN)

    # Start the policy engine.
    policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket,
                                              pusher_cache,
                                              kvs,
                                              ip,
                                              local=local)
    policy.update()

    start = time.time()

    while True:
        socks = dict(poller.poll(timeout=1000))

        if connect_socket in socks and socks[connect_socket] == zmq.POLLIN:
            msg = connect_socket.recv_string()
            connect_socket.send_string(route_addr)

        if (func_create_socket in socks
                and socks[func_create_socket] == zmq.POLLIN):
            create_function(func_create_socket, kvs)

        if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN:
            call_function(func_call_socket, pusher_cache, policy)

        if (dag_create_socket in socks
                and socks[dag_create_socket] == zmq.POLLIN):
            create_dag(dag_create_socket, pusher_cache, kvs, dags, policy,
                       call_frequency)

        if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN:
            call = DagCall()
            call.ParseFromString(dag_call_socket.recv())

            name = call.name

            t = time.time()
            if name in last_arrivals:
                if name not in interarrivals:
                    interarrivals[name] = []

                interarrivals[name].append(t - last_arrivals[name])

            last_arrivals[name] = t

            if name not in dags:
                resp = GenericResponse()
                resp.success = False
                resp.error = NO_SUCH_DAG

                dag_call_socket.send(resp.SerializeToString())
                continue

            dag = dags[name]
            for fname in dag[0].functions:
                call_frequency[fname.name] += 1

            response = call_dag(call, pusher_cache, dags, policy)
            dag_call_socket.send(response.SerializeToString())

        if (dag_delete_socket in socks
                and socks[dag_delete_socket] == zmq.POLLIN):
            delete_dag(dag_delete_socket, dags, policy, call_frequency)

        if list_socket in socks and socks[list_socket] == zmq.POLLIN:
            msg = list_socket.recv_string()
            prefix = msg if msg else ''

            resp = StringSet()
            resp.keys.extend(sched_utils.get_func_list(kvs, prefix))

            list_socket.send(resp.SerializeToString())

        if exec_status_socket in socks and socks[exec_status_socket] == \
                zmq.POLLIN:
            status = ThreadStatus()
            status.ParseFromString(exec_status_socket.recv())

            policy.process_status(status)

        if sched_update_socket in socks and socks[sched_update_socket] == \
                zmq.POLLIN:
            status = SchedulerStatus()
            status.ParseFromString(sched_update_socket.recv())

            # Retrieve any DAGs that some other scheduler knows about that we
            # do not yet know about.
            for dname in status.dags:
                if dname not in dags:
                    payload = kvs.get(dname)
                    while None in payload:
                        payload = kvs.get(dname)

                    dag = Dag()
                    dag.ParseFromString(payload[dname].reveal())
                    dags[dag.name] = (dag, sched_utils.find_dag_source(dag))

                    for fname in dag.functions:
                        if fname not in call_frequency:
                            call_frequency[fname] = 0

            policy.update_function_locations(status.function_locations)

        end = time.time()

        if end - start > METADATA_THRESHOLD:
            # Update the scheduler policy-related metadata.
            policy.update()

            # If the management IP is None, that means we arre running in
            # local mode, so there is no need to deal with caches and other
            # schedulers.
            if mgmt_ip:
                schedulers = sched_utils.get_ip_set(
                    sched_utils.get_scheduler_list_address(mgmt_ip),
                    requestor_cache, False)

        if end - start > REPORT_THRESHOLD:
            num_unique_executors = policy.get_unique_executors()
            key = scheduler_id + ':' + str(time.time())
            data = {'key': key, 'count': num_unique_executors}

            status = SchedulerStatus()
            for name in dags.keys():
                status.dags.append(name)

            for fname in policy.function_locations:
                for loc in policy.function_locations[fname]:
                    floc = status.function_locations.add()
                    floc.name = fname
                    floc.ip = loc[0]
                    floc.tid = loc[1]

            msg = status.SerializeToString()

            for sched_ip in schedulers:
                if sched_ip != ip:
                    sckt = pusher_cache.get(
                        sched_utils.get_scheduler_update_address(sched_ip))
                    sckt.send(msg)

            stats = ExecutorStatistics()
            for fname in call_frequency:
                fstats = stats.functions.add()
                fstats.name = fname
                fstats.call_count = call_frequency[fname]
                logging.info('Reporting %d calls for function %s.' %
                             (call_frequency[fname], fname))

                call_frequency[fname] = 0

            for dname in interarrivals:
                dstats = stats.dags.add()
                dstats.name = dname
                dstats.call_count = len(interarrivals[dname]) + 1
                dstats.interarrival.extend(interarrivals[dname])

                interarrivals[dname].clear()

            # We only attempt to send the statistics if we are running in
            # cluster mode. If we are running in local mode, we write them to
            # the local log file.
            if mgmt_ip:
                sckt = pusher_cache.get(
                    sutils.get_statistics_report_address(mgmt_ip))
                sckt.send(stats.SerializeToString())

            start = time.time()
Example #3
0
def scheduler(ip, mgmt_ip, route_addr, policy_type):

    # If the management IP is not set, we are running in local mode.
    local = (mgmt_ip is None)
    kvs = AnnaTcpClient(route_addr, ip, local=local)

    scheduler_id = str(uuid.uuid4())

    context = zmq.Context(1)
    context.set(zmq.MAX_SOCKETS, 10000)

    # A mapping from a DAG's name to its protobuf representation.
    dags = {}

    # Tracks how often a request for each function is received.
    call_frequency = {}

    # Tracks the time interval between successive requests for a particular
    # DAG.
    interarrivals = {}

    # Tracks the most recent arrival for each DAG -- used to calculate
    # interarrival times.
    last_arrivals = {}

    # Maintains a list of all other schedulers in the system, so we can
    # propagate metadata to them.
    schedulers = set()

    connect_socket = context.socket(zmq.REP)
    connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT))

    func_create_socket = context.socket(zmq.REP)
    func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT))

    func_call_socket = context.socket(zmq.REP)
    func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT))

    # This is for handle the invocation from queue
    # Mainly for storage event
    func_call_queue_socket = context.socket(zmq.PULL)
    func_call_queue_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                                (FUNC_CALL_QUEUE_PORT))

    dag_create_socket = context.socket(zmq.REP)
    dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT))

    dag_call_socket = context.socket(zmq.REP)
    dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT))

    dag_delete_socket = context.socket(zmq.REP)
    dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT))

    list_socket = context.socket(zmq.REP)
    list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT))

    exec_status_socket = context.socket(zmq.PULL)
    exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT))

    sched_update_socket = context.socket(zmq.PULL)
    sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.SCHED_UPDATE_PORT))

    pin_accept_socket = context.socket(zmq.PULL)
    pin_accept_socket.setsockopt(zmq.RCVTIMEO, 10000)  # 10 seconds.
    pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                           (sutils.PIN_ACCEPT_PORT))

    continuation_socket = context.socket(zmq.PULL)
    continuation_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.CONTINUATION_PORT))

    if not local:
        management_request_socket = context.socket(zmq.REQ)
        management_request_socket.setsockopt(zmq.RCVTIMEO, 500)
        # By setting this flag, zmq matches replies with requests.
        management_request_socket.setsockopt(zmq.REQ_CORRELATE, 1)
        # Relax strict alternation between request and reply.
        # For detailed explanation, see here: http://api.zeromq.org/4-1:zmq-setsockopt
        management_request_socket.setsockopt(zmq.REQ_RELAXED, 1)
        management_request_socket.connect(
            sched_utils.get_scheduler_list_address(mgmt_ip))

    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(connect_socket, zmq.POLLIN)
    poller.register(func_create_socket, zmq.POLLIN)
    poller.register(func_call_socket, zmq.POLLIN)
    poller.register(func_call_queue_socket, zmq.POLLIN)
    poller.register(dag_create_socket, zmq.POLLIN)
    poller.register(dag_call_socket, zmq.POLLIN)
    poller.register(dag_delete_socket, zmq.POLLIN)
    poller.register(list_socket, zmq.POLLIN)
    poller.register(exec_status_socket, zmq.POLLIN)
    poller.register(sched_update_socket, zmq.POLLIN)
    poller.register(continuation_socket, zmq.POLLIN)

    # Start the policy engine.
    policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket,
                                              pusher_cache,
                                              kvs,
                                              ip,
                                              policy_type,
                                              local=local)
    policy.update()

    start = time.time()

    while True:
        socks = dict(poller.poll(timeout=1000))

        if connect_socket in socks and socks[connect_socket] == zmq.POLLIN:
            msg = connect_socket.recv_string()
            connect_socket.send_string(route_addr)

        if (func_create_socket in socks
                and socks[func_create_socket] == zmq.POLLIN):
            create_function(func_create_socket, kvs)

        if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN:
            call_function(func_call_socket, pusher_cache, policy)

        if func_call_queue_socket in socks and socks[
                func_call_queue_socket] == zmq.POLLIN:
            call_function_from_queue(func_call_queue_socket, pusher_cache,
                                     policy)

        if (dag_create_socket in socks
                and socks[dag_create_socket] == zmq.POLLIN):
            create_dag(dag_create_socket, pusher_cache, kvs, dags, policy,
                       call_frequency)

        if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN:
            start_t = int(time.time() * 1000000)
            call = DagCall()
            call.ParseFromString(dag_call_socket.recv())

            name = call.name

            t = time.time()
            if name in last_arrivals:
                if name not in interarrivals:
                    interarrivals[name] = []

                interarrivals[name].append(t - last_arrivals[name])

            last_arrivals[name] = t

            if name not in dags:
                resp = GenericResponse()
                resp.success = False
                resp.error = NO_SUCH_DAG

                dag_call_socket.send(resp.SerializeToString())
                continue

            dag = dags[name]
            for fname in dag[0].functions:
                call_frequency[fname.name] += 1

            response = call_dag(call, pusher_cache, dags, policy)
            sched_t = int(time.time() * 1000000)
            logging.info(
                f'App function {name} recv: {start_t}, scheduled: {sched_t}')
            dag_call_socket.send(response.SerializeToString())

        if (dag_delete_socket in socks
                and socks[dag_delete_socket] == zmq.POLLIN):
            delete_dag(dag_delete_socket, dags, policy, call_frequency)

        if list_socket in socks and socks[list_socket] == zmq.POLLIN:
            msg = list_socket.recv_string()
            prefix = msg if msg else ''

            resp = StringSet()
            resp.keys.extend(sched_utils.get_func_list(kvs, prefix))

            list_socket.send(resp.SerializeToString())

        if exec_status_socket in socks and socks[exec_status_socket] == \
                zmq.POLLIN:
            status = ThreadStatus()
            status.ParseFromString(exec_status_socket.recv())

            policy.process_status(status)

        if sched_update_socket in socks and socks[sched_update_socket] == \
                zmq.POLLIN:
            status = SchedulerStatus()
            status.ParseFromString(sched_update_socket.recv())

            # Retrieve any DAGs that some other scheduler knows about that we
            # do not yet know about.
            for dname in status.dags:
                if dname not in dags:
                    payload = kvs.get(dname)
                    while None in payload:
                        payload = kvs.get(dname)

                    dag = Dag()
                    dag.ParseFromString(payload[dname].reveal())
                    dags[dag.name] = (dag, sched_utils.find_dag_source(dag))

                    for fname in dag.functions:
                        if fname.name not in call_frequency:
                            call_frequency[fname.name] = 0

            policy.update_function_locations(status.function_locations)

        if continuation_socket in socks and socks[continuation_socket] == \
                zmq.POLLIN:
            start_t = int(time.time() * 1000000)

            continuation = Continuation()
            continuation.ParseFromString(continuation_socket.recv())

            call = continuation.call
            call.name = continuation.name

            result = Value()
            result.ParseFromString(continuation.result)

            dag, sources = dags[call.name]
            for source in sources:
                call.function_args[source].values.extend([result])

            call_dag(call, pusher_cache, dags, policy, continuation.id)
            sched_t = int(time.time() * 1000000)
            print(
                f'App function {call.name} recv: {start_t}, scheduled: {sched_t}'
            )
            for fname in dag.functions:
                call_frequency[fname.name] += 1

        end = time.time()

        if end - start > METADATA_THRESHOLD:
            # Update the scheduler policy-related metadata.
            policy.update()

            # If the management IP is None, that means we arre running in
            # local mode, so there is no need to deal with caches and other
            # schedulers.
            if not local:
                latest_schedulers = sched_utils.get_ip_set(
                    management_request_socket, False)
                if latest_schedulers:
                    schedulers = latest_schedulers

        if end - start > REPORT_THRESHOLD:
            status = SchedulerStatus()
            for name in dags.keys():
                status.dags.append(name)

            for fname in policy.function_locations:
                for loc in policy.function_locations[fname]:
                    floc = status.function_locations.add()
                    floc.name = fname
                    floc.ip = loc[0]
                    floc.tid = loc[1]

            msg = status.SerializeToString()

            for sched_ip in schedulers:
                if sched_ip != ip:
                    sckt = pusher_cache.get(
                        sched_utils.get_scheduler_update_address(sched_ip))
                    sckt.send(msg)

            stats = ExecutorStatistics()
            for fname in call_frequency:
                fstats = stats.functions.add()
                fstats.name = fname
                fstats.call_count = call_frequency[fname]
                logging.debug('Reporting %d calls for function %s.' %
                              (call_frequency[fname], fname))

                call_frequency[fname] = 0

            for dname in interarrivals:
                dstats = stats.dags.add()
                dstats.name = dname
                dstats.call_count = len(interarrivals[dname]) + 1
                dstats.interarrival.extend(interarrivals[dname])

                interarrivals[dname].clear()

            # We only attempt to send the statistics if we are running in
            # cluster mode. If we are running in local mode, we write them to
            # the local log file.
            if mgmt_ip:
                sckt = pusher_cache.get(
                    sutils.get_statistics_report_address(mgmt_ip))
                sckt.send(stats.SerializeToString())

            start = time.time()