Пример #1
0
  def connect(self):
    """
    Initialize routing keys, publisher, and subscriber
    """

    self._validate_metrics()

    if self._input_metrics is not None:  # Sources have no input metrics

      for input_metric_key, input_metric_name in self._input_metrics.items():
        self.routing_keys[input_metric_key] = _ROUTING_KEY % (
          self.user_id, self.device_type, input_metric_name)

      for input_metric_key in self._input_metrics.keys():
        sub = PikaSubscriber(self.rmq_address, self.rmq_user, self.rmq_pwd)
        sub.connect()
        sub.register(self.routing_keys[input_metric_key])
        self.subscribers[input_metric_key] = sub

    if self._output_metrics is not None:  # Sinks have no input metrics

      for output_metric_key, output_metric_name in self._output_metrics.items():
        self.routing_keys[output_metric_key] = _ROUTING_KEY % (
          self.user_id, self.device_type, output_metric_name)

      for output_metric_key in self._output_metrics.keys():
        pub = PikaPublisher(self.rmq_address, self.rmq_user, self.rmq_pwd)
        pub.connect()
        pub.register(self.routing_keys[output_metric_key])
        self.publishers[output_metric_key] = pub
        self.row_counter += 1

      if now > _RECORDING_TIME + self.start_time:
        sys.exit(1)



if __name__ == "__main__":
  opts = get_opts()
  host = opts.server
  username = opts.login
  pwd = opts.pwd
  user = opts.user
  device = opts.device
  metric = opts.metric
  channels = json.loads(opts.channels)
  tag = int(opts.tag)
  convert_to_nupic = bool(opts.nupic)
  routing_key = _ROUTING_KEY % (user, device, metric)

  sub = PikaSubscriber(host, username, pwd)
  sub.connect()
  sub.register(routing_key)

  _LOGGER.info("Consuming messages from queue '%s' at '%s'"
               % (routing_key, host))

  csv_writer = CSVWriter(tag, channels, convert_to_nupic)
  sub.subscribe(routing_key, csv_writer.write_csv_files)
