Пример #1
0
 def __init__(self, name, app, model, client_ttl=300):
     self.name = name
     self.model = model
     self.app = app
     self.client_manager = ClientManager(name, app, client_ttl)
     self.update_manager = UpdateManager(name)
     self.register_handlers()
Пример #2
0
 def test_client_disconnection(self):
     """
     Tests if client gets disconnected after TIMEOUT.
     """
     timer = FakeTimer(TIMEOUT + 1)
     manager = ClientManager(timer)
     client_id = manager.register("rosetta")[1]
     self.assertIsNone(manager.get_client(client_id))
Пример #3
0
 def test_empty_manager(self):
     """
     Tests if ClientManager does not return trash clients
     for random client id.
     """
     manager = ClientManager()
     for client_id in range(1, 10):
         self.assertIsNone(manager.get_client(client_id))
Пример #4
0
 def listen(self):
     """Listen for the client connection and it to client_manager"""
     self.clients = []
     while (True):
         (clientsocket, address) = self._serverSocket.accept()
         cm = ClientManager(clientsocket, address, self)
         cm.daemon = True
         self.clients.append(cm)
         cm.start()
Пример #5
0
def main():
    # Parse the input aruments
    ap = argparse.ArgumentParser()

    ap.add_argument("--address", "-a", required=False, default="0.0.0.0", help="listening IP address, default=\"0.0.0.0\"")
    ap.add_argument("--port", "-p", required=False, default=9910, help="listening UDP Port, default=9910")
    ap.add_argument("--config", required=False, default="default_config.xml", help="config XML file from ATEM software (default=default_config.xml)")
    ap.add_argument("--debug", "-d", required=False, default="INFO", help="debug level (in quotes): NONE, INFO (default), WARNING, DEBUG")
    

    args = ap.parse_args()
    host = args.address
    port = args.port
    config_file = args.config

    print("ATEM Server Starting...")

    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    s.bind((host, port))

    client_mgr = ClientManager()
    atem_config.config_init(config_file)

    print("ATEM Server Running...Hit ctrl-c to exit")

    while True:
        try:
            # Process incoming packets but timeout after a while so the clients
            # can perform cleanup and resend unresponded packets.
            readers, writers, errors = select.select([s], [], [], 0.050)
            if len(readers) > 0:
                try:
                    bytes, addr = s.recvfrom(2048)
                    packet = Packet(addr, bytes)
                    packet.parse_packet()
                    client = client_mgr.get_client(packet.ip_and_port, packet.session_id)
                    client.process_inbound_packet(packet)
                except ConnectionResetError:
                    print("connection reset!")
                    s.close()
                    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                    s.bind((host, port))
                    continue
                except KeyboardInterrupt:
                    raise
            
            # Perform regularly regardless of incoming packets
            client_mgr.run_clients(s)
        except KeyboardInterrupt:
            # quit
            sys.exit()
Пример #6
0
    def test_add_clients_and_get_them(self):
        """
        Tests if registered client's can be found in manager.
        """
        sample_clients = ["curiosity", "venera", "spirit"]
        manager = ClientManager()
        client_with_id = []

        for name in sample_clients:
            state = manager.register(name)
            self.assertTrue(state[0])
            client_with_id.append((name, state[1]))

        for (name, client_id) in client_with_id:
            self.assertEqual(name, manager.get_client(client_id).get_name())
Пример #7
0
 def __init__(self, game_code):
     # Game code of game being managed.
     self.game_code = game_code
     # Generated codename game.
     self.game = CodenameGame()
     # Manager for clients (and players)
     self.client_manager = ClientManager()
     # Manager for displaying the right data to a client
     # self.display_manager = DisplayManager()
     self.machine = Machine(
         model=self,
         states=GameManager.states,
         initial=GameState.IN_LOBBY,
         transitions=GameManager.transitions,
     )
