class Classifier(): def __init__(self, nnet_prms_file, labellings_file, logbase=2, only_top=5): with open(nnet_prms_file, 'rb') as nnet_prms_fp: nnet_prms = pickle.load(nnet_prms_fp) nnet_prms['training_params']['BATCH_SZ'] = 1 self.ntwk = NeuralNet(**nnet_prms) self.tester = self.ntwk.get_data_test_model() self.ht = nnet_prms['layers'][0][1]['img_sz'] self.logbase = logbase self.only_top = only_top self.unichars = LabelToUnicodeConverter(labellings_file) self.nclasses = nnet_prms['layers'][-1][1]["n_out"] logi("Network {}".format(self.ntwk)) logi("LogBase {}".format(self.logbase)) logi("OnlyTop {}".format(self.only_top)) def __call__(self, scaled_glp): img = scaled_glp.pix.astype('float32').reshape((1, 1, self.ht, self.ht)) if self.ntwk.takes_aux(): dtopbot = scaled_glp.dtop, scaled_glp.dbot aux_data = np.array([[dtopbot, dtopbot]], dtype='float32') logprobs, preds = self.tester(img, aux_data) else: logprobs, preds = self.tester(img) logprobs = logprobs[0]/self.logbase if self.only_top: decent = np.argpartition(logprobs, -self.only_top)[-self.only_top:] if logger.isEnabledFor(logging.INFO): decent = decent[np.argsort(-logprobs[decent])] else: decent = np.arange(self.nclasses) return [(ch, logprobs[i]) for i in decent for ch in self.unichars[i]]
with open(nnet_prms_file_name, 'rb') as nnet_prms_file: nnet_prms = pickle.load(nnet_prms_file) with open(labelings_file_name, encoding='utf-8') as labels_fp: labellings = ast.literal_eval(labels_fp.read()) # print(labellings) chars = LabelToUnicodeConverter(labellings).onecode ############################################# Init Network Bantry.scaler = ScalerFactory(scaler_prms) bf = BantryFile(banti_file_name) nnet_prms['training_params']['BATCH_SZ'] = 1 ntwk = NeuralNet(**nnet_prms) tester = ntwk.get_data_test_model(go_nuts=True) ############################################# Image saver dir_name = os.path.basename(nnet_prms_file_name)[:-7] + '/' if not os.path.exists(dir_name): os.makedirs(dir_name) namer = (dir_name + '{:03d}_{}_{:02d}.png').format print("Look for me in :", dir_name) def saver(outs, ch, debug=True): saver.index += 1 for i, out in enumerate(outs): global_normalize = False if out.ndim == 2: n_nodes = out.shape[1] w = n_nodes // int(np.sqrt(n_nodes))