コード例 #1
0
  def connect(self):
    """
    Initialize routing keys, publisher, and subscriber
    """

    self._validate_metrics()

    if self._input_metrics is not None:  # Sources have no input metrics

      for input_metric_key, input_metric_name in self._input_metrics.items():
        self.routing_keys[input_metric_key] = _ROUTING_KEY % (
          self.user_id, self.device_type, input_metric_name)

      for input_metric_key in self._input_metrics.keys():
        sub = PikaSubscriber(self.rmq_address, self.rmq_user, self.rmq_pwd)
        sub.connect()
        sub.register(self.routing_keys[input_metric_key])
        self.subscribers[input_metric_key] = sub

    if self._output_metrics is not None:  # Sinks have no input metrics

      for output_metric_key, output_metric_name in self._output_metrics.items():
        self.routing_keys[output_metric_key] = _ROUTING_KEY % (
          self.user_id, self.device_type, output_metric_name)

      for output_metric_key in self._output_metrics.keys():
        pub = PikaPublisher(self.rmq_address, self.rmq_user, self.rmq_pwd)
        pub.connect()
        pub.register(self.routing_keys[output_metric_key])
        self.publishers[output_metric_key] = pub
コード例 #2
0
  def initialize(self):
    """
    Initialize classifier, publisher (classification), and subscribers (mu 
    and tag)
    """

    self.classifiers["left"] = HTMClassifier(network_config, _TRAINING_DATA,
                                             _CATEGORIES)
    self.classifiers["right"] = HTMClassifier(network_config, _TRAINING_DATA,
                                              _CATEGORIES)
    for classifier in self.classifiers.values():
      classifier.initialize()
      if _PRE_TRAIN:
        classifier.train(_TRAIN_SET_SIZE, partitions)


    self.tag_subscriber = PikaSubscriber(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)
    self.tag_publisher = PikaPublisher(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)
    self.mu_subscriber = PikaSubscriber(self.rmq_address,
                                        self.rmq_user, self.rmq_pwd)
    self.classification_publisher = PikaPublisher(self.rmq_address,
                                                  self.rmq_user, self.rmq_pwd)

    self.tag_subscriber.connect()
    self.tag_publisher.connect()
    self.mu_subscriber.connect()
    self.classification_publisher.connect()

    self.tag_subscriber.subscribe(self.routing_keys["tag"])
    self.tag_publisher.register(self.routing_keys["tag"])
    self.mu_subscriber.subscribe(self.routing_keys["mu"])
    self.classification_publisher.register(self.routing_keys["classification"])
コード例 #3
0
  def connect(self):
    """Initialize publisher and subscribers"""

    self.routing_keys = {
      self.input_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                         self.input_metric),
      self.output_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                          self.output_metric),
      self.tag_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                       self.tag_metric),
    }

    self.tag_subscriber = PikaSubscriber(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)
    self.classification_publisher = PikaPublisher(self.rmq_address,
                                                  self.rmq_user, self.rmq_pwd)
    self.input_metric_subscriber = PikaSubscriber(self.rmq_address,
                                                  self.rmq_user, self.rmq_pwd)
    self.output_metric_publisher = PikaPublisher(self.rmq_address,
                                                 self.rmq_user, self.rmq_pwd)

    self.tag_subscriber.connect()
    self.classification_publisher.connect()
    self.input_metric_subscriber.connect()
    self.output_metric_publisher.connect()

    self.tag_subscriber.register(self.routing_keys[self.tag_metric])
    self.classification_publisher.register(self.routing_keys[self.tag_metric])
    self.input_metric_subscriber.register(self.routing_keys[self.input_metric])
    self.output_metric_publisher.register(self.routing_keys[self.output_metric])