Пример #8
0
def main():
    start_time = time.time()
    args = ARGS
    ray.init(include_dashboard=False, num_gpus=args.num_gpus)
    log_filename = os.path.join(args.metrics_dir, args.metrics_name + '.csv')
    if os.path.exists(log_filename):
        os.remove(log_filename)
    logging.basicConfig(filename=log_filename,
                        level=logging.INFO,
                        format='%(message)s')
    # Set the random seed if provided (affects client sampling, and batching)
    random.seed(1 + args.seed)
    np.random.seed(12 + args.seed)

    tup = MAIN_PARAMS[args.dataset][args.t]
    num_rounds = args.num_rounds if args.num_rounds != -1 else tup[0]
    eval_every = args.eval_every if args.eval_every != -1 else tup[1]
    clients_per_round = args.clients_per_round if args.clients_per_round != -1 else tup[
        2]
    args.clients_per_round = args.clients_per_round if args.clients_per_round != -1 else tup[
        2]

    manager = ClientManager(args)
    clients = manager.setup_clients(args.setup_clients)
    clients.sort(key=lambda x: x.num_train_samples)

    manager.corrupt_clients()

    print('Clients in Total: %d' % len(clients))

    # Create server
    server = Server(clients, manager, args)
    if args.loadmodel:
        model = tf.keras.models.load_model(log_filename + "_model")
        server.set_model(model)
    client_ids, client_groups, num_train_samples, num_test_samples = manager.get_clients_info(
    )
    total_train_samples = np.sum(list(num_train_samples.values()))
    for c, n in zip(clients, num_train_samples.values()):
        c.set_weight(float(n) / total_train_samples)

    # Initial status
    print('--- Random Initialization ---')
    server.test_model(0, set_to_use='train', log=False)
    # Simulate training
    for i in range(num_rounds):
        # Select clients to train this round
        server.select_clients(i, num_clients=clients_per_round)

        server.train_model(num_epochs=args.num_epochs,
                           batch_size=args.batch_size,
                           round=i)
        aggregation_start = time.time()
        server.aggregate(args.method)

        if args.method == "arfl":
            server.update_alpha(i)

        print(
            datetime.datetime.now(),
            '--- Round %d of %d: Training %d Clients. Time cost in total %s. Aggregation time %s --- '
            % (i + 1, num_rounds, clients_per_round, time.time() - start_time,
               time.time() - aggregation_start))
        # Test model
        if (i + 1) % eval_every == 0 or (i + 1) == num_rounds:
            test_stat_metrics = server.test_model(
                i, set_to_use='train')  # Evaluate training loss
            print_metrics(test_stat_metrics,
                          num_train_samples,
                          prefix='{}_'.format('train'))
            test_stat_metrics = server.test_model(i, set_to_use='test')
            print_metrics(test_stat_metrics,
                          num_test_samples,
                          prefix='{}_'.format('test'))

    # Save model when training ends
    server.save_model(log_filename + "_model")
Пример #9
0
class Experiment(object):
    def __init__(self, name, app, model, client_ttl=300):
        self.name = name
        self.model = model
        self.app = app
        self.client_manager = ClientManager(name, app, client_ttl)
        self.update_manager = UpdateManager(name)
        self.register_handlers()

    def register_handlers(self):
        self.app.router.add_post(
            '/{}/update'.format(self.name),
            self.update,
        )
        self.app.router.add_get(
            '/{}/start_round'.format(self.name),
            self.trigger_start_round,
        )
        self.app.router.add_get(
            '/{}/end_round'.format(self.name),
            self.trigger_end_round,
        )
        self.app.router.add_get(
            '/{}/loss_history'.format(self.name),
            self.get_loss_history,
        )

    async def get_loss_history(self, request):
        return web.json_response(self._update_loss_history)

    async def trigger_start_round(self, request):
        try:
            n_epoch = int(request.query['n_epoch'])
        except KeyError:
            n_epoch = 32
        except ValueError:
            return web.json_response({"err": "Invalid Epoch Value"},
                                     status=400)
        try:
            status = await self.start_round(n_epoch)
        except UpdateException:
            return web.json_response({'err': "Update already in progress"},
                                     status=423)
        return web.json_response(status)

    async def trigger_end_round(self, request):
        self.end_round()
        return web.json_response(json_clean(self._update_state))

    async def start_round(self, n_epoch):
        await self.update_manager.start_update(n_epoch=n_epoch)
        update_name = self.update_manager.update_name
        print("Starting update:", update_name)
        if not len(self.client_manager):
            print("No clients. Aborting update.")
            return []
        data = {
            'state_dict': self.model.state_dict(),
            'update_name': update_name,
            'n_epoch': n_epoch,
        }
        result = await self.client_manager.notify_clients(
            'round_start', http_method='POST', data=pickle.dumps(data))
        for client_id, response in result:
            if response:
                self.update_manager.client_start(client_id)
        if not self.update_manager:
            print("No clients working on round... ending")
            self.end_round()
        return dict(result)

    async def update(self, request):
        client_id = self.client_manager.verify_request(request)
        body = await request.read()
        data = pickle.loads(body)
        update_name = data['update_name']

        if (not self.update_manager.in_progress
                or update_name != self.update_manager.update_name):
            return web.json_response({'error': "Wrong Update"}, status=410)

        self.update_manager.client_end(client_id, data)
        self.client_manager[client_id]['last_update'] = update_name
        self.client_manager[client_id]['num_updates'] += 1

        if not self.update_manager.clients_left:
            self.end_round()
        return web.json_response("OK")

    def end_round(self):
        if not self.update_manager.in_progress:
            return
        update_name = self.update_manager.update_name
        print("Finishing update:", update_name)
        datas = self.update_manager.end_update()
        N = sum(d['n_samples'] for d in datas.values())
        if not N:
            print("No responses for update:", update_name)
            return
        for key, value in self.model.state_dict().items():
            weight_sum = (d['state_dict'][key] * d['n_samples']
                          for d in datas.values())
            value[:] = sum(weight_sum) / N
        for epoch in range(self.update_manager.update_meta['n_epoch']):
            epoch_loss = sum(d['loss_history'][epoch] * d['n_samples']
                             for d in datas.values())
            self.update_manager.loss_history.append(epoch_loss / N)
        print("Finished update:", update_name)
        print("Final Loss:", self.update_manager.loss_history[-1])
