예제 #1
0
    def dataReceived(self, data):
        out = q_protocol.parse_message(data)
        if out['type'] == 'redundant_connection':
            print('Redundancy in connect name')

        if out['type'] == 'new_net':
            print('Ready to train ' + out['net_string'])

            if self.factory.debug:
                time.sleep(5)
                self.transport.write(q_protocol.construct_net_trained_message(self.factory.clientname,
                                                                              out['net_string'],
                                                                              86.0,
                                                                              100,
                                                                              85.5,
                                                                              10000,
                                                                              float(out['epsilon']),

                                                                              int(out['iteration_number'])))
            else:
                model_dir = get_model_dir(self.factory.hyper_parameters.CHECKPOINT_DIR, out['net_string'])

                trainer = ModelExec(model_dir, self.factory.hyper_parameters, self.factory.state_space_parameters)

                train_out = trainer.run_one_model(out['net_string'], gpu_to_use=self.factory.gpu_to_use)
                print('OUT', train_out)

                # If OUT OF MEMORY or FAIL, delete files
                if train_out['status'] in ['OUT_OF_MEMORY', 'FAIL']:
                    rm_model_dir(self.factory.hyper_parameters.CHECKPOINT_DIR, out['net_string'])

                if train_out['status'] == 'OUT_OF_MEMORY':
                    self.transport.write(q_protocol.construct_net_too_large_message(self.factory.clientname))
                else:

                    (iter_best, acc_best) = max(list(train_out['test_accs'].items()), key=lambda x: x[1]) if train_out[
                                                                                                                 'status'] != 'FAIL' \
                        else (0, 1.0 / self.factory.hyper_parameters.NUM_CLASSES)
                    (iter_last, acc_last) = max(list(train_out['test_accs'].items()), key=lambda x: x[0]) if train_out[
                                                                                                                 'status'] != 'FAIL' \
                        else (0, 1.0 / self.factory.hyper_parameters.NUM_CLASSES)

                    # Clear out model files
                    clear_redundant_logs_caffe(self.factory.hyper_parameters.CHECKPOINT_DIR,
                                               pd.DataFrame({'net': [out['net_string']],
                                                             'iter_best_val': [iter_best],
                                                             'iter_last_val': [iter_last]}))

                    self.transport.write(q_protocol.construct_net_trained_message(self.factory.clientname,
                                                                                  out['net_string'],
                                                                                  acc_best,
                                                                                  iter_best,
                                                                                  acc_last,
                                                                                  iter_last,
                                                                                  float(out['epsilon']),
                                                                                  int(out['iteration_number'])))
예제 #2
0
    def test_construct_net_trained(self):
        test = True

        msg = q_protocol.construct_net_trained_message('luna', '[C(120,1,1), P(5,1), GAP(10), SM(10)]', 0.1, 1000, 0.2, 2000, 0.7, 3000)
        out = q_protocol.parse_message(msg)

        test = (out['sender'] == 'luna'
                and out['type'] == 'net_trained'
                and out['net_string'] == '[C(120,1,1), P(5,1), GAP(10), SM(10)]'
                and float(out['acc_best_val']) == 0.1
                and int(out['iter_best_val']) == 1000
                and float(out['acc_last_val']) == 0.2
                and int(out['iter_last_val']) == 2000
                and float(out['epsilon']) == 0.7
                and int(out['iteration_number']) == 3000)


        if not test:
            print(msg)
            print(out)

        return test