コード例 #4
0
  def connect(self):
    """
    Initialize EEG preprocessor, publisher, and subscriber
    """

    if self.step_size is None:
      raise ValueError("Step size can't be none. "
                       "Use configure() to set it.")
    if self.electrodes_placement is None:
      raise ValueError("Electrode placement can't be none. "
                       "Use configure() to set it.")
    if self.input_metric is None:
      raise ValueError("Input metric can't be none. "
                       "Use configure() to set it.")
    if self.output_metric is None:
      raise ValueError("Output metric can't be none. "
                       "Use configure() to set it.")

    self.routing_keys = {
      self.input_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                         self.input_metric),
      self.output_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                          self.output_metric)
    }

    self.mu_publisher = PikaPublisher(self.rmq_address,
                                      self.rmq_user, self.rmq_pwd)
    self.eeg_subscriber = PikaSubscriber(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)

    self.eeg_subscriber.connect()
    self.mu_publisher.connect()

    self.mu_publisher.register(self.routing_keys[self.output_metric])
    self.eeg_subscriber.register(self.routing_keys[self.input_metric])
コード例 #5
0
    def initialize(self):
        """
    Initialize EEG preprocessor, publisher, and subscriber
    """
        self.mu_publisher = PikaPublisher(self.rmq_address, self.rmq_user, self.rmq_pwd)
        self.eeg_subscriber = PikaSubscriber(self.rmq_address, self.rmq_user, self.rmq_pwd)

        self.eeg_subscriber.connect()
        self.mu_publisher.connect()

        self.mu_publisher.register(self.routing_keys[_MU])
        self.eeg_subscriber.subscribe(self.routing_keys[_EEG])
コード例 #6
0
import json
from brainsquared.subscribers.PikaSubscriber import PikaSubscriber

host = "rabbitmq.cloudbrain.rocks"
username = "******"
pwd = "cloudbrain"

user = "******"
device = "openbci"
metric = "classification"
routing_key = "%s:%s:%s" % (user, device, metric)

sub = PikaSubscriber(host, username, pwd)
sub.connect()
sub.subscribe(routing_key)

def _print_message(ch, method, properties, body):
  print body


while 1:
  try:
    (a,c,b) = sub.get_one_message(routing_key)
    if b:
      print json.loads(b)
  except KeyboardInterrupt:
    sub.disconnect()