Пример #3
0
class HTMClassifier(object):
  def __init__(self,
               user_id,
               device_type,
               rmq_address,
               rmq_user,
               rmq_pwd,
               input_metrics,
               output_metrics):

    """
    Motor Imagery Module.
    
    Metrics conventions:
    - Data to classify: {"timestamp": <int>, "channel_0": <float>}
    - Data label: {"timestamp": <int>, "channel_0": <int>}
    - Classification result: {"timestamp": <int>, "channel_0": <int>}
    
    @param user_id: (string) ID of the user using the device. 
    @param device_type: (string) type of the device publishing to this module. 
    @param rmq_address: (string) address of the RabbitMQ server.
    @param rmq_user: (string) login for RabbitMQ connection.
    @param rmq_pwd: (string) password for RabbitMQ connection.
    @param input_metrics: (list) name of the input metric.
    @param output_metrics (list)  name of the output metric.
    """
    self.module_id = HTMClassifier.__name__

    self.user_id = user_id
    self.device_type = device_type
    self.rmq_address = rmq_address
    self.rmq_user = rmq_user
    self.rmq_pwd = rmq_pwd
    self.input_metrics = input_metrics
    self.output_metrics = output_metrics

    self.input_metric = None
    self.output_metric = None
    self.tag_metric = None
    self.last_tag = {"timestamp": None, "channel_0": None}

    self.output_metric_publisher = None
    self.input_metric_subscriber = None
    self.tag_subscriber = None
    self.routing_keys = None
    
    # Set when configure() is called.
    self.categories = None
    self.network_config = None
    self.trained_network_path = None
    self.minval = None
    self.maxval = None

    # Module specific
    self._network = None
    

  def _validate_metrics(self):
    """
    Validate input and output metrics and initialize them accordingly.
    
    This module must have the following signature for input and output metrics:
    
    input_metrics = {"metric_to_classify": <string>, "label_metric": <string>}
    output_metrics = {"result_metric": <string>}
    """

    if "label_metric" in self.input_metrics:
      self.tag_metric = self.input_metrics["label_metric"]
    else:
      raise KeyError("The input metric 'label_metric' is not set!")

    if "metric_to_classify" in self.input_metrics:
      self.input_metric = self.input_metrics["metric_to_classify"]
    else:
      raise KeyError("The input metric 'metric_to_classify' is not set!")

    if "result_metric" in self.output_metrics:
      self.output_metric = self.output_metrics["result_metric"]
    else:
      raise KeyError("The output metric 'result_metric' is not set!")


  def configure(self, categories, network_config, trained_network_path,
                minval, maxval):
    """Configure the module"""

    self._validate_metrics()

    self.categories = categories
    self.network_config = network_config
    self.trained_network_path = trained_network_path
    self.minval = minval
    self.maxval = maxval
    self.network_config["sensorRegionConfig"]["encoders"]["scalarEncoder"][
      "minval"] = minval

    # Init tag with first category
    self.last_tag["channel_0"] = self.categories[0]
    self.last_tag["timestamp"] = int(time.time() * 1000)


  def connect(self):
    """Initialize publisher and subscribers"""

    self.routing_keys = {
      self.input_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                         self.input_metric),
      self.output_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                          self.output_metric),
      self.tag_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                       self.tag_metric),
    }

    self.tag_subscriber = PikaSubscriber(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)
    self.classification_publisher = PikaPublisher(self.rmq_address,
                                                  self.rmq_user, self.rmq_pwd)
    self.input_metric_subscriber = PikaSubscriber(self.rmq_address,
                                                  self.rmq_user, self.rmq_pwd)
    self.output_metric_publisher = PikaPublisher(self.rmq_address,
                                                 self.rmq_user, self.rmq_pwd)

    self.tag_subscriber.connect()
    self.classification_publisher.connect()
    self.input_metric_subscriber.connect()
    self.output_metric_publisher.connect()

    self.tag_subscriber.register(self.routing_keys[self.tag_metric])
    self.classification_publisher.register(self.routing_keys[self.tag_metric])
    self.input_metric_subscriber.register(self.routing_keys[self.input_metric])
    self.output_metric_publisher.register(self.routing_keys[self.output_metric])


  def train(self, training_file, num_records):
    """Create a network and training it on a CSV data source"""

    dataSource = FileRecordStream(streamID=training_file)
    dataSource.setAutoRewind(True)
    self._network = configureNetwork(dataSource, self.network_config)
    for i in xrange(num_records):  # Equivalent to: network.run(num_records) 
      self._network.run(1)
    self._network.save(self.trained_network_path)


  def start(self):
    """Get data from rabbitMQ and classify input data"""

    if self._network is None:
      self._network = Network(self.trained_network_path)

    regionNames = self._get_all_regions_names()
    setNetworkLearningMode(self._network, regionNames, False)
    _LOGGER.info("[Module %s] Starting Motor Imagery module. Routing keys: %s"
                 % (self.module_id, self.routing_keys))

    self.input_metric_subscriber.subscribe(
        self.routing_keys[self.input_metric],
        self._tag_and_classify)


  def _get_all_regions_names(self):

    region_names = []
    for region_config_key, region_config in self.network_config.items():
      region_names.append(region_config["regionName"])

    return region_names


  def _tag_and_classify(self, ch, method, properties, body):
    """Tag data and runs it through the classifier"""

    self._update_last_tag()

    input_data = simplejson.loads(body)
    timestamp = input_data["timestamp"]

    if self.maxval is not None and self.minval is not None:
      value = np.clip(input_data["channel_0"], self.minval, self.maxval)
    else:
      value = input_data["channel_0"]

    classificationResults = classifyNextRecord(self._network,
                                               self.network_config,
                                               timestamp,
                                               value,
                                               self.last_tag["channel_0"])
    inferredCategory = classificationResults["bestInference"]
    _LOGGER.debug("Raw results: %s" % classificationResults)

    buffer = [{"timestamp": timestamp, "channel_0": inferredCategory}]

    self.output_metric_publisher.publish(self.routing_keys[self.output_metric],
                                         buffer)


  def _update_last_tag(self):
    """
    Consume all tags in the queue and keep the last one (i.e. the most up 
    to date)
    
    A tag is a dict with the following format:
     tag = {"timestamp": <int>, "channel_0": <float>}
    
    """
    while 1:
      (meth_frame, header_frame, body) = self.tag_subscriber.get_one_message(
          self.routing_keys[self.tag_metric])
      if body:
        self.last_tag = simplejson.loads(body)
      else:
        _LOGGER.info("Last tag: {}".format(self.last_tag))
        return
Пример #4
0
        "channel_3",
        "channel_4",
        "channel_5",
        "channel_6",
        "channel_7",
        "tag",
    ]

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    file_name = "{}/test_{}.csv".format(data_dir, int(time.time()))

    tag_subscriber = PikaSubscriber(host, username, pwd)
    tag_subscriber.connect()
    tag_subscriber.register(tag_routing_key)

    eeg_subscriber = PikaSubscriber(host, username, pwd)
    eeg_subscriber.connect()
    eeg_subscriber.register(eeg_routing_key)

    f_out = open(file_name, "w")
    writer = csv.DictWriter(f_out, fieldnames=fieldnames)
    writer.writeheader()

    tag = 0

    t1 = threading.Thread(target=tag_subscriber.subscribe, args=(tag_routing_key, consume_tag))
    t2 = threading.Thread(target=eeg_subscriber.subscribe, args=(eeg_routing_key, consume_eeg))

    t1.start()
