예제 #1
0
class CloudburstConnection():
    def __init__(self, func_addr, ip, tid=0, local=False):
        '''
        func_addr: The address of the Cloudburst interface, either localhost or
        the address of an AWS ELB in cluster mode.
        ip: The IP address of the client machine -- used to send and receive
        responses.
        tid: If multiple clients are running on the same machine, they will
        need to use unique IDs.
        local: A boolean representin whether the client is interacting with the
        cluster in local or cluster mode.
        '''

        self.service_addr = 'tcp://' + func_addr + ':%d'
        self.context = zmq.Context(1)
        kvs_addr = self._connect()

        # Picks a random offset of 10, mostly to alleviate port conflicts when
        # running in local mode.
        self.kvs_client = AnnaTcpClient(kvs_addr,
                                        ip,
                                        local=local,
                                        offset=tid + 10)

        self.func_create_sock = self.context.socket(zmq.REQ)
        self.func_create_sock.connect(self.service_addr % FUNC_CREATE_PORT)

        self.func_call_sock = self.context.socket(zmq.REQ)
        self.func_call_sock.connect(self.service_addr % FUNC_CALL_PORT)

        self.list_sock = self.context.socket(zmq.REQ)
        self.list_sock.connect(self.service_addr % LIST_PORT)

        self.dag_create_sock = self.context.socket(zmq.REQ)
        self.dag_create_sock.connect(self.service_addr % DAG_CREATE_PORT)

        self.dag_call_sock = self.context.socket(zmq.REQ)
        self.dag_call_sock.connect(self.service_addr % DAG_CALL_PORT)

        self.dag_delete_sock = self.context.socket(zmq.REQ)
        self.dag_delete_sock.connect(self.service_addr % DAG_DELETE_PORT)

        self.response_sock = self.context.socket(zmq.PULL)
        response_port = 9000 + tid
        self.response_sock.setsockopt(zmq.RCVTIMEO, 1000)
        self.response_sock.bind('tcp://*:' + str(response_port))

        self.response_address = 'tcp://' + ip + ':' + str(response_port)

        self.rid = 0

    def list(self, prefix=None):
        '''
        Returns a list of all the functions registered in the system.

        prefix: An optional argument which, if specified, prunes the list of
        returned functions to match the provided prefix.
        '''

        for fname in self._get_func_list(prefix):
            print(fname)

    def get_function(self, name):
        '''
        Retrieves a handle for an individual function. Returns None if the
        function cannot be found in the system. The returned object can be
        called like a regular Python function, which returns a CloudburstFuture.

        name: The name of the function to retrieve.
        '''
        if name not in self._get_func_list():
            print(f'''No function found with name {name}. To view all
                  functions, use the `list` method.''')
            return None

        return CloudburstFunction(name, self, self.kvs_client)

    def register(self, function, name):
        '''
        Registers a new function or class with the system. The returned object
        can be called like a regular Python function, which returns a Cloudburst
        Future. If the input is a class, the class is expected to have a run
        method, which is what is invoked at runtime.

        function: The function object that we are registering.
        name: A unique name for the function to be stored with in the system.
        '''

        func = Function()
        func.name = name
        func.body = serializer.dump(function)

        self.func_create_sock.send(func.SerializeToString())

        resp = GenericResponse()
        resp.ParseFromString(self.func_create_sock.recv())

        if resp.success:
            return CloudburstFunction(name, self, self.kvs_client)
        else:
            raise RuntimeError(
                f'Unexpected error while registering function: {resp}.')

    def register_dag(self, name, functions, connections):
        '''
        Registers a new DAG with the system. This operation will fail if any of
        the functions provided cannot be identified in the system.

        name: A unique name for this DAG.
        functions: A list of names of functions to be included in this DAG.
        connections: A list of ordered pairs of function names that represent
        the edges in this DAG.
        '''

        flist = self._get_func_list()
        for fname in functions:
            if isinstance(fname, tuple):
                fname = fname[0]

            if fname not in flist:
                raise RuntimeError(
                    f'Function {fname} not registered. Please register before '
                    + 'including it in a DAG.')

        dag = Dag()
        dag.name = name
        for function in functions:
            ref = dag.functions.add()

            if type(function) == tuple:
                fname = function[0]
                invalids = function[1]
                ref.type = MULTIEXEC
            else:
                fname = function
                invalids = []

            ref.name = fname
            for invalid in invalids:
                ref.invalid_results.append(serializer.dump(invalid))

        for pair in connections:
            conn = dag.connections.add()
            conn.source = pair[0]
            conn.sink = pair[1]

        self.dag_create_sock.send(dag.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.dag_create_sock.recv())

        return r.success, r.error

    def call_dag(self,
                 dname,
                 arg_map,
                 direct_response=False,
                 consistency=NORMAL,
                 output_key=None,
                 client_id=None):
        '''
        Issues a new request to execute the DAG. Returns a CloudburstFuture that

        dname: The name of the DAG to cexecute.
        arg_map: A map from function names to lists of arguments for each of
        the functions in the DAG.
        direct_response: If True, the response will be synchronously received
        by the client; otherwise, the result will be stored in the KVS.
        consistency: The consistency mode to use with this function: either
        NORMAL or MULTI.
        output_key: The KVS key in which to store the result of thie DAG.
        client_id: An optional ID associated with an individual client across
        requests; this is used for causal metadata.
        '''
        dc = DagCall()
        dc.name = dname
        dc.consistency = consistency

        if output_key:
            dc.output_key = output_key

        if client_id:
            dc.client_id = client_id

        for fname in arg_map:
            fname_args = arg_map[fname]
            if type(fname_args) != list:
                fname_args = [fname_args]
            args = [
                serializer.dump(arg, serialize=False) for arg in fname_args
            ]
            al = dc.function_args[fname]
            al.values.extend(args)

        if direct_response:
            dc.response_address = self.response_address

        self.dag_call_sock.send(dc.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.dag_call_sock.recv())

        if direct_response:
            try:
                result = self.response_sock.recv()
                return serializer.load(result)
            except zmq.ZMQError as e:
                if e.errno == zmq.EAGAIN:
                    return None
                else:
                    raise e
        else:
            if r.success:
                return CloudburstFuture(r.response_id, self.kvs_client,
                                        serializer)
            else:
                return None

    def delete_dag(self, dname):
        '''
        Removes the specified DAG from the system.

        dname: The name of the DAG to delete.
        '''
        self.dag_delete_sock.send_string(dname)

        r = GenericResponse()
        r.ParseFromString(self.dag_delete_sock.recv())

        return r.success, r.error

    def get_object(self, key):
        '''
        Retrieves an arbitrary key from the KVS, automatically deserializes it,
        and returns the value to the user.
        '''
        lattice = self.kvs_client.get(key)[key]
        return serializer.load_lattice(lattice)

    def put_object(self, key, value):
        '''
        Automatically wraps an object in a lattice and puts it into the
        key-value store at the desired key.
        '''
        lattice = serializer.dump_lattice(value)
        return self.kvs_client.put(key, lattice)

    def exec_func(self, name, args):
        call = FunctionCall()
        call.name = name
        call.request_id = self.rid

        for arg in args:
            argobj = call.arguments.values.add()
            serializer.dump(arg, argobj)

        self.func_call_sock.send(call.SerializeToString())

        r = GenericResponse()
        r.ParseFromString(self.func_call_sock.recv())

        self.rid += 1
        return r.response_id

    def _connect(self):
        sckt = self.context.socket(zmq.REQ)
        sckt.connect(self.service_addr % CONNECT_PORT)
        sckt.send_string('')

        return sckt.recv_string()

    def _get_func_list(self, prefix=None):
        msg = prefix if prefix else ''
        self.list_sock.send_string(msg)

        flist = StringSet()
        flist.ParseFromString(self.list_sock.recv())
        return flist.keys