コード例 #7
0
class HTMMotorImageryModule(object):
  def __init__(self,
               user_id,
               module_id,
               device_type,
               rmq_address,
               rmq_user,
               rmq_pwd):

    self.stats = {
      "left": {"min": None, "max": None},
      "right": {"min": None, "max": None}
      }
    self.module_id = module_id
    self.user_id = user_id
    self.device_type = device_type
    self.rmq_address = rmq_address
    self.rmq_user = rmq_user
    self.rmq_pwd = rmq_pwd
    
    self.classification_publisher = None
    self.mu_subscriber = None
    self.tag_subscriber = None
    self.tag_publisher = None

    self.routing_keys = {
      "mu": _ROUTING_KEY % (user_id, device_type, _MU),
      "tag": _ROUTING_KEY % (user_id, module_id, _TAG),
      "classification": _ROUTING_KEY % (user_id, module_id, _CLASSIFICATION)
    }

    self.start_time = int(time.time() * 1000)  # in ms
    self.last_tag = {"timestamp": self.start_time, "value": _CATEGORIES[1]}

    self.classifiers = {"left": None, "right": None}
    self.numRecords = 0
    self.learning_mode = True


  def initialize(self):
    """
    Initialize classifier, publisher (classification), and subscribers (mu 
    and tag)
    """

    self.classifiers["left"] = HTMClassifier(network_config, _TRAINING_DATA,
                                             _CATEGORIES)
    self.classifiers["right"] = HTMClassifier(network_config, _TRAINING_DATA,
                                              _CATEGORIES)
    for classifier in self.classifiers.values():
      classifier.initialize()
      if _PRE_TRAIN:
        classifier.train(_TRAIN_SET_SIZE, partitions)


    self.tag_subscriber = PikaSubscriber(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)
    self.tag_publisher = PikaPublisher(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)
    self.mu_subscriber = PikaSubscriber(self.rmq_address,
                                        self.rmq_user, self.rmq_pwd)
    self.classification_publisher = PikaPublisher(self.rmq_address,
                                                  self.rmq_user, self.rmq_pwd)

    self.tag_subscriber.connect()
    self.tag_publisher.connect()
    self.mu_subscriber.connect()
    self.classification_publisher.connect()

    self.tag_subscriber.subscribe(self.routing_keys["tag"])
    self.tag_publisher.register(self.routing_keys["tag"])
    self.mu_subscriber.subscribe(self.routing_keys["mu"])
    self.classification_publisher.register(self.routing_keys["classification"])
    

  def start(self):
    
    _LOGGER.info("[Module %s] Starting Motor Imagery module. Routing keys: %s" 
                % (self.module_id, self.routing_keys))

   
    self.mu_subscriber.consume_messages(self.routing_keys["mu"], 
                                        self._tag_and_classify)


  def _update_last_tag(self, last_tag):
    """Consume all tags in the queue and keep the last one (i.e. the most up 
    to date)"""
    while 1:
      (meth_frame, header_frame, body) = self.tag_subscriber.get_one_message(
        self.routing_keys["tag"])
      if body:
        last_tag = json.loads(body)
      else:
        return last_tag
  
  
  def _tag_and_classify(self, ch, method, properties, body):
    """Tag data and runs it through the classifier"""
    
    self.numRecords += 1
    print self.numRecords
    if self.numRecords > 1000:
      self.learning_mode = False
      print "=======LEARNING DISABLED!!!========="
      
    self.last_tag = self._update_last_tag(self.last_tag)
    _LOGGER.debug("[Module %s] mu: %s | last_tag: %s" % (self.module_id, body, 
                                                      self.last_tag))

    mu = json.loads(body)
    mu_timestamp = mu["timestamp"]
    tag_timestamp = self.last_tag["timestamp"]

    results = {}
    for (hemisphere, classifier) in self.classifiers.items():
      mu_value = mu[hemisphere]
      tag_value = self.last_tag["value"]
      mu_clipped = np.clip(mu_value, _MU_MIN, _MU_MAX)
      
      results[hemisphere] = classifier.classify(input_data=mu_clipped,
                                                target=tag_value,
                                                learning_is_on=self.learning_mode)

      self._update_stats(hemisphere, mu_value)
      #_LOGGER.debug(self.stats)
    

    _LOGGER.debug("Raw results: %s" % results)      

     
    classification_result = _reconcile_results(results['left'],
                                               results['right'])
    
    buffer = [{"timestamp": mu_timestamp, "value": classification_result}]

    self.classification_publisher.publish(self.routing_keys["classification"],
                                          buffer)


  def _update_stats(self, hemisphere, mu_value):
    """
    Update stats.
     self.stats = {
      "left": {"min": None, "max": None},
      "right": {"min": None, "max": None}
      }
    """
    min_val = self.stats[hemisphere]["min"]
    max_val = self.stats[hemisphere]["max"]
    
    if not min_val:
      self.stats[hemisphere]["min"] = mu_value
    if not max_val:
      self.stats[hemisphere]["max"] = mu_value  
    if mu_value < min_val:
      self.stats[hemisphere]["min"] = mu_value
    if mu_value > max_val:
      self.stats[hemisphere]["max"] = mu_value  
コード例 #8
0
        self.row_counter += 1

      if now > _RECORDING_TIME + self.start_time:
        sys.exit(1)



if __name__ == "__main__":
  opts = get_opts()
  host = opts.server
  username = opts.login
  pwd = opts.pwd
  user = opts.user
  device = opts.device
  metric = opts.metric
  channels = json.loads(opts.channels)
  tag = int(opts.tag)
  convert_to_nupic = bool(opts.nupic)
  routing_key = _ROUTING_KEY % (user, device, metric)

  sub = PikaSubscriber(host, username, pwd)
  sub.connect()
  sub.register(routing_key)

  _LOGGER.info("Consuming messages from queue '%s' at '%s'"
               % (routing_key, host))

  csv_writer = CSVWriter(tag, channels, convert_to_nupic)
  sub.subscribe(routing_key, csv_writer.write_csv_files)
