class HTMMotorImageryModule(object):
  def __init__(self,
               user_id,
               module_id,
               device_type,
               rmq_address,
               rmq_user,
               rmq_pwd):

    self.stats = {
      "left": {"min": None, "max": None},
      "right": {"min": None, "max": None}
      }
    self.module_id = module_id
    self.user_id = user_id
    self.device_type = device_type
    self.rmq_address = rmq_address
    self.rmq_user = rmq_user
    self.rmq_pwd = rmq_pwd
    
    self.classification_publisher = None
    self.mu_subscriber = None
    self.tag_subscriber = None
    self.tag_publisher = None

    self.routing_keys = {
      "mu": _ROUTING_KEY % (user_id, device_type, _MU),
      "tag": _ROUTING_KEY % (user_id, module_id, _TAG),
      "classification": _ROUTING_KEY % (user_id, module_id, _CLASSIFICATION)
    }

    self.start_time = int(time.time() * 1000)  # in ms
    self.last_tag = {"timestamp": self.start_time, "value": _CATEGORIES[1]}

    self.classifiers = {"left": None, "right": None}
    self.numRecords = 0
    self.learning_mode = True


  def initialize(self):
    """
    Initialize classifier, publisher (classification), and subscribers (mu 
    and tag)
    """

    self.classifiers["left"] = HTMClassifier(network_config, _TRAINING_DATA,
                                             _CATEGORIES)
    self.classifiers["right"] = HTMClassifier(network_config, _TRAINING_DATA,
                                              _CATEGORIES)
    for classifier in self.classifiers.values():
      classifier.initialize()
      if _PRE_TRAIN:
        classifier.train(_TRAIN_SET_SIZE, partitions)


    self.tag_subscriber = PikaSubscriber(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)
    self.tag_publisher = PikaPublisher(self.rmq_address,
                                         self.rmq_user, self.rmq_pwd)
    self.mu_subscriber = PikaSubscriber(self.rmq_address,
                                        self.rmq_user, self.rmq_pwd)
    self.classification_publisher = PikaPublisher(self.rmq_address,
                                                  self.rmq_user, self.rmq_pwd)

    self.tag_subscriber.connect()
    self.tag_publisher.connect()
    self.mu_subscriber.connect()
    self.classification_publisher.connect()

    self.tag_subscriber.subscribe(self.routing_keys["tag"])
    self.tag_publisher.register(self.routing_keys["tag"])
    self.mu_subscriber.subscribe(self.routing_keys["mu"])
    self.classification_publisher.register(self.routing_keys["classification"])
    

  def start(self):
    
    _LOGGER.info("[Module %s] Starting Motor Imagery module. Routing keys: %s" 
                % (self.module_id, self.routing_keys))

   
    self.mu_subscriber.consume_messages(self.routing_keys["mu"], 
                                        self._tag_and_classify)


  def _update_last_tag(self, last_tag):
    """Consume all tags in the queue and keep the last one (i.e. the most up 
    to date)"""
    while 1:
      (meth_frame, header_frame, body) = self.tag_subscriber.get_one_message(
        self.routing_keys["tag"])
      if body:
        last_tag = json.loads(body)
      else:
        return last_tag
  
  
  def _tag_and_classify(self, ch, method, properties, body):
    """Tag data and runs it through the classifier"""
    
    self.numRecords += 1
    print self.numRecords
    if self.numRecords > 1000:
      self.learning_mode = False
      print "=======LEARNING DISABLED!!!========="
      
    self.last_tag = self._update_last_tag(self.last_tag)
    _LOGGER.debug("[Module %s] mu: %s | last_tag: %s" % (self.module_id, body, 
                                                      self.last_tag))

    mu = json.loads(body)
    mu_timestamp = mu["timestamp"]
    tag_timestamp = self.last_tag["timestamp"]

    results = {}
    for (hemisphere, classifier) in self.classifiers.items():
      mu_value = mu[hemisphere]
      tag_value = self.last_tag["value"]
      mu_clipped = np.clip(mu_value, _MU_MIN, _MU_MAX)
      
      results[hemisphere] = classifier.classify(input_data=mu_clipped,
                                                target=tag_value,
                                                learning_is_on=self.learning_mode)

      self._update_stats(hemisphere, mu_value)
      #_LOGGER.debug(self.stats)
    

    _LOGGER.debug("Raw results: %s" % results)      

     
    classification_result = _reconcile_results(results['left'],
                                               results['right'])
    
    buffer = [{"timestamp": mu_timestamp, "value": classification_result}]

    self.classification_publisher.publish(self.routing_keys["classification"],
                                          buffer)


  def _update_stats(self, hemisphere, mu_value):
    """
    Update stats.
     self.stats = {
      "left": {"min": None, "max": None},
      "right": {"min": None, "max": None}
      }
    """
    min_val = self.stats[hemisphere]["min"]
    max_val = self.stats[hemisphere]["max"]
    
    if not min_val:
      self.stats[hemisphere]["min"] = mu_value
    if not max_val:
      self.stats[hemisphere]["max"] = mu_value  
    if mu_value < min_val:
      self.stats[hemisphere]["min"] = mu_value
    if mu_value > max_val:
      self.stats[hemisphere]["max"] = mu_value  
