def store_hash_values(self, hash_keys: list, number: int, hbase_manager: HBaseManager, index: int): batch_insert_rows = [ HBaseRow(row_key=hash_key, row_values={self.COLUMN_NAME: number}, family_name=HBaseManager.FAMILY_NAME) for hash_key in hash_keys ] status = hbase_manager.batch_insert(self.TABLE_NAME, batch_insert_rows) return True
def predict_from_image_batch(self, mnist_batch, index): t0 = time.time() connection_pool = ConnectionPool(size=self.CONNECTION_POOL_SIZE, host=HBaseManager.HOST, port=HBaseManager.PORT) hbase_manager = HBaseManager(connection_pool) process_pool = Pool(self.POOL_SIZE) n = len(mnist_batch) indexs = list(range(n)) extract_process = process_pool.starmap_async(self.extract_keys, zip(mnist_batch, indexs)) extracted_keys = extract_process.get() predict_hash_args = zip(extracted_keys, indexs) predictions = [ self.predict_hash_values(keys, hbase_manager, i) for keys, i in predict_hash_args ] process_pool.close() t1 = time.time() print("Mnist Batch {} predicted in: {} Seconds, For Node: {}".format( str(index), str(t1 - t0), self.__str__())) return predictions
def train_batch(self, mnist_batch, index): ''' :type mnist_batch: list of tuple :type deviate: boolean :rtype: None ''' t0 = time.time() connection_pool = ConnectionPool(size=self.CONNECTION_POOL_SIZE, host=HBaseManager.HOST, port=HBaseManager.PORT) hbase_manager = HBaseManager(connection_pool) process_pool = Pool(self.POOL_SIZE) thread_pool = ThreadPool(self.POOL_SIZE) n = len(mnist_batch) numbers, mnist_images = MnistHelper.extract_numbers_images(mnist_batch) mnist_images = [mnist_obs[MnistModel.PREDICTOR_INDEX] for mnist_obs in mnist_batch] indexs = list(range(n)) extract_process = process_pool.starmap_async(self.extract_keys, zip(mnist_images, indexs)) extracted_keys = extract_process.get() store_hash_args = zip(extracted_keys, numbers, indexs) [self.store_hash_values(k, n, hbase_manager, i) for k, n, i in store_hash_args] process_pool.close() thread_pool.close() t1 = time.time() print("Time taken to train batch {} : {} Seconds".format(str(index),str(t1 - t0)))
def predict_hash_values(self, hash_keys: list, hbase_manager: HBaseManager, index): if len(hash_keys) == 0: print("no good hash keys") return random.choice(self.ALL_DIGITS) hash_rows = hbase_manager.batch_get_rows(self.TABLE_NAME, hash_keys) hashed_predictions = [hash_row.row_values[self.COLUMN_NAME] for hash_row in hash_rows] if len(hashed_predictions) == 0: print("no collision predictions") return random.choice(self.ALL_DIGITS) test = dict((x,hashed_predictions.count(x)) for x in set(hashed_predictions)) best_prediction = max(hashed_predictions, key=hashed_predictions.count) return best_prediction
def predict_hash_values(self, hash_keys: list, hbase_manager: HBaseManager, index): #print("predicting image: " + str(index)) if len(hash_keys) == 0: print("no good hash keys") return random.choice(self.ALL_DIGITS) hash_rows = hbase_manager.batch_get_rows(self.table_name, hash_keys) hashed_predictions = [ hash_row.row_values[self.COLUMN_NAME] for hash_row in hash_rows ] if len(hashed_predictions) == 0: print("no collision predictions") return random.choice(self.ALL_DIGITS) best_prediction = max(hashed_predictions, key=hashed_predictions.count) return best_prediction
def predict_probs(self, hash_keys: list, hbase_manager: HBaseManager, index): if len(hash_keys) == 0: print("no good hash keys") return self.get_default_probs() hash_rows = hbase_manager.batch_get_rows(self.TABLE_NAME, hash_keys) hashed_predictions = [hash_row.row_values[self.COLUMN_NAME] for hash_row in hash_rows] if len(hashed_predictions) == 0: print("no collision predictions") return self.get_default_probs() prediction_count = Counter(hashed_predictions) prediction_probs = {k: v / len(hashed_predictions) for k, v in prediction_count.items()} for digit in self.ALL_DIGITS: if digit not in prediction_probs: prediction_probs[digit] = 0.0 return prediction_probs
def get_hbase_hash_values(self, hash_keys: list, hbase_manager: HBaseManager, index): hash_rows = hbase_manager.batch_get_rows(self.TABLE_NAME, hash_keys) hashed_predictions = [hash_row.row_values[self.COLUMN_NAME] for hash_row in hash_rows]
def setup(self): HBaseManager(ConnectionPool(size=1, host=HBaseManager.HOST, port=HBaseManager.PORT)).create_table( table_name=self.TABLE_NAME, delete=True)