Пример #5
0
class PreprocessingModule(object):
  def __init__(self,
               user_id,
               device_type,
               rmq_address,
               rmq_user,
               rmq_pwd,
               input_metrics,
               output_metrics):

    self.module_id = str(uuid.uuid4())
    
    self.user_id = user_id
    self.device_type = device_type
    self.rmq_address = rmq_address
    self.rmq_user = rmq_user
    self.rmq_pwd = rmq_pwd
    self.input_metrics = None
    self.output_metrics = None

    self.eeg_subscriber = None
    self.mu_publisher = None
    self.routing_keys = None
    self.preprocessor = None

    self.num_channels = get_num_channels(self.device_type, "eeg")
    self.eeg_data = np.zeros((0, self.num_channels))
    self.count = 0

    self.eyeblinks_remover = EyeBlinksFilter()

    self.step_size = None
    self.electrodes_placement = None
    self.enable_ica = False
    
    self.started_fit = False



  def configure(self, step_size, electrodes_placement, enable_ica=False):
    """
    Module specific params.
    @param step_size: (int) STFT step size
    @param electrodes_placement: (dict) dict with the electrode placement 
      for optional Laplacian filtering. 
      
      E.g:
      {
        "channel_2": {
           "main": "channel_2", 
           "artifact": ["channel_0", "channel_3", "channel_5"]
        },
        "channel_4": {
           "main": "channel_4", 
           "artifact": ["channel_1", "channel_3", "channel_6"]
        },
      }
      
      If you don't want any Laplacian filtering then set this to:
      {
        "channel_2": {
           "main": "channel_2", 
           "artifact": []
        },
        "channel_4": {
           "main": "channel_4", 
           "artifact": []
        },
      }
      
      More about Laplacian filtering: http://sccn.ucsd.edu/wiki/Flt_laplace  
    @param input_metric: (string) name of the input metric.
    @param output_metric: (string) name of the output metric.
    @param enable_ica: (boolean) if 1, enable ICA pre-processing. This will 
      remove eye blinks. 
    """
    self.step_size = step_size
    self.electrodes_placement = electrodes_placement
    self.input_metric = input_metric
    self.output_metric = output_metric
    self.enable_ica = enable_ica


  def connect(self):
    """
    Initialize EEG preprocessor, publisher, and subscriber
    """

    if self.step_size is None:
      raise ValueError("Step size can't be none. "
                       "Use configure() to set it.")
    if self.electrodes_placement is None:
      raise ValueError("Electrode placement can't be none. "
                       "Use configure() to set it.")
    if self.input_metric is None:
      raise ValueError("Input metric can't be none. "
                       "Use configure() to set it.")
    if self.output_metric is None:
      raise ValueError("Output metric can't be none. "
                       "Use configure() to set it.")

    self.routing_keys = {
      self.input_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                         self.input_metric),
      self.output_metric: _ROUTING_KEY % (self.user_id, self.device_type,
                                          self.output_metric)
    }

    self.mu_publisher = PikaPublisher(self.rmq_address,
                                      self.rmq_user, self.rmq_pwd)
    self.eeg_subscriber = PikaSubscriber(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)

    self.eeg_subscriber.connect()
    self.mu_publisher.connect()

    self.mu_publisher.register(self.routing_keys[self.output_metric])
    self.eeg_subscriber.register(self.routing_keys[self.input_metric])


  def start(self):
    _LOGGER.info("[Module %s] Starting Preprocessing. Routing "
                 "keys: %s" % (self.module_id, self.routing_keys))

    self.eeg_subscriber.subscribe(self.routing_keys[self.input_metric],
                                  self._preprocess)


  def refit_ica(self):
    t = Thread(target=self.eyeblinks_remover.fit, args=(self.eeg_data,))
    t.start()


  def _preprocess(self, ch, method, properties, body):
    eeg = json.loads(body)

    self.eeg_data = np.vstack([self.eeg_data, get_raw(eeg, self.num_channels)])

    self.count += self.step_size

    timestamp = eeg[-1]["timestamp"]

    if self.enable_ica:
      eeg = from_raw(self.eyeblinks_remover.transform(
          get_raw(eeg, self.num_channels)), self.num_channels)
      if ((self.count >= 5000 and not self.started_fit)
          or self.count % 10000 == 0):
        _LOGGER.info('refitting...')
        self.started_fit = True
        self.refit_ica()

    processed_data = preprocess_stft(eeg, self.electrodes_placement)

    data = {"timestamp": timestamp}
    for key, value in processed_data.items():
      data[key] = processed_data[key][-1]

    _LOGGER.debug("--> output: %s" % data)
    self.mu_publisher.publish(self.routing_keys[self.output_metric], data)