コード例 #9
0
class HTMClassifier(object):
  def __init__(self,
               user_id,
               device_type,
               rmq_address,
               rmq_user,
               rmq_pwd,
               input_metrics,
               output_metrics):

    """
    Motor Imagery Module.
    
    Metrics conventions:
    - Data to classify: {"timestamp": <int>, "channel_0": <float>}
    - Data label: {"timestamp": <int>, "channel_0": <int>}
    - Classification result: {"timestamp": <int>, "channel_0": <int>}
    
    @param user_id: (string) ID of the user using the device. 
    @param device_type: (string) type of the device publishing to this module. 
    @param rmq_address: (string) address of the RabbitMQ server.
    @param rmq_user: (string) login for RabbitMQ connection.
    @param rmq_pwd: (string) password for RabbitMQ connection.
    @param input_metrics: (list) name of the input metric.
    @param output_metrics (list)  name of the output metric.
    """
    self.module_id = HTMClassifier.__name__

    self.user_id = user_id
    self.device_type = device_type
    self.rmq_address = rmq_address
    self.rmq_user = rmq_user
    self.rmq_pwd = rmq_pwd
    self.input_metrics = input_metrics
    self.output_metrics = output_metrics

    self.input_metric = None
    self.output_metric = None
    self.tag_metric = None
    self.last_tag = {"timestamp": None, "channel_0": None}

    self.output_metric_publisher = None
    self.input_metric_subscriber = None
    self.tag_subscriber = None
    self.routing_keys = None
    
    # Set when configure() is called.
    self.categories = None
    self.network_config = None
    self.trained_network_path = None
    self.minval = None
    self.maxval = None

    # Module specific
    self._network = None
    

  def _validate_metrics(self):
    """
    Validate input and output metrics and initialize them accordingly.
    
    This module must have the following signature for input and output metrics:
    
    input_metrics = {"metric_to_classify": <string>, "label_metric": <string>}
    output_metrics = {"result_metric": <string>}
    """

    if "label_metric" in self.input_metrics:
      self.tag_metric = self.input_metrics["label_metric"]
    else:
      raise KeyError("The input metric 'label_metric' is not set!")

    if "metric_to_classify" in self.input_metrics:
      self.input_metric = self.input_metrics["metric_to_classify"]
    else:
      raise KeyError("The input metric 'metric_to_classify' is not set!")

    if "result_metric" in self.output_metrics:
      self.output_metric = self.output_metrics["result_metric"]
    else:
      raise KeyError("The output metric 'result_metric' is not set!")


  def configure(self, categories, network_config, trained_network_path,
                minval, maxval):
    """Configure the module"""

    self._validate_metrics()

    self.categories = categories
    self.network_config = network_config
    self.trained_network_path = trained_network_path
    self.minval = minval
    self.maxval = maxval
    self.network_config["sensorRegionConfig"]["encoders"]["scalarEncoder"][
      "minval"] = minval

    # Init tag with first category
    self.last_tag["channel_0"] = self.categories[0]
    self.last_tag["timestamp"] = int(time.time() * 1000)


  def connect(self):
    """Initialize publisher and subscribers"""

    self.routing_keys = {
      self.input_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                         self.input_metric),
      self.output_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                          self.output_metric),
      self.tag_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                       self.tag_metric),
    }

    self.tag_subscriber = PikaSubscriber(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)
    self.classification_publisher = PikaPublisher(self.rmq_address,
                                                  self.rmq_user, self.rmq_pwd)
    self.input_metric_subscriber = PikaSubscriber(self.rmq_address,
                                                  self.rmq_user, self.rmq_pwd)
    self.output_metric_publisher = PikaPublisher(self.rmq_address,
                                                 self.rmq_user, self.rmq_pwd)

    self.tag_subscriber.connect()
    self.classification_publisher.connect()
    self.input_metric_subscriber.connect()
    self.output_metric_publisher.connect()

    self.tag_subscriber.register(self.routing_keys[self.tag_metric])
    self.classification_publisher.register(self.routing_keys[self.tag_metric])
    self.input_metric_subscriber.register(self.routing_keys[self.input_metric])
    self.output_metric_publisher.register(self.routing_keys[self.output_metric])


  def train(self, training_file, num_records):
    """Create a network and training it on a CSV data source"""

    dataSource = FileRecordStream(streamID=training_file)
    dataSource.setAutoRewind(True)
    self._network = configureNetwork(dataSource, self.network_config)
    for i in xrange(num_records):  # Equivalent to: network.run(num_records) 
      self._network.run(1)
    self._network.save(self.trained_network_path)


  def start(self):
    """Get data from rabbitMQ and classify input data"""

    if self._network is None:
      self._network = Network(self.trained_network_path)

    regionNames = self._get_all_regions_names()
    setNetworkLearningMode(self._network, regionNames, False)
    _LOGGER.info("[Module %s] Starting Motor Imagery module. Routing keys: %s"
                 % (self.module_id, self.routing_keys))

    self.input_metric_subscriber.subscribe(
        self.routing_keys[self.input_metric],
        self._tag_and_classify)


  def _get_all_regions_names(self):

    region_names = []
    for region_config_key, region_config in self.network_config.items():
      region_names.append(region_config["regionName"])

    return region_names


  def _tag_and_classify(self, ch, method, properties, body):
    """Tag data and runs it through the classifier"""

    self._update_last_tag()

    input_data = simplejson.loads(body)
    timestamp = input_data["timestamp"]

    if self.maxval is not None and self.minval is not None:
      value = np.clip(input_data["channel_0"], self.minval, self.maxval)
    else:
      value = input_data["channel_0"]

    classificationResults = classifyNextRecord(self._network,
                                               self.network_config,
                                               timestamp,
                                               value,
                                               self.last_tag["channel_0"])
    inferredCategory = classificationResults["bestInference"]
    _LOGGER.debug("Raw results: %s" % classificationResults)

    buffer = [{"timestamp": timestamp, "channel_0": inferredCategory}]

    self.output_metric_publisher.publish(self.routing_keys[self.output_metric],
                                         buffer)


  def _update_last_tag(self):
    """
    Consume all tags in the queue and keep the last one (i.e. the most up 
    to date)
    
    A tag is a dict with the following format:
     tag = {"timestamp": <int>, "channel_0": <float>}
    
    """
    while 1:
      (meth_frame, header_frame, body) = self.tag_subscriber.get_one_message(
          self.routing_keys[self.tag_metric])
      if body:
        self.last_tag = simplejson.loads(body)
      else:
        _LOGGER.info("Last tag: {}".format(self.last_tag))
        return
