import json from brainsquared.subscribers.PikaSubscriber import PikaSubscriber host = "rabbitmq.cloudbrain.rocks" username = "******" pwd = "cloudbrain" user = "******" device = "openbci" metric = "classification" routing_key = "%s:%s:%s" % (user, device, metric) sub = PikaSubscriber(host, username, pwd) sub.connect() sub.subscribe(routing_key) def _print_message(ch, method, properties, body): print body while 1: try: (a,c,b) = sub.get_one_message(routing_key) if b: print json.loads(b) except KeyboardInterrupt: sub.disconnect()
class HTMClassifier(object): def __init__(self, user_id, device_type, rmq_address, rmq_user, rmq_pwd, input_metrics, output_metrics): """ Motor Imagery Module. Metrics conventions: - Data to classify: {"timestamp": <int>, "channel_0": <float>} - Data label: {"timestamp": <int>, "channel_0": <int>} - Classification result: {"timestamp": <int>, "channel_0": <int>} @param user_id: (string) ID of the user using the device. @param device_type: (string) type of the device publishing to this module. @param rmq_address: (string) address of the RabbitMQ server. @param rmq_user: (string) login for RabbitMQ connection. @param rmq_pwd: (string) password for RabbitMQ connection. @param input_metrics: (list) name of the input metric. @param output_metrics (list) name of the output metric. """ self.module_id = HTMClassifier.__name__ self.user_id = user_id self.device_type = device_type self.rmq_address = rmq_address self.rmq_user = rmq_user self.rmq_pwd = rmq_pwd self.input_metrics = input_metrics self.output_metrics = output_metrics self.input_metric = None self.output_metric = None self.tag_metric = None self.last_tag = {"timestamp": None, "channel_0": None} self.output_metric_publisher = None self.input_metric_subscriber = None self.tag_subscriber = None self.routing_keys = None # Set when configure() is called. self.categories = None self.network_config = None self.trained_network_path = None self.minval = None self.maxval = None # Module specific self._network = None def _validate_metrics(self): """ Validate input and output metrics and initialize them accordingly. This module must have the following signature for input and output metrics: input_metrics = {"metric_to_classify": <string>, "label_metric": <string>} output_metrics = {"result_metric": <string>} """ if "label_metric" in self.input_metrics: self.tag_metric = self.input_metrics["label_metric"] else: raise KeyError("The input metric 'label_metric' is not set!") if "metric_to_classify" in self.input_metrics: self.input_metric = self.input_metrics["metric_to_classify"] else: raise KeyError("The input metric 'metric_to_classify' is not set!") if "result_metric" in self.output_metrics: self.output_metric = self.output_metrics["result_metric"] else: raise KeyError("The output metric 'result_metric' is not set!") def configure(self, categories, network_config, trained_network_path, minval, maxval): """Configure the module""" self._validate_metrics() self.categories = categories self.network_config = network_config self.trained_network_path = trained_network_path self.minval = minval self.maxval = maxval self.network_config["sensorRegionConfig"]["encoders"]["scalarEncoder"][ "minval"] = minval # Init tag with first category self.last_tag["channel_0"] = self.categories[0] self.last_tag["timestamp"] = int(time.time() * 1000) def connect(self): """Initialize publisher and subscribers""" self.routing_keys = { self.input_metric: _ROUTING_KEY % (self.user_id, self.device_type, self.input_metric), self.output_metric: _ROUTING_KEY % (self.user_id, self.device_type, self.output_metric), self.tag_metric: _ROUTING_KEY % (self.user_id, self.device_type, self.tag_metric), } self.tag_subscriber = PikaSubscriber(self.rmq_address, self.rmq_user, self.rmq_pwd) self.classification_publisher = PikaPublisher(self.rmq_address, self.rmq_user, self.rmq_pwd) self.input_metric_subscriber = PikaSubscriber(self.rmq_address, self.rmq_user, self.rmq_pwd) self.output_metric_publisher = PikaPublisher(self.rmq_address, self.rmq_user, self.rmq_pwd) self.tag_subscriber.connect() self.classification_publisher.connect() self.input_metric_subscriber.connect() self.output_metric_publisher.connect() self.tag_subscriber.register(self.routing_keys[self.tag_metric]) self.classification_publisher.register(self.routing_keys[self.tag_metric]) self.input_metric_subscriber.register(self.routing_keys[self.input_metric]) self.output_metric_publisher.register(self.routing_keys[self.output_metric]) def train(self, training_file, num_records): """Create a network and training it on a CSV data source""" dataSource = FileRecordStream(streamID=training_file) dataSource.setAutoRewind(True) self._network = configureNetwork(dataSource, self.network_config) for i in xrange(num_records): # Equivalent to: network.run(num_records) self._network.run(1) self._network.save(self.trained_network_path) def start(self): """Get data from rabbitMQ and classify input data""" if self._network is None: self._network = Network(self.trained_network_path) regionNames = self._get_all_regions_names() setNetworkLearningMode(self._network, regionNames, False) _LOGGER.info("[Module %s] Starting Motor Imagery module. Routing keys: %s" % (self.module_id, self.routing_keys)) self.input_metric_subscriber.subscribe( self.routing_keys[self.input_metric], self._tag_and_classify) def _get_all_regions_names(self): region_names = [] for region_config_key, region_config in self.network_config.items(): region_names.append(region_config["regionName"]) return region_names def _tag_and_classify(self, ch, method, properties, body): """Tag data and runs it through the classifier""" self._update_last_tag() input_data = simplejson.loads(body) timestamp = input_data["timestamp"] if self.maxval is not None and self.minval is not None: value = np.clip(input_data["channel_0"], self.minval, self.maxval) else: value = input_data["channel_0"] classificationResults = classifyNextRecord(self._network, self.network_config, timestamp, value, self.last_tag["channel_0"]) inferredCategory = classificationResults["bestInference"] _LOGGER.debug("Raw results: %s" % classificationResults) buffer = [{"timestamp": timestamp, "channel_0": inferredCategory}] self.output_metric_publisher.publish(self.routing_keys[self.output_metric], buffer) def _update_last_tag(self): """ Consume all tags in the queue and keep the last one (i.e. the most up to date) A tag is a dict with the following format: tag = {"timestamp": <int>, "channel_0": <float>} """ while 1: (meth_frame, header_frame, body) = self.tag_subscriber.get_one_message( self.routing_keys[self.tag_metric]) if body: self.last_tag = simplejson.loads(body) else: _LOGGER.info("Last tag: {}".format(self.last_tag)) return
class HTMMotorImageryModule(object): def __init__(self, user_id, module_id, device_type, rmq_address, rmq_user, rmq_pwd): self.stats = { "left": {"min": None, "max": None}, "right": {"min": None, "max": None} } self.module_id = module_id self.user_id = user_id self.device_type = device_type self.rmq_address = rmq_address self.rmq_user = rmq_user self.rmq_pwd = rmq_pwd self.classification_publisher = None self.mu_subscriber = None self.tag_subscriber = None self.tag_publisher = None self.routing_keys = { "mu": _ROUTING_KEY % (user_id, device_type, _MU), "tag": _ROUTING_KEY % (user_id, module_id, _TAG), "classification": _ROUTING_KEY % (user_id, module_id, _CLASSIFICATION) } self.start_time = int(time.time() * 1000) # in ms self.last_tag = {"timestamp": self.start_time, "value": _CATEGORIES[1]} self.classifiers = {"left": None, "right": None} self.numRecords = 0 self.learning_mode = True def initialize(self): """ Initialize classifier, publisher (classification), and subscribers (mu and tag) """ self.classifiers["left"] = HTMClassifier(network_config, _TRAINING_DATA, _CATEGORIES) self.classifiers["right"] = HTMClassifier(network_config, _TRAINING_DATA, _CATEGORIES) for classifier in self.classifiers.values(): classifier.initialize() if _PRE_TRAIN: classifier.train(_TRAIN_SET_SIZE, partitions) self.tag_subscriber = PikaSubscriber(self.rmq_address, self.rmq_user, self.rmq_pwd) self.tag_publisher = PikaPublisher(self.rmq_address, self.rmq_user, self.rmq_pwd) self.mu_subscriber = PikaSubscriber(self.rmq_address, self.rmq_user, self.rmq_pwd) self.classification_publisher = PikaPublisher(self.rmq_address, self.rmq_user, self.rmq_pwd) self.tag_subscriber.connect() self.tag_publisher.connect() self.mu_subscriber.connect() self.classification_publisher.connect() self.tag_subscriber.subscribe(self.routing_keys["tag"]) self.tag_publisher.register(self.routing_keys["tag"]) self.mu_subscriber.subscribe(self.routing_keys["mu"]) self.classification_publisher.register(self.routing_keys["classification"]) def start(self): _LOGGER.info("[Module %s] Starting Motor Imagery module. Routing keys: %s" % (self.module_id, self.routing_keys)) self.mu_subscriber.consume_messages(self.routing_keys["mu"], self._tag_and_classify) def _update_last_tag(self, last_tag): """Consume all tags in the queue and keep the last one (i.e. the most up to date)""" while 1: (meth_frame, header_frame, body) = self.tag_subscriber.get_one_message( self.routing_keys["tag"]) if body: last_tag = json.loads(body) else: return last_tag def _tag_and_classify(self, ch, method, properties, body): """Tag data and runs it through the classifier""" self.numRecords += 1 print self.numRecords if self.numRecords > 1000: self.learning_mode = False print "=======LEARNING DISABLED!!!=========" self.last_tag = self._update_last_tag(self.last_tag) _LOGGER.debug("[Module %s] mu: %s | last_tag: %s" % (self.module_id, body, self.last_tag)) mu = json.loads(body) mu_timestamp = mu["timestamp"] tag_timestamp = self.last_tag["timestamp"] results = {} for (hemisphere, classifier) in self.classifiers.items(): mu_value = mu[hemisphere] tag_value = self.last_tag["value"] mu_clipped = np.clip(mu_value, _MU_MIN, _MU_MAX) results[hemisphere] = classifier.classify(input_data=mu_clipped, target=tag_value, learning_is_on=self.learning_mode) self._update_stats(hemisphere, mu_value) #_LOGGER.debug(self.stats) _LOGGER.debug("Raw results: %s" % results) classification_result = _reconcile_results(results['left'], results['right']) buffer = [{"timestamp": mu_timestamp, "value": classification_result}] self.classification_publisher.publish(self.routing_keys["classification"], buffer) def _update_stats(self, hemisphere, mu_value): """ Update stats. self.stats = { "left": {"min": None, "max": None}, "right": {"min": None, "max": None} } """ min_val = self.stats[hemisphere]["min"] max_val = self.stats[hemisphere]["max"] if not min_val: self.stats[hemisphere]["min"] = mu_value if not max_val: self.stats[hemisphere]["max"] = mu_value if mu_value < min_val: self.stats[hemisphere]["min"] = mu_value if mu_value > max_val: self.stats[hemisphere]["max"] = mu_value
if __name__ == "__main__": # host = "rabbitmq.cloudbrain.rocks" # username = "******" # pwd = "cloudbrain" host = "localhost" username = "******" pwd = "guest" user = "******" device = "neurosky" metric = "attention" routing_key = "%s:%s:%s" % (user, device, metric) sub = PikaSubscriber(host, username, pwd) sub.connect() sub.register(routing_key) msg = sub.get_one_message(routing_key) print "[DEBUG] de-queued one message: %s" % str(msg) print "[DEBUG] Consuming messages from queue '%s' at '%s'" % (routing_key, host) while 1: try: sub.subscribe(routing_key, _print_message) except KeyboardInterrupt: sub.disconnect()