Пример #1
0
 def store_hash_values(self, hash_keys: list, number: int, hbase_manager: HBaseManager, index: int):
     batch_insert_rows = [
         HBaseRow(row_key=hash_key, row_values={self.COLUMN_NAME: number}, family_name=HBaseManager.FAMILY_NAME)
         for hash_key in hash_keys
     ]
     status = hbase_manager.batch_insert(self.TABLE_NAME, batch_insert_rows)
     return True
Пример #2
0
    def predict_from_image_batch(self, mnist_batch, index):

        t0 = time.time()
        connection_pool = ConnectionPool(size=self.CONNECTION_POOL_SIZE,
                                         host=HBaseManager.HOST,
                                         port=HBaseManager.PORT)
        hbase_manager = HBaseManager(connection_pool)

        process_pool = Pool(self.POOL_SIZE)
        n = len(mnist_batch)

        indexs = list(range(n))

        extract_process = process_pool.starmap_async(self.extract_keys,
                                                     zip(mnist_batch, indexs))
        extracted_keys = extract_process.get()

        predict_hash_args = zip(extracted_keys, indexs)

        predictions = [
            self.predict_hash_values(keys, hbase_manager, i)
            for keys, i in predict_hash_args
        ]

        process_pool.close()

        t1 = time.time()
        print("Mnist Batch {} predicted in: {} Seconds, For Node: {}".format(
            str(index), str(t1 - t0), self.__str__()))

        return predictions
Пример #3
0
    def train_batch(self, mnist_batch, index):
        '''
        :type mnist_batch: list of tuple
        :type deviate: boolean
        :rtype: None
        '''

        t0 = time.time()

        connection_pool = ConnectionPool(size=self.CONNECTION_POOL_SIZE, host=HBaseManager.HOST, port=HBaseManager.PORT)
        hbase_manager = HBaseManager(connection_pool)

        process_pool = Pool(self.POOL_SIZE)
        thread_pool = ThreadPool(self.POOL_SIZE)
        n = len(mnist_batch)

        numbers, mnist_images = MnistHelper.extract_numbers_images(mnist_batch)
        mnist_images = [mnist_obs[MnistModel.PREDICTOR_INDEX] for mnist_obs in mnist_batch]
        indexs = list(range(n))

        extract_process = process_pool.starmap_async(self.extract_keys, zip(mnist_images, indexs))
        extracted_keys = extract_process.get()

        store_hash_args = zip(extracted_keys, numbers, indexs)
        [self.store_hash_values(k, n, hbase_manager, i) for k, n, i in store_hash_args]

        process_pool.close()
        thread_pool.close()

        t1 = time.time()
        print("Time taken to train batch {} : {} Seconds".format(str(index),str(t1 - t0)))
Пример #4
0
    def predict_hash_values(self, hash_keys: list, hbase_manager: HBaseManager, index):

        if len(hash_keys) == 0:
            print("no good hash keys")
            return random.choice(self.ALL_DIGITS)

        hash_rows = hbase_manager.batch_get_rows(self.TABLE_NAME, hash_keys)
        hashed_predictions = [hash_row.row_values[self.COLUMN_NAME] for hash_row in hash_rows]

        if len(hashed_predictions) == 0:
            print("no collision predictions")
            return random.choice(self.ALL_DIGITS)

        test = dict((x,hashed_predictions.count(x)) for x in set(hashed_predictions))

        best_prediction = max(hashed_predictions, key=hashed_predictions.count)

        return best_prediction
Пример #5
0
    def predict_hash_values(self, hash_keys: list, hbase_manager: HBaseManager,
                            index):
        #print("predicting image: " + str(index))

        if len(hash_keys) == 0:
            print("no good hash keys")
            return random.choice(self.ALL_DIGITS)

        hash_rows = hbase_manager.batch_get_rows(self.table_name, hash_keys)
        hashed_predictions = [
            hash_row.row_values[self.COLUMN_NAME] for hash_row in hash_rows
        ]
        if len(hashed_predictions) == 0:
            print("no collision predictions")
            return random.choice(self.ALL_DIGITS)

        best_prediction = max(hashed_predictions, key=hashed_predictions.count)

        return best_prediction
Пример #6
0
    def predict_probs(self, hash_keys: list, hbase_manager: HBaseManager, index):

        if len(hash_keys) == 0:
            print("no good hash keys")
            return self.get_default_probs()

        hash_rows = hbase_manager.batch_get_rows(self.TABLE_NAME, hash_keys)
        hashed_predictions = [hash_row.row_values[self.COLUMN_NAME] for hash_row in hash_rows]

        if len(hashed_predictions) == 0:
            print("no collision predictions")
            return self.get_default_probs()

        prediction_count = Counter(hashed_predictions)
        prediction_probs = {k: v / len(hashed_predictions) for k, v in prediction_count.items()}
        for digit in self.ALL_DIGITS:
            if digit not in prediction_probs:
                prediction_probs[digit] = 0.0

        return prediction_probs
Пример #7
0
 def get_hbase_hash_values(self, hash_keys: list, hbase_manager: HBaseManager, index):
     hash_rows = hbase_manager.batch_get_rows(self.TABLE_NAME, hash_keys)
     hashed_predictions = [hash_row.row_values[self.COLUMN_NAME] for hash_row in hash_rows]
Пример #8
0
 def setup(self):
     HBaseManager(ConnectionPool(size=1, host=HBaseManager.HOST, port=HBaseManager.PORT)).create_table(
         table_name=self.TABLE_NAME, delete=True)