コード例 #10
0
ファイル: FileSink.py プロジェクト: ywcui1990/htm-challenge
        "channel_1",
        "channel_2",
        "channel_3",
        "channel_4",
        "channel_5",
        "channel_6",
        "channel_7",
        "tag",
    ]

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    file_name = "{}/test_{}.csv".format(data_dir, int(time.time()))

    tag_subscriber = PikaSubscriber(host, username, pwd)
    tag_subscriber.connect()
    tag_subscriber.register(tag_routing_key)

    eeg_subscriber = PikaSubscriber(host, username, pwd)
    eeg_subscriber.connect()
    eeg_subscriber.register(eeg_routing_key)

    f_out = open(file_name, "w")
    writer = csv.DictWriter(f_out, fieldnames=fieldnames)
    writer.writeheader()

    tag = 0

    t1 = threading.Thread(target=tag_subscriber.subscribe, args=(tag_routing_key, consume_tag))
    t2 = threading.Thread(target=eeg_subscriber.subscribe, args=(eeg_routing_key, consume_eeg))
コード例 #11
0

def _print_message(ch, method, properties, body):
  buffer = json.loads(body)
  for data in buffer:
    print data



if __name__ == "__main__":
  host = "127.0.0.1"
  username = "******"
  pwd = "guest"

  user = "******"
  device = "neurosky"
  metric = "attention"
  routing_key = "%s:%s:%s" % (user, device, metric)

  sub = PikaSubscriber(host, username, pwd)
  sub.connect()
  sub.register(routing_key)

  print "[DEBUG] Consuming messages from queue '%s' at '%s'" % (routing_key,
                                                                host)
  while 1:
    try:
      sub.subscribe(routing_key, _print_message)
    except KeyboardInterrupt:
      sub.disconnect()