import sys
import time
print("- " * 10)
time.sleep(15)
# time.sleep(1)
sys.path.append('/home/pi/Documents/3CB101-Pi/project/client')
from client_manager import ClientManager

manager = ClientManager()
manager.run()
Пример #11
0
class GameManager(object):
    states = [GameState.IN_LOBBY, GameState.IN_GAME, GameState.IN_ENDSCREEN]

    transitions = [
        ['start_game', GameState.IN_LOBBY, GameState.IN_GAME],
        ['pause_game', GameState.IN_GAME, GameState.IN_LOBBY],
        ['game_over', GameState.IN_GAME, GameState.IN_ENDSCREEN],
    ]

    def __init__(self, game_code):
        # Game code of game being managed.
        self.game_code = game_code
        # Generated codename game.
        self.game = CodenameGame()
        # Manager for clients (and players)
        self.client_manager = ClientManager()
        # Manager for displaying the right data to a client
        # self.display_manager = DisplayManager()
        self.machine = Machine(
            model=self,
            states=GameManager.states,
            initial=GameState.IN_LOBBY,
            transitions=GameManager.transitions,
        )

    def get_game(self):
        ''' Returns contained CodenamesGame object. '''
        return self.game

    def get_lobby_update_event(self):
        ''' Constructs a client UPDATE event based upon the current state. '''
        lobby_bundle = self.client_manager.get_lobby_bundle()
        return EmitEvent(ClientEvent.UPDATE.value,
                         lobby_bundle,
                         room=self.game_code,
                         broadcast=True)

    def handle_client_event(self, client_id, client_event, data):
        ''' Passes client events down to the client manager to deal with and
            appends an UPDATE event.
        '''
        client_manager = self.client_manager
        client, events = client_manager.handle_event(self.game_code, client_id,
                                                     client_event, data)
        events.append(self.get_lobby_update_event())
        return client, events

    def validate_client_has_current_role(self, client_id):
        team, role = self.game.get_current_turn()
        return self.client_manager.client_has_role(client_id, team, role)

    def get_game_update_event(self):
        ''' Constructs a game UPDATE event based upon the current state. '''
        game_bundle = self.game.serialize()
        return EmitEvent(GameEvent.UPDATE.value,
                         game_bundle,
                         room=self.game_code,
                         broadcast=True)

    def handle_game_event(self, client_id, game_event, data):
        events = []
        game = self.game
        # TODO: validate that the move came from the expected client
        # self.validate_client_has_current_role(client_id)
        if game_event is GameEvent.CHOOSE_WORD:
            word = data
            game.make_guess(word)
        elif game_event is GameEvent.SUBMIT_CLUE:
            # TODO: validate that we're expecting a submit clue using game turn state
            # self.validate_client_has_current_role(client_id)
            clue = data
            if 'word' not in clue or 'number' not in clue:
                # TODO error handling
                pass
            # TODO validate clue
            game.set_current_clue(clue['word'], int(clue['number']))

        events.append(self.get_game_update_event())
        return game, events

    def get_num_clients(self):
        return self.client_manager.get_num_clients()

    def serialize_game(self):
        ''' Serializes contained game to JSON object along with game code. '''
        game_bundle = self.game.serialize()
        game_bundle['gameCode'] = str(self.game_code)
        return game_bundle

    def serialize_players(self):
        ''' Serializes players to JSON object. '''
        return self.client_manager.serialize_players()

    def serialize_players_mapping(self):
        ''' Serializes playerid to player mapping to JSON object. '''
        return self.client_manager.serialize_players_mapping()