예제 #2
0
def scheduler(ip, mgmt_ip, route_addr):

    # If the management IP is not set, we are running in local mode.
    local = (mgmt_ip is None)
    kvs = AnnaTcpClient(route_addr, ip, local=local)

    scheduler_id = str(uuid.uuid4())

    context = zmq.Context(1)

    # A mapping from a DAG's name to its protobuf representation.
    dags = {}

    # Tracks how often a request for each function is received.
    call_frequency = {}

    # Tracks the time interval between successive requests for a particular
    # DAG.
    interarrivals = {}

    # Tracks the most recent arrival for each DAG -- used to calculate
    # interarrival times.
    last_arrivals = {}

    # Maintains a list of all other schedulers in the system, so we can
    # propagate metadata to them.
    schedulers = set()

    connect_socket = context.socket(zmq.REP)
    connect_socket.bind(sutils.BIND_ADDR_TEMPLATE % (CONNECT_PORT))

    func_create_socket = context.socket(zmq.REP)
    func_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CREATE_PORT))

    func_call_socket = context.socket(zmq.REP)
    func_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (FUNC_CALL_PORT))

    dag_create_socket = context.socket(zmq.REP)
    dag_create_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CREATE_PORT))

    dag_call_socket = context.socket(zmq.REP)
    dag_call_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_CALL_PORT))

    dag_delete_socket = context.socket(zmq.REP)
    dag_delete_socket.bind(sutils.BIND_ADDR_TEMPLATE % (DAG_DELETE_PORT))

    list_socket = context.socket(zmq.REP)
    list_socket.bind(sutils.BIND_ADDR_TEMPLATE % (LIST_PORT))

    exec_status_socket = context.socket(zmq.PULL)
    exec_status_socket.bind(sutils.BIND_ADDR_TEMPLATE % (sutils.STATUS_PORT))

    sched_update_socket = context.socket(zmq.PULL)
    sched_update_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.SCHED_UPDATE_PORT))

    pin_accept_socket = context.socket(zmq.PULL)
    pin_accept_socket.setsockopt(zmq.RCVTIMEO, 500)
    pin_accept_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                           (sutils.PIN_ACCEPT_PORT))

    continuation_socket = context.socket(zmq.PULL)
    continuation_socket.bind(sutils.BIND_ADDR_TEMPLATE %
                             (sutils.CONTINUATION_PORT))

    if not local:
        management_request_socket = context.socket(zmq.REQ)
        management_request_socket.setsockopt(zmq.RCVTIMEO, 500)
        # By setting this flag, zmq matches replies with requests.
        management_request_socket.setsockopt(zmq.REQ_CORRELATE, 1)
        # Relax strict alternation between request and reply.
        # For detailed explanation, see here: http://api.zeromq.org/4-1:zmq-setsockopt
        management_request_socket.setsockopt(zmq.REQ_RELAXED, 1)
        management_request_socket.connect(
            sched_utils.get_scheduler_list_address(mgmt_ip))

    pusher_cache = SocketCache(context, zmq.PUSH)

    poller = zmq.Poller()
    poller.register(connect_socket, zmq.POLLIN)
    poller.register(func_create_socket, zmq.POLLIN)
    poller.register(func_call_socket, zmq.POLLIN)
    poller.register(dag_create_socket, zmq.POLLIN)
    poller.register(dag_call_socket, zmq.POLLIN)
    poller.register(dag_delete_socket, zmq.POLLIN)
    poller.register(list_socket, zmq.POLLIN)
    poller.register(exec_status_socket, zmq.POLLIN)
    poller.register(sched_update_socket, zmq.POLLIN)
    poller.register(continuation_socket, zmq.POLLIN)

    # Start the policy engine.
    policy = DefaultCloudburstSchedulerPolicy(pin_accept_socket,
                                              pusher_cache,
                                              kvs,
                                              ip,
                                              local=local)
    policy.update()

    start = time.time()

    while True:
        socks = dict(poller.poll(timeout=1000))

        if connect_socket in socks and socks[connect_socket] == zmq.POLLIN:
            msg = connect_socket.recv_string()
            connect_socket.send_string(route_addr)

        if (func_create_socket in socks
                and socks[func_create_socket] == zmq.POLLIN):
            create_function(func_create_socket, kvs)

        if func_call_socket in socks and socks[func_call_socket] == zmq.POLLIN:
            call_function(func_call_socket, pusher_cache, policy)

        if (dag_create_socket in socks
                and socks[dag_create_socket] == zmq.POLLIN):
            create_dag(dag_create_socket, pusher_cache, kvs, dags, policy,
                       call_frequency)

        if dag_call_socket in socks and socks[dag_call_socket] == zmq.POLLIN:
            call = DagCall()
            call.ParseFromString(dag_call_socket.recv())

            name = call.name

            t = time.time()
            if name in last_arrivals:
                if name not in interarrivals:
                    interarrivals[name] = []

                interarrivals[name].append(t - last_arrivals[name])

            last_arrivals[name] = t

            if name not in dags:
                resp = GenericResponse()
                resp.success = False
                resp.error = NO_SUCH_DAG

                dag_call_socket.send(resp.SerializeToString())
                continue

            dag = dags[name]
            for fname in dag[0].functions:
                call_frequency[fname.name] += 1

            response = call_dag(call, pusher_cache, dags, policy)
            dag_call_socket.send(response.SerializeToString())

        if (dag_delete_socket in socks
                and socks[dag_delete_socket] == zmq.POLLIN):
            delete_dag(dag_delete_socket, dags, policy, call_frequency)

        if list_socket in socks and socks[list_socket] == zmq.POLLIN:
            msg = list_socket.recv_string()
            prefix = msg if msg else ''

            resp = StringSet()
            resp.keys.extend(sched_utils.get_func_list(kvs, prefix))

            list_socket.send(resp.SerializeToString())

        if exec_status_socket in socks and socks[exec_status_socket] == \
                zmq.POLLIN:
            status = ThreadStatus()
            status.ParseFromString(exec_status_socket.recv())

            policy.process_status(status)

        if sched_update_socket in socks and socks[sched_update_socket] == \
                zmq.POLLIN:
            status = SchedulerStatus()
            status.ParseFromString(sched_update_socket.recv())

            # Retrieve any DAGs that some other scheduler knows about that we
            # do not yet know about.
            for dname in status.dags:
                if dname not in dags:
                    payload = kvs.get(dname)
                    while None in payload:
                        payload = kvs.get(dname)

                    dag = Dag()
                    dag.ParseFromString(payload[dname].reveal())
                    dags[dag.name] = (dag, sched_utils.find_dag_source(dag))

                    for fname in dag.functions:
                        if fname.name not in call_frequency:
                            call_frequency[fname.name] = 0

            policy.update_function_locations(status.function_locations)

        if continuation_socket in socks and socks[continuation_socket] == \
                zmq.POLLIN:
            continuation = Continuation()
            continuation.ParseFromString(continuation_socket.recv())

            call = continuation.call
            call.name = continuation.name

            result = Value()
            result.ParseFromString(continuation.result)

            dag, sources = dags[call.name]
            for source in sources:
                call.function_args[source].values.extend([result])

            call_dag(call, pusher_cache, dags, policy, continuation.id)

        end = time.time()

        if end - start > METADATA_THRESHOLD:
            # Update the scheduler policy-related metadata.
            policy.update()

            # If the management IP is None, that means we arre running in
            # local mode, so there is no need to deal with caches and other
            # schedulers.
            if not local:
                latest_schedulers = sched_utils.get_ip_set(
                    management_request_socket, False)
                if latest_schedulers:
                    schedulers = latest_schedulers

        if end - start > REPORT_THRESHOLD:
            status = SchedulerStatus()
            for name in dags.keys():
                status.dags.append(name)

            for fname in policy.function_locations:
                for loc in policy.function_locations[fname]:
                    floc = status.function_locations.add()
                    floc.name = fname
                    floc.ip = loc[0]
                    floc.tid = loc[1]

            msg = status.SerializeToString()

            for sched_ip in schedulers:
                if sched_ip != ip:
                    sckt = pusher_cache.get(
                        sched_utils.get_scheduler_update_address(sched_ip))
                    sckt.send(msg)

            stats = ExecutorStatistics()
            for fname in call_frequency:
                fstats = stats.functions.add()
                fstats.name = fname
                fstats.call_count = call_frequency[fname]
                logging.info('Reporting %d calls for function %s.' %
                             (call_frequency[fname], fname))

                call_frequency[fname] = 0

            for dname in interarrivals:
                dstats = stats.dags.add()
                dstats.name = dname
                dstats.call_count = len(interarrivals[dname]) + 1
                dstats.interarrival.extend(interarrivals[dname])

                interarrivals[dname].clear()

            # We only attempt to send the statistics if we are running in
            # cluster mode. If we are running in local mode, we write them to
            # the local log file.
            if mgmt_ip:
                sckt = pusher_cache.get(
                    sutils.get_statistics_report_address(mgmt_ip))
                sckt.send(stats.SerializeToString())

            start = time.time()
예제 #3
0
"""
for temp_string_iter in pickle_arr:
	key_string = base_string + "#" + str(i)
	value_bytes = LWWPairLattice(int(time.time()), temp_string_iter.encode())
	client.put(key_string, value_bytes)
	print(key_string)
	# print(temp_string_iter)
	# print("")
	i += 1
i_string = str(i)
i_bytes = LWWPairLattice(int(time.time()), i_string.encode())
client.put(bash_bytes_string, i_bytes)
"""

i_two = int(
    (((client.get(bash_bytes_string))[bash_bytes_string]).reveal()).decode())
temp_pickle_ret = ""
for index in range(i_two):
    temp_string_iter_two = pickle_arr[index]
    key_string_two = base_string + "#" + str(index)
    ret = (((client.get(key_string_two))[key_string_two]).reveal()).decode()
    temp_pickle_ret = temp_pickle_ret + ret
    print(temp_string_iter_two == ret)

# print(bin(temp_pickle_string))
print(type(temp_pickle_string))
# print(len(list(str.split(temp_pickle_string))))
# print(list(temp.items())[0][1])
# print(list(temp.items())[0])
# print(list(temp.items()))
# print(type(temp))