コード例 #12
0
class PreprocessingModule(object):
  def __init__(self,
               user_id,
               device_type,
               rmq_address,
               rmq_user,
               rmq_pwd,
               input_metrics,
               output_metrics):

    self.module_id = str(uuid.uuid4())
    
    self.user_id = user_id
    self.device_type = device_type
    self.rmq_address = rmq_address
    self.rmq_user = rmq_user
    self.rmq_pwd = rmq_pwd
    self.input_metrics = None
    self.output_metrics = None

    self.eeg_subscriber = None
    self.mu_publisher = None
    self.routing_keys = None
    self.preprocessor = None

    self.num_channels = get_num_channels(self.device_type, "eeg")
    self.eeg_data = np.zeros((0, self.num_channels))
    self.count = 0

    self.eyeblinks_remover = EyeBlinksFilter()

    self.step_size = None
    self.electrodes_placement = None
    self.enable_ica = False
    
    self.started_fit = False



  def configure(self, step_size, electrodes_placement, enable_ica=False):
    """
    Module specific params.
    @param step_size: (int) STFT step size
    @param electrodes_placement: (dict) dict with the electrode placement 
      for optional Laplacian filtering. 
      
      E.g:
      {
        "channel_2": {
           "main": "channel_2", 
           "artifact": ["channel_0", "channel_3", "channel_5"]
        },
        "channel_4": {
           "main": "channel_4", 
           "artifact": ["channel_1", "channel_3", "channel_6"]
        },
      }
      
      If you don't want any Laplacian filtering then set this to:
      {
        "channel_2": {
           "main": "channel_2", 
           "artifact": []
        },
        "channel_4": {
           "main": "channel_4", 
           "artifact": []
        },
      }
      
      More about Laplacian filtering: http://sccn.ucsd.edu/wiki/Flt_laplace  
    @param input_metric: (string) name of the input metric.
    @param output_metric: (string) name of the output metric.
    @param enable_ica: (boolean) if 1, enable ICA pre-processing. This will 
      remove eye blinks. 
    """
    self.step_size = step_size
    self.electrodes_placement = electrodes_placement
    self.input_metric = input_metric
    self.output_metric = output_metric
    self.enable_ica = enable_ica


  def connect(self):
    """
    Initialize EEG preprocessor, publisher, and subscriber
    """

    if self.step_size is None:
      raise ValueError("Step size can't be none. "
                       "Use configure() to set it.")
    if self.electrodes_placement is None:
      raise ValueError("Electrode placement can't be none. "
                       "Use configure() to set it.")
    if self.input_metric is None:
      raise ValueError("Input metric can't be none. "
                       "Use configure() to set it.")
    if self.output_metric is None:
      raise ValueError("Output metric can't be none. "
                       "Use configure() to set it.")

    self.routing_keys = {
      self.input_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                         self.input_metric),
      self.output_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                          self.output_metric)
    }

    self.mu_publisher = PikaPublisher(self.rmq_address,
                                      self.rmq_user, self.rmq_pwd)
    self.eeg_subscriber = PikaSubscriber(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)

    self.eeg_subscriber.connect()
    self.mu_publisher.connect()

    self.mu_publisher.register(self.routing_keys[self.output_metric])
    self.eeg_subscriber.register(self.routing_keys[self.input_metric])


  def start(self):
    _LOGGER.info("[Module %s] Starting Preprocessing. Routing "
                 "keys: %s" % (self.module_id, self.routing_keys))

    self.eeg_subscriber.subscribe(self.routing_keys[self.input_metric],
                                  self._preprocess)


  def refit_ica(self):
    t = Thread(target=self.eyeblinks_remover.fit, args=(self.eeg_data,))
    t.start()


  def _preprocess(self, ch, method, properties, body):
    eeg = json.loads(body)

    self.eeg_data = np.vstack([self.eeg_data, get_raw(eeg, self.num_channels)])

    self.count += self.step_size

    timestamp = eeg[-1]["timestamp"]

    if self.enable_ica:
      eeg = from_raw(self.eyeblinks_remover.transform(
          get_raw(eeg, self.num_channels)), self.num_channels)
      if ((self.count >= 5000 and not self.started_fit)
          or self.count % 10000 == 0):
        _LOGGER.info('refitting...')
        self.started_fit = True
        self.refit_ica()

    processed_data = preprocess_stft(eeg, self.electrodes_placement)

    data = {"timestamp": timestamp}
    for key, value in processed_data.items():
      data[key] = processed_data[key][-1]

    _LOGGER.debug("--> output: %s" % data)
    self.mu_publisher.publish(self.routing_keys[self.output_metric], data)