class PreprocessingModule(object):
    def __init__(self, user_id, module_id, device_type, rmq_address, rmq_user, rmq_pwd, step_size):
        self.user_id = user_id
        self.module_id = module_id
        self.device_type = device_type
        self.rmq_address = rmq_address
        self.rmq_user = rmq_user
        self.rmq_pwd = rmq_pwd

        self.eeg_subscriber = None
        self.mu_publisher = None

        self.routing_keys = {
            _EEG: _ROUTING_KEY % (user_id, device_type, _EEG),
            _MU: _ROUTING_KEY % (user_id, device_type, _MU),
        }

        self.preprocessor = None

        self.eeg_data = np.zeros((0, 8))
        self.count = 0

        self.eyeblinks_remover = EyeBlinksRemover()

        self.step_size = step_size

        self.started_fit = False

    def initialize(self):
        """
    Initialize EEG preprocessor, publisher, and subscriber
    """
        self.mu_publisher = PikaPublisher(self.rmq_address, self.rmq_user, self.rmq_pwd)
        self.eeg_subscriber = PikaSubscriber(self.rmq_address, self.rmq_user, self.rmq_pwd)

        self.eeg_subscriber.connect()
        self.mu_publisher.connect()

        self.mu_publisher.register(self.routing_keys[_MU])
        self.eeg_subscriber.subscribe(self.routing_keys[_EEG])

    def start(self):
        _LOGGER.info("[Module %s] Starting Preprocessing. Routing " "keys: %s" % (self.module_id, self.routing_keys))

        self.eeg_subscriber.consume_messages(self.routing_keys[_EEG], self._preprocess)

    def refit_ica(self):
        t = Thread(target=self.eyeblinks_remover.fit, args=(self.eeg_data[1000:],))
        t.start()
        # self.eyeblinks_remover.fit(self.eeg_data[1000:])

    def _preprocess(self, ch, method, properties, body):
        eeg = json.loads(body)

        self.eeg_data = np.vstack([self.eeg_data, get_raw(eeg)])

        # self.count += len(eeg)
        self.count += self.step_size

        print(self.count)

        if (self.count >= 5000 and not self.started_fit) or self.count % 10000 == 0:
            _LOGGER.info("refitting...")
            self.started_fit = True
            self.refit_ica()

        timestamp = eeg[-1]["timestamp"]

        eeg = from_raw(self.eyeblinks_remover.transform(get_raw(eeg)))

        process = preprocess_stft(eeg, _METADATA)

        mu_left = process["left"][-1]
        mu_right = process["right"][-1]

        data = {"timestamp": timestamp, "left": mu_left, "right": mu_right}

        _LOGGER.debug("--> mu: %s" % data)
        if self.eyeblinks_remover.fitted:
            self.mu_publisher.publish(self.routing_keys[_MU], data)