Exemplo n.º 1
0
    def dataReceived(self, data):
        msg = q_protocol.parse_message(data)
        if msg['type'] == 'login':

            # Redundant connection
            if msg['sender'] in self.factory.clients:
                self.transport.write(
                    q_protocol.construct_redundant_connection_message(
                        socket.gethostname()))
                print(bcolors.FAIL + msg['sender'] +
                      ' tried to connect again. Killing second connection.' +
                      bcolors.ENDC)
                self.transport.loseConnection()

            # New connection
            else:
                print(bcolors.OKGREEN + msg['sender'] + ' has connected.' +
                      bcolors.ENDC)
                self.send_new_net(msg['sender'])

        elif msg['type'] == 'net_trained':
            iters = self.factory.clients[msg['sender']]['iters_sampled']
            self.factory.new_net_lock.run(self.factory.incorporate_trained_net,
                                          msg['net_string'],
                                          float(msg['acc_best_val']),
                                          int(msg['iter_best_val']),
                                          float(msg['acc_last_val']),
                                          int(msg['iter_last_val']),
                                          float(msg['epsilon']), iters,
                                          msg['sender'])
            self.send_new_net(msg['sender'])
        elif msg['type'] == 'net_too_large':
            self.send_new_net(msg['sender'])
Exemplo n.º 2
0
    def dataReceived(self, data):
        out = q_protocol.parse_message(data)
        if out['type'] == 'redundant_connection':
            print('Redundancy in connect name')

        if out['type'] == 'new_net':
            print('Ready to train ' + out['net_string'])

            if self.factory.debug:
                time.sleep(5)
                self.transport.write(q_protocol.construct_net_trained_message(self.factory.clientname,
                                                                              out['net_string'],
                                                                              86.0,
                                                                              100,
                                                                              85.5,
                                                                              10000,
                                                                              float(out['epsilon']),

                                                                              int(out['iteration_number'])))
            else:
                model_dir = get_model_dir(self.factory.hyper_parameters.CHECKPOINT_DIR, out['net_string'])

                trainer = ModelExec(model_dir, self.factory.hyper_parameters, self.factory.state_space_parameters)

                train_out = trainer.run_one_model(out['net_string'], gpu_to_use=self.factory.gpu_to_use)
                print('OUT', train_out)

                # If OUT OF MEMORY or FAIL, delete files
                if train_out['status'] in ['OUT_OF_MEMORY', 'FAIL']:
                    rm_model_dir(self.factory.hyper_parameters.CHECKPOINT_DIR, out['net_string'])

                if train_out['status'] == 'OUT_OF_MEMORY':
                    self.transport.write(q_protocol.construct_net_too_large_message(self.factory.clientname))
                else:

                    (iter_best, acc_best) = max(list(train_out['test_accs'].items()), key=lambda x: x[1]) if train_out[
                                                                                                                 'status'] != 'FAIL' \
                        else (0, 1.0 / self.factory.hyper_parameters.NUM_CLASSES)
                    (iter_last, acc_last) = max(list(train_out['test_accs'].items()), key=lambda x: x[0]) if train_out[
                                                                                                                 'status'] != 'FAIL' \
                        else (0, 1.0 / self.factory.hyper_parameters.NUM_CLASSES)

                    # Clear out model files
                    clear_redundant_logs_caffe(self.factory.hyper_parameters.CHECKPOINT_DIR,
                                               pd.DataFrame({'net': [out['net_string']],
                                                             'iter_best_val': [iter_best],
                                                             'iter_last_val': [iter_last]}))

                    self.transport.write(q_protocol.construct_net_trained_message(self.factory.clientname,
                                                                                  out['net_string'],
                                                                                  acc_best,
                                                                                  iter_best,
                                                                                  acc_last,
                                                                                  iter_last,
                                                                                  float(out['epsilon']),
                                                                                  int(out['iteration_number'])))
Exemplo n.º 3
0
    def test_construct_new_net(self):
        test = True
        msg = q_protocol.construct_new_net_message('luna', '[C(120,1,1), P(5,1), GAP(10), SM(10)]', 0.7, 2000)
        out = q_protocol.parse_message(msg)

        test = (out['sender'] == 'luna'
                and out['type'] == 'new_net'
                and out['net_string'] == '[C(120,1,1), P(5,1), GAP(10), SM(10)]'
                and float(out['epsilon']) == 0.7
                and int(out['iteration_number']) == 2000)

        if not test:
            print(msg)
            print(out)

        return test
Exemplo n.º 4
0
    def test_construct_net_trained(self):
        test = True

        msg = q_protocol.construct_net_trained_message('luna', '[C(120,1,1), P(5,1), GAP(10), SM(10)]', 0.1, 1000, 0.2, 2000, 0.7, 3000)
        out = q_protocol.parse_message(msg)

        test = (out['sender'] == 'luna'
                and out['type'] == 'net_trained'
                and out['net_string'] == '[C(120,1,1), P(5,1), GAP(10), SM(10)]'
                and float(out['acc_best_val']) == 0.1
                and int(out['iter_best_val']) == 1000
                and float(out['acc_last_val']) == 0.2
                and int(out['iter_last_val']) == 2000
                and float(out['epsilon']) == 0.7
                and int(out['iteration_number']) == 3000)


        if not test:
            print(msg)
            print(out)

        return test
Exemplo n.º 5
0
    def test_construct_login(self):
        msg = q_protocol.construct_login_message('luna')
        out = q_protocol.parse_message(msg)

        return (out['sender'] == 'luna'
                and out['type'] == 'login')