コード例 #13
0
class PreprocessingModule(object):
    def __init__(self, user_id, module_id, device_type, rmq_address, rmq_user, rmq_pwd, step_size):
        self.user_id = user_id
        self.module_id = module_id
        self.device_type = device_type
        self.rmq_address = rmq_address
        self.rmq_user = rmq_user
        self.rmq_pwd = rmq_pwd

        self.eeg_subscriber = None
        self.mu_publisher = None

        self.routing_keys = {
            _EEG: _ROUTING_KEY % (user_id, device_type, _EEG),
            _MU: _ROUTING_KEY % (user_id, device_type, _MU),
        }

        self.preprocessor = None

        self.eeg_data = np.zeros((0, 8))
        self.count = 0

        self.eyeblinks_remover = EyeBlinksRemover()

        self.step_size = step_size

        self.started_fit = False

    def initialize(self):
        """
    Initialize EEG preprocessor, publisher, and subscriber
    """
        self.mu_publisher = PikaPublisher(self.rmq_address, self.rmq_user, self.rmq_pwd)
        self.eeg_subscriber = PikaSubscriber(self.rmq_address, self.rmq_user, self.rmq_pwd)

        self.eeg_subscriber.connect()
        self.mu_publisher.connect()

        self.mu_publisher.register(self.routing_keys[_MU])
        self.eeg_subscriber.subscribe(self.routing_keys[_EEG])

    def start(self):
        _LOGGER.info("[Module %s] Starting Preprocessing. Routing " "keys: %s" % (self.module_id, self.routing_keys))

        self.eeg_subscriber.consume_messages(self.routing_keys[_EEG], self._preprocess)

    def refit_ica(self):
        t = Thread(target=self.eyeblinks_remover.fit, args=(self.eeg_data[1000:],))
        t.start()
        # self.eyeblinks_remover.fit(self.eeg_data[1000:])

    def _preprocess(self, ch, method, properties, body):
        eeg = json.loads(body)

        self.eeg_data = np.vstack([self.eeg_data, get_raw(eeg)])

        # self.count += len(eeg)
        self.count += self.step_size

        print(self.count)

        if (self.count >= 5000 and not self.started_fit) or self.count % 10000 == 0:
            _LOGGER.info("refitting...")
            self.started_fit = True
            self.refit_ica()

        timestamp = eeg[-1]["timestamp"]

        eeg = from_raw(self.eyeblinks_remover.transform(get_raw(eeg)))

        process = preprocess_stft(eeg, _METADATA)

        mu_left = process["left"][-1]
        mu_right = process["right"][-1]

        data = {"timestamp": timestamp, "left": mu_left, "right": mu_right}

        _LOGGER.debug("--> mu: %s" % data)
        if self.eyeblinks_remover.fitted:
            self.mu_publisher.publish(self.routing_keys[_MU], data)
コード例 #14
0
import json
from brainsquared.subscribers.PikaSubscriber import PikaSubscriber

host = "rabbitmq.cloudbrain.rocks"
username = "******"
pwd = "cloudbrain"

user = "******"
device = "openbci"
metric = "classification"
routing_key = "%s:%s:%s" % (user, device, metric)

sub = PikaSubscriber(host, username, pwd)
sub.connect()
sub.register(routing_key)

def _print_message(ch, method, properties, body):
  print body


while 1:
  try:
    (a,c,b) = sub.get_one_message(routing_key)
    if b:
      print json.loads(b)
  except KeyboardInterrupt:
    sub.disconnect()