Пример #12
0
class Experiment(object):
    def __init__(self, name, app, model, client_ttl=300):
        self.name = name
        self.model = model
        self.app = app
        self.client_manager = ClientManager(name, app, client_ttl)
        self.update_manager = UpdateManager(name)
        self.register_handlers()

    def register_handlers(self):
        self.app.router.add_post(
            '/{}/update'.format(self.name),
            self.update,
        )
        self.app.router.add_get(
            '/{}/start_round'.format(self.name),
            self.trigger_start_round,
        )
        self.app.router.add_get(
            '/{}/end_round'.format(self.name),
            self.trigger_end_round,
        )
        self.app.router.add_get(
            '/{}/loss_history'.format(self.name),
            self.get_loss_history,
        )
        self.app.router.add_get(
            '/{}/get_client_updates'.format(self.name),
            self.get_client_updates,
        )

    async def get_client_updates(self, request):
        print("inside get_client_updates ")

        all_clients = self.client_manager.clients
        '''
        for c in all_clients: 
            print("######### LOOP ONE MORE TIME  ##########")
            print("dict is ", all_clients[c]['client_id'])
            params = all_clients[c]['stat_dict']
            for p in params:
                params[p] = np.asarray(params[p])
            all_clients[c]['stat_dict'] = params     
        '''

        if len(all_clients) > 0:
            model_all = {
                k: (sum(all_clients[c]['state_dict'][k]
                        for c in all_clients) / len(all_clients))
                for k in all_clients[next(iter(all_clients))]['state_dict']
            }
            if model_all:
                self.model.load_state_dict(model_all)
                print(
                    "done with the merge of model parameters from all clients "
                )
            else:
                print(
                    "done with no merge of model parameters from all clients ")

        return web.json_response("OK")

    async def get_loss_history(self, request):
        return web.json_response(self._update_loss_history)

    async def trigger_start_round(self, request):

        #n_epoch = 32
        n_epoch = 1
        print("get a request trigger start round")

        try:
            status = await self.start_round(n_epoch)
        except UpdateException:
            return web.json_response({'err': "Update already in progress"},
                                     status=423)
        return web.json_response(status)

    async def trigger_end_round(self, request):
        self.end_round()
        #return web.json_response(json_clean(self._update_state))
        #TODO RIGHT NOW " FIXME"
        return web.json_response('OK')

    async def start_round(self, n_epoch):
        await self.update_manager.start_update(n_epoch=n_epoch)
        update_name = self.update_manager.update_name
        print("Starting update:", update_name, n_epoch)

        if not len(self.client_manager):
            print("No clients. Aborting update.")
            return []

        data = {
            'state_dict': self.model.state_dict(),
            'update_name': update_name,
            'n_epoch': n_epoch,
        }

        #print("data is", pickle.dumps(data))
        print("data is stuff inside my start_round")

        result = await self.client_manager.notify_clients(
            'round_start', http_method='POST', data=pickle.dumps(data))

        print(" start update result is ", result)

        for client_id, response in result:
            if response:
                self.update_manager.client_start(client_id)

        if not self.update_manager:
            print("No clients working on round... ending")
            self.end_round()

        print("end start update result is ", result)

        return dict(result)

    async def update(self, request):

        client_id = self.client_manager.verify_request(request)
        print("receive an update from ", client_id)

        body = await request.read()
        data = pickle.loads(body)

        #print("the data from update client_id is ", data)
        update_name = data['update_name']

        #if (not self.update_manager.in_progress or
        #        update_name != self.update_manager.update_name):
        #    return web.json_response({'error': "Wrong Update"}, status=410)

        self.update_manager.client_end(client_id, data)
        self.client_manager[client_id]['last_update'] = update_name
        self.client_manager[client_id]['num_updates'] += 1
        self.client_manager[client_id]['state_dict'] = data['state_dict']

        print("about to end update in manager ")

        if not self.update_manager.clients_left:
            self.end_round()
        print("report update to end update in manager ")

        return web.json_response("OK")

    def end_round(self):
        if not self.update_manager.in_progress:
            return

        update_name = self.update_manager.update_name
        #print("Finishing 1 update:", update_name)

        datas = self.update_manager.end_update()

        print("Finishing 1 update:", datas)

        # here we do federated averaging for models
        # we do average computation here
        # save the model in the manager and notify the clients
        N = sum(d['n_samples'] for d in datas.values())
        if not N:
            print("No responses for update:", update_name)
            return

        # need to figure out a way to calculate averaged sum

        for key, value in self.model.state_dict().items():
            weight_sum = (d['state_dict'][key] * d['n_samples']
                          for d in datas.values())
            value[:] = sum(weight_sum) / N

        #for epoch in range(self.update_manager.update_meta['n_epoch']):
        #    epoch_loss = sum(d['loss_history'][epoch]*d['n_samples']
        #                     for d in datas.values())
        #    self.update_manager.loss_history.append(epoch_loss / N)

        print("Finished update:", update_name)