def create_hmm(qsr_seq, filename): r = ROSClient() q, d = r.call_service( HMMRepRequestCreate( qsr_seq=qsr_seq, qsr_type="qtch" ) ) with open(filename+".hmm", 'w') as f: f.write(d)
class TestHMM(unittest.TestCase): QTCB_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcb_sample_test.hmm')[0] QTCC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcc_sample_test.hmm')[0] QTCBC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcbc_sample_test.hmm')[0] QTCB_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcb_passby_left.hmm')[0] QTCC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcc_passby_left.hmm')[0] QTCBC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcbc_passby_left.hmm')[0] RCC3_TEST_HMM = find_resource(PKG, 'rcc3_test.hmm')[0] QTCB_QSR = find_resource(PKG, 'qtcb.qsr')[0] QTCC_QSR = find_resource(PKG, 'qtcc.qsr')[0] QTCBC_QSR = find_resource(PKG, 'qtcbc.qsr')[0] RCC3_QSR = find_resource(PKG, 'rcc3.qsr')[0] correct_samples = { "qtcb": [[u'--', u'0-', u'+-', u'+0', u'++']], "qtcc": [[u'--+-', u'--0-', u'----', u'0---', u'+---', u'+0--', u'++--', u'---']], "qtcbc": [[u'--', u'----', u'0---', u'+---', u'+0--', u'++--', u'++']] } correct_hashsum = { "qtcb": "3fb65b50d0f7631a300132e8bca9ca13", "qtcc": "dbf1529cb0b0c90aaebbe7eafe0e9b05", "qtcbc": "0a3acf7b48c4c1155931442c86317ce4", "rcc3": "6d5bcc6c44d9b1120c738efa1994a40a" } correct_loglikelihoods ={ "qtcb": -2.16887, "qtcc": -6.07188, "qtcbc": -2.23475, "rcc3": -4.15914 } def __init__(self, *args): super(TestHMM, self).__init__(*args) rospy.init_node(NAME) self.r = ROSClient() def _create_hmm(self, qsr_file, qsr_type): with open(qsr_file, 'r') as f: qsr_seq = json.load(f) _, d = self.r.call_service( HMMRepRequestCreate( qsr_seq=qsr_seq, qsr_type=qsr_type ) ) return d def _create_sample(self, hmm_file, qsr_type): with open(hmm_file, 'r') as f: hmm = f.read() _, s = self.r.call_service( HMMRepRequestSample( qsr_type=qsr_type, xml=hmm, max_length=10, num_samples=1 ) ) return s def _calculate_loglikelihood(self, hmm_file, qsr_file, qsr_type): with open(qsr_file, 'r') as f: qsr_seq = json.load(f) with open(hmm_file, 'r') as f: hmm = f.read() q, l = self.r.call_service( HMMRepRequestLogLikelihood( qsr_type=qsr_type, xml=hmm, qsr_seq=qsr_seq ) ) return round(l, 5) def _to_strings(self, array): return [x.values()[0] for x in array] def test_qtcb_create(self): res = self._create_hmm(self.QTCB_QSR, 'qtcb') self.assertEqual(hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcb"]) def test_qtcb_sample(self): res = self._create_sample(self.QTCB_SAMPLE_TEST_HMM, 'qtcb') self.assertEqual(res, self.correct_samples["qtcb"]) def test_qtcb_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCB_PASSBY_LEFT_HMM, self.QTCB_QSR, 'qtcb') self.assertEqual(res, self.correct_loglikelihoods["qtcb"]) def test_qtcc_create(self): res = self._create_hmm(self.QTCC_QSR, 'qtcc') self.assertEqual(hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcc"]) def test_qtcc_sample(self): res = self._create_sample(self.QTCC_SAMPLE_TEST_HMM, 'qtcc') self.assertEqual(res, self.correct_samples["qtcc"]) def test_qtcc_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCC_PASSBY_LEFT_HMM, self.QTCC_QSR, 'qtcc') self.assertEqual(res, self.correct_loglikelihoods["qtcc"]) def test_qtcbc_create(self): res = self._create_hmm(self.QTCBC_QSR, 'qtcbc') self.assertEqual(hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcbc"]) def test_qtcbc_sample(self): res = self._create_sample(self.QTCBC_SAMPLE_TEST_HMM, 'qtcbc') self.assertEqual(res, self.correct_samples["qtcbc"]) def test_qtcbc_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCBC_PASSBY_LEFT_HMM, self.QTCBC_QSR, 'qtcbc') self.assertEqual(res, self.correct_loglikelihoods["qtcbc"]) def test_rcc3_create(self): res = self._create_hmm(self.RCC3_QSR, 'rcc3') self.assertEqual(hashlib.md5(res).hexdigest(), self.correct_hashsum["rcc3"]) def test_rcc3_sample(self): res = self._create_sample(self.RCC3_TEST_HMM, 'rcc3') self.assertTrue(type(res) == list and len(res) > 0) def test_rcc3_loglikelihood(self): res = self._calculate_loglikelihood(self.RCC3_TEST_HMM, self.RCC3_QSR, 'rcc3') self.assertEqual(res, self.correct_loglikelihoods["rcc3"])
subparsers.add_parser('loglikelihood', parents=[general, log_parse, qtc_parse]) # Parse arguments args = parser.parse_args() rospy.init_node("ros_client") r = ROSClient() if args.action == "create": qsr_seq = load_files(args.input) d = r.call_service( HMMRepRequestCreate( qsr_seq=qsr_seq, qsr_type=args.qsr_type, pseudo_transitions=args.pseudo_transitions, lookup_table=load_json_file(args.lookup) if args.lookup != "" else None, transition_matrix=load_json_file(args.trans) if args.trans != "" else None, emission_matrix=load_json_file(args.emi) if args.emi != "" else None, start_at_zero=args.start_at_zero ) ) with open(args.output, 'w') as f: json.dump(d, f) elif args.action == "sample": with open(args.input, 'r') as f: hmm = json.load(f) s = r.call_service( HMMRepRequestSample( qsr_type=args.qsr_type, dictionary=hmm, max_length=args.max_length, num_samples=args.num_samples,
class StatePredictor(object): __interaction_types = ["passby", "pathcrossing"] __filters = {} __classification_results = {} model = None rules = {} def __init__(self, name): self.isDebug = False rospy.loginfo("Starting %s ..." % name) self.client = ROSClient() qmc = QTCModelCreation() obs = qmc.create_observation_model(qtc_type=qmc.qtc_types.qtch, start_end=True) self.lookup = qmc.create_states(qtc_type=qmc.qtc_types.qtch, start_end=True) hmc = HMMModelCreation() m = PfModel() for f in os.listdir(rospy.get_param("~model_dir")): filename = rospy.get_param("~model_dir") + '/' + f if f.endswith(".hmm"): rospy.loginfo("Creating prediction model from: %s", filename) pred = hmc.create_prediction_model(input=filename) m.add_model(f.split('.')[0], pred, obs) elif f.endswith(".rules"): with open(filename) as fd: rospy.loginfo("Reading rules from: %s", filename) # Necessary due to old rule format: self.rules[f.split('.')[0]] = self._create_proper_qtc_keys(json.load(fd)) self.model = m.get() visualisation = rospy.get_param("~visualisation") self.visualisation_colors = visualisation['models'] self.default_color = ColorRGBA( a=1.0, r=visualisation['default']['r']/255., g=visualisation['default']['g']/255., b=visualisation['default']['b']/255. ) self.pub = rospy.Publisher("~prediction_array", QTCPredictionArray, queue_size=10) self.markpub = rospy.Publisher("~marker_array", MarkerArray, queue_size=10) rospy.Subscriber(rospy.get_param("~qtc_topic", "/online_qtc_creator/qtc_array"), QTCArray, self.callback, queue_size=1) rospy.Subscriber(rospy.get_param("~ppl_topic", "/people_tracker/positions"), PeopleTracker, self.people_callback, queue_size=1) rospy.loginfo("... all done") def _create_proper_qtc_keys(self, dictionary): ret = {} for k,v in dictionary.items(): if isinstance(v,dict): v = self._create_proper_qtc_keys(v) ret[','.join(k.replace('-1','-').replace('1','+').replace('9','').replace(',',''))] = v return ret def callback(self, msg): out = QTCPredictionArray() out.header = msg.header for q in msg.qtc: m = QTCPrediction() m.uuid = q.uuid if q.uuid not in self.__filters: self.client.call_service( PfRepRequestCreate( num_particles=1000, models=self.model, state_lookup_table=self.lookup, uuid=q.uuid, ensure_particle_per_state=True, debug=self.isDebug, starvation_factor=0.1 ) ) self.__filters[q.uuid] = msg.header.stamp.to_sec() qtc_robot = json.loads(q.qtc_robot_human)[-1].split(',') qtc_goal = json.loads(q.qtc_goal_human)[-1].split(',') qtc = [qtc_goal[1], qtc_goal[3], qtc_robot[1]] if len(qtc_robot) == 4: qtc.append(qtc_robot[3]) qtc = ','.join(qtc) # start = time.time() qtc_state = self.client.call_service( PfRepRequestUpdate( uuid=q.uuid, observation=qtc, debug=self.isDebug ) ) # print "+++ elapsed", time.time()-start if qtc_state == None: rospy.logwarn("[" + rospy.get_name() + "]: " + str(qtc_state) + " state reported, aborting" ) return self.__classification_results[q.uuid] = qtc_state[2] try: states = self.rules[qtc_state[2]][qtc_state[0]].keys() probs = self.rules[qtc_state[2]][qtc_state[0]].values() # Both lists are always in a corresponding order except KeyError as e: rospy.logwarn("%s not in rules" % e) return pred = qtc_state[0].split(',') print pred prediction = states[probs.index(max(probs))].split(',') print prediction prediction = [prediction[0], pred[2]] \ if len(pred) < 4 else [prediction[0], pred[2], prediction[1], pred[3]] prediction = ','.join(prediction) rospy.logdebug("Prediction: %s" %prediction) if prediction == None: return m.qtc_serialised = json.dumps(prediction) out.qtc.append(m) self.pub.publish(out) self.__decay(self.__filters, decay_time=10.) def _create_qtc_string(self, qtc): qtc = np.array(qtc) qtc = qtc[qtc!=9.] return ','.join(map(str, qtc.astype(int))).replace('-1','-').replace('1','+') def __decay(self, filters, decay_time=60.): for k in filters.keys(): if filters[k] + decay_time < rospy.Time.now().to_sec(): rospy.logdebug("[" + rospy.get_name() + "]: " + "Deleting particle filter: " + str(k) + " last seen " + str(filters[k]) ) self.client.call_service(PfRepRequestRemove(uuid=k)) del filters[k] def people_callback(self, msg): people = [[],[]] for uuid, pose in zip(msg.uuids, msg.poses): people[0].append(pose) try: people[1].append( ColorRGBA( a=1.0, r=self.visualisation_colors[self.__classification_results[uuid]]['color']['r']/255., g=self.visualisation_colors[self.__classification_results[uuid]]['color']['g']/255., b=self.visualisation_colors[self.__classification_results[uuid]]['color']['b']/255. ) ) except KeyError: people[1].append(self.default_color) self.markpub.publish( mc.marker_array_from_people_tracker_msg( poses=people[0], target_frame=msg.header.frame_id, color=people[1] ) )
class TestHMM(unittest.TestCase): QTCB_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcb_sample_test.hmm')[0] QTCC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcc_sample_test.hmm')[0] QTCBC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcbc_sample_test.hmm')[0] QTCB_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcb_passby_left.hmm')[0] QTCC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcc_passby_left.hmm')[0] QTCBC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcbc_passby_left.hmm')[0] RCC3_TEST_HMM = find_resource(PKG, 'rcc3_test.hmm')[0] QTCB_QSR = find_resource(PKG, 'qtcb.qsr')[0] QTCC_QSR = find_resource(PKG, 'qtcc.qsr')[0] QTCBC_QSR = find_resource(PKG, 'qtcbc.qsr')[0] RCC3_QSR = find_resource(PKG, 'rcc3.qsr')[0] correct_samples = { "qtcb": [[u'--', u'0-', u'+-', u'+0', u'++']], "qtcc": [[ u'--+-', u'--0-', u'----', u'0---', u'+---', u'+0--', u'++--', u'---' ]], "qtcbc": [[u'--', u'----', u'0---', u'+---', u'+0--', u'++--', u'++']] } correct_hashsum = { "qtcb": "de1ea75d1b0d6c9ff8249a24583fedb9", "qtcc": "5d74fb27e53ba00f84014d3be70e2740", "qtcbc": "f73e2c85f1447a5109d6ab3d5201fb76", "rcc3": "402c53a1cc004a5f518d0b607eb7ac38" } correct_loglikelihoods = { "qtcb": -2.16887, "qtcc": -6.07188, "qtcbc": -2.23475, "rcc3": -4.15914 } def __init__(self, *args): super(TestHMM, self).__init__(*args) rospy.init_node(NAME) self.r = ROSClient() def _create_hmm(self, qsr_file, qsr_type, start_at_zero=True): with open(qsr_file, 'r') as f: qsr_seq = json.load(f) d = self.r.call_service( HMMRepRequestCreate(qsr_seq=qsr_seq, qsr_type=qsr_type, start_at_zero=start_at_zero)) return d def _create_sample(self, hmm_file, qsr_type): with open(hmm_file, 'r') as f: hmm = json.load(f) s = self.r.call_service( HMMRepRequestSample(qsr_type=qsr_type, dictionary=hmm, max_length=10, num_samples=1)) return s def _calculate_loglikelihood(self, hmm_file, qsr_file, qsr_type): with open(qsr_file, 'r') as f: qsr_seq = json.load(f) with open(hmm_file, 'r') as f: hmm = json.load(f) l = self.r.call_service( HMMRepRequestLogLikelihood(qsr_type=qsr_type, dictionary=hmm, qsr_seq=qsr_seq)) return round(l, 5) def _to_strings(self, array): return [x.values()[0] for x in array] def test_qtcb_create(self): res = self._create_hmm(self.QTCB_QSR, 'qtcb') self.assertEqual( hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["qtcb"]) def test_qtcb_sample(self): res = self._create_sample(self.QTCB_SAMPLE_TEST_HMM, 'qtcb') self.assertEqual(res, self.correct_samples["qtcb"]) def test_qtcb_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCB_PASSBY_LEFT_HMM, self.QTCB_QSR, 'qtcb') self.assertEqual(res, self.correct_loglikelihoods["qtcb"]) def test_qtcc_create(self): res = self._create_hmm(self.QTCC_QSR, 'qtcc') self.assertEqual( hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["qtcc"]) def test_qtcc_sample(self): res = self._create_sample(self.QTCC_SAMPLE_TEST_HMM, 'qtcc') self.assertEqual(res, self.correct_samples["qtcc"]) def test_qtcc_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCC_PASSBY_LEFT_HMM, self.QTCC_QSR, 'qtcc') self.assertEqual(res, self.correct_loglikelihoods["qtcc"]) def test_qtcbc_create(self): res = self._create_hmm(self.QTCBC_QSR, 'qtcbc') self.assertEqual( hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["qtcbc"]) def test_qtcbc_sample(self): res = self._create_sample(self.QTCBC_SAMPLE_TEST_HMM, 'qtcbc') self.assertEqual(res, self.correct_samples["qtcbc"]) def test_qtcbc_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCBC_PASSBY_LEFT_HMM, self.QTCBC_QSR, 'qtcbc') self.assertEqual(res, self.correct_loglikelihoods["qtcbc"]) def test_rcc3_create(self): res = self._create_hmm(self.RCC3_QSR, 'rcc3') self.assertEqual( hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["rcc3"]) def test_rcc3_sample(self): res = self._create_sample(self.RCC3_TEST_HMM, 'rcc3') self.assertTrue(type(res) == list and len(res) > 0) def test_rcc3_loglikelihood(self): res = self._calculate_loglikelihood(self.RCC3_TEST_HMM, self.RCC3_QSR, 'rcc3') self.assertEqual(res, self.correct_loglikelihoods["rcc3"])
log_parse = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,add_help=False) log_parse.add_argument('-i', '--input', help="the xml file containing the HMM", type=str, required=True) log_parse.add_argument('-q', '--qsr_seq', help="reads a file containing state chains", type=str, required=True) subparsers.add_parser('loglikelihood', parents=[general, log_parse, qtc_parse]) # Parse arguments args = parser.parse_args() rospy.init_node("ros_client") r = ROSClient() if args.action == "create": qsr_seq = load_files(args.input) q, d = r.call_service( HMMRepRequestCreate( qsr_seq=qsr_seq, qsr_type=args.qsr_type ) ) with open(args.output, 'w') as f: f.write(d) elif args.action == "sample": with open(args.input, 'r') as f: hmm = f.read() q, s = r.call_service( HMMRepRequestSample( qsr_type=args.qsr_type, xml=hmm, max_length=args.max_length, num_samples=args.num_samples ) ) try:
# Parse arguments args = parser.parse_args() rospy.init_node("ros_client") r = ROSClient() if args.action == "create": models = load_files(args.model) states = load_lookup(args.states) d = r.call_service( PfRepRequestCreate( num_particles=args.num_particles, models=models, state_lookup_table=states, starvation_factor=args.starvation_factor, ensure_particle_per_state=args.ensure_particle_per_state, debug=args.debug, uuid=args.uuid if args.uuid != "" else None ) ) print d elif args.action == "predict": p = r.call_service( PfRepRequestPredict( uuid=args.uuid, num_steps=args.num_steps, debug=args.debug ) )
class TestHMM(unittest.TestCase): QTCB_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcb_sample_test.hmm')[0] QTCC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcc_sample_test.hmm')[0] QTCBC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcbc_sample_test.hmm')[0] QTCB_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcb_passby_left.hmm')[0] QTCC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcc_passby_left.hmm')[0] QTCBC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcbc_passby_left.hmm')[0] RCC3_TEST_HMM = find_resource(PKG, 'rcc3_test.hmm')[0] QTCB_QSR = find_resource(PKG, 'qtcb.qsr')[0] QTCC_QSR = find_resource(PKG, 'qtcc.qsr')[0] QTCBC_QSR = find_resource(PKG, 'qtcbc.qsr')[0] RCC3_QSR = find_resource(PKG, 'rcc3.qsr')[0] correct_samples = { "qtcb": [[u'--', u'0-', u'+-', u'+0', u'++']], "qtcc": [[ u'--+-', u'--0-', u'----', u'0---', u'+---', u'+0--', u'++--', u'---' ]], "qtcbc": [[u'--', u'----', u'0---', u'+---', u'+0--', u'++--', u'++']] } correct_hashsum = { "qtcb": "3fb65b50d0f7631a300132e8bca9ca13", "qtcc": "dbf1529cb0b0c90aaebbe7eafe0e9b05", "qtcbc": "0a3acf7b48c4c1155931442c86317ce4", "rcc3": "6d5bcc6c44d9b1120c738efa1994a40a" } correct_loglikelihoods = { "qtcb": -2.16887, "qtcc": -6.07188, "qtcbc": -2.23475, "rcc3": -4.15914 } def __init__(self, *args): super(TestHMM, self).__init__(*args) rospy.init_node(NAME) self.r = ROSClient() def _create_hmm(self, qsr_file, qsr_type): with open(qsr_file, 'r') as f: qsr_seq = json.load(f) _, d = self.r.call_service( HMMRepRequestCreate(qsr_seq=qsr_seq, qsr_type=qsr_type)) return d def _create_sample(self, hmm_file, qsr_type): with open(hmm_file, 'r') as f: hmm = f.read() _, s = self.r.call_service( HMMRepRequestSample(qsr_type=qsr_type, xml=hmm, max_length=10, num_samples=1)) return s def _calculate_loglikelihood(self, hmm_file, qsr_file, qsr_type): with open(qsr_file, 'r') as f: qsr_seq = json.load(f) with open(hmm_file, 'r') as f: hmm = f.read() q, l = self.r.call_service( HMMRepRequestLogLikelihood(qsr_type=qsr_type, xml=hmm, qsr_seq=qsr_seq)) return round(l, 5) def _to_strings(self, array): return [x.values()[0] for x in array] def test_qtcb_create(self): res = self._create_hmm(self.QTCB_QSR, 'qtcb') self.assertEqual( hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcb"]) def test_qtcb_sample(self): res = self._create_sample(self.QTCB_SAMPLE_TEST_HMM, 'qtcb') self.assertEqual(res, self.correct_samples["qtcb"]) def test_qtcb_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCB_PASSBY_LEFT_HMM, self.QTCB_QSR, 'qtcb') self.assertEqual(res, self.correct_loglikelihoods["qtcb"]) def test_qtcc_create(self): res = self._create_hmm(self.QTCC_QSR, 'qtcc') self.assertEqual( hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcc"]) def test_qtcc_sample(self): res = self._create_sample(self.QTCC_SAMPLE_TEST_HMM, 'qtcc') self.assertEqual(res, self.correct_samples["qtcc"]) def test_qtcc_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCC_PASSBY_LEFT_HMM, self.QTCC_QSR, 'qtcc') self.assertEqual(res, self.correct_loglikelihoods["qtcc"]) def test_qtcbc_create(self): res = self._create_hmm(self.QTCBC_QSR, 'qtcbc') self.assertEqual( hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcbc"]) def test_qtcbc_sample(self): res = self._create_sample(self.QTCBC_SAMPLE_TEST_HMM, 'qtcbc') self.assertEqual(res, self.correct_samples["qtcbc"]) def test_qtcbc_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCBC_PASSBY_LEFT_HMM, self.QTCBC_QSR, 'qtcbc') self.assertEqual(res, self.correct_loglikelihoods["qtcbc"]) def test_rcc3_create(self): res = self._create_hmm(self.RCC3_QSR, 'rcc3') self.assertEqual( hashlib.md5(res).hexdigest(), self.correct_hashsum["rcc3"]) def test_rcc3_sample(self): res = self._create_sample(self.RCC3_TEST_HMM, 'rcc3') self.assertTrue(type(res) == list and len(res) > 0) def test_rcc3_loglikelihood(self): res = self._calculate_loglikelihood(self.RCC3_TEST_HMM, self.RCC3_QSR, 'rcc3') self.assertEqual(res, self.correct_loglikelihoods["rcc3"])
class TestPf(unittest.TestCase): CROSSING_OBS = find_resource(PKG, 'crossing.obs')[0] CROSSING_PRED = find_resource(PKG, 'crossing.pred')[0] PASSBY_OBS = find_resource(PKG, 'passby.obs')[0] PASSBY_PRED = find_resource(PKG, 'passby.pred')[0] STATES = find_resource(PKG, 'qtch.states')[0] filters = map(str, range(5)) def __init__(self, *args): super(self.__class__, self).__init__(*args) rospy.init_node(NAME) self.r = ROSClient() def _create_model(self, obs, pred): m = PfModel() for f in pred: name = f.split('/')[-1].split('.')[0] with open(f, 'r') as a: m.add_prediction_matrix(name, np.loadtxt(a)) for f in obs: name = f.split('/')[-1].split('.')[0] with open(f, 'r') as a: m.add_observation_matrix(name, np.loadtxt(a)) return m.get() def _load_lookup(self, filename): with open(filename, 'r') as f: return json.load(f) def _create_pf(self, pred, obs, states, num_particles=1000, uuid="", starvation_factor=0.1, ensure_particle_per_state=True, debug=True): models = self._create_model(pred, obs) states = self._load_lookup(states) d = self.r.call_service( PfRepRequestCreate( num_particles=num_particles, models=models, state_lookup_table=states, starvation_factor=starvation_factor, ensure_particle_per_state=ensure_particle_per_state, debug=debug, uuid=uuid if uuid != "" else None ) ) return d def _predict(self, uuid, num_steps=5, debug=True): p = self.r.call_service( PfRepRequestPredict( uuid=uuid, num_steps=num_steps, debug=debug ) ) return p def _update(self, uuid, obs, debug=True): p = self.r.call_service( PfRepRequestUpdate( uuid=uuid, observation=obs, debug=debug ) ) return p def _list(self): p = self.r.call_service( PfRepRequestList() ) return p def _remove(self, uuid): p = self.r.call_service( PfRepRequestRemove( uuid=uuid ) ) return p def test1_pf_create(self): res = [] for uuid in self.filters: res.append(self._create_pf( pred=[self.CROSSING_PRED, self.PASSBY_PRED], obs=[self.CROSSING_OBS, self.PASSBY_OBS], states=self.STATES, uuid=uuid )) self.assertEqual(map(str, res), self.filters) def test2_pf_predict(self): res = [] for uuid in self.filters: res.append(self._predict( uuid=uuid, num_steps=1 )[0]) self.assertEqual(len(res), len(self.filters)) def test3_pf_update(self): res = [] for uuid in self.filters: res.append(self._update( uuid=uuid, obs='-,-,-' )) self.assertEqual(len(res), len(self.filters)) def test4_pf_list(self): res = self._list() self.assertEqual(len(res), len(self.filters)) def test5_pf_remove(self): for uuid in self.filters: self._remove(uuid) self.assertFalse(len(self._list()))
class TestHMM(unittest.TestCase): QTCB_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcb_sample_test.hmm')[0] QTCC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcc_sample_test.hmm')[0] QTCBC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcbc_sample_test.hmm')[0] QTCB_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcb_passby_left.hmm')[0] QTCC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcc_passby_left.hmm')[0] QTCBC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcbc_passby_left.hmm')[0] RCC3_TEST_HMM = find_resource(PKG, 'rcc3_test.hmm')[0] QTCB_QSR = find_resource(PKG, 'qtcb.qsr')[0] QTCC_QSR = find_resource(PKG, 'qtcc.qsr')[0] QTCBC_QSR = find_resource(PKG, 'qtcbc.qsr')[0] RCC3_QSR = find_resource(PKG, 'rcc3.qsr')[0] correct_samples = { "qtcb": [[u'--', u'0-', u'+-', u'+0', u'++']], "qtcc": [[u'--+-', u'--0-', u'----', u'0---', u'+---', u'+0--', u'++--', u'---']], "qtcbc": [[u'--', u'----', u'0---', u'+---', u'+0--', u'++--', u'++']] } correct_hashsum = { "qtcb": "de1ea75d1b0d6c9ff8249a24583fedb9", "qtcc": "5d74fb27e53ba00f84014d3be70e2740", "qtcbc": "f73e2c85f1447a5109d6ab3d5201fb76", "rcc3": "402c53a1cc004a5f518d0b607eb7ac38" } correct_loglikelihoods ={ "qtcb": -2.16887, "qtcc": -6.07188, "qtcbc": -2.23475, "rcc3": -4.15914 } def __init__(self, *args): super(TestHMM, self).__init__(*args) rospy.init_node(NAME) self.r = ROSClient() def _create_hmm(self, qsr_file, qsr_type, start_at_zero=True): with open(qsr_file, 'r') as f: qsr_seq = json.load(f) d = self.r.call_service( HMMRepRequestCreate( qsr_seq=qsr_seq, qsr_type=qsr_type, start_at_zero=start_at_zero ) ) return d def _create_sample(self, hmm_file, qsr_type): with open(hmm_file, 'r') as f: hmm = json.load(f) s = self.r.call_service( HMMRepRequestSample( qsr_type=qsr_type, dictionary=hmm, max_length=10, num_samples=1 ) ) return s def _calculate_loglikelihood(self, hmm_file, qsr_file, qsr_type): with open(qsr_file, 'r') as f: qsr_seq = json.load(f) with open(hmm_file, 'r') as f: hmm = json.load(f) l = self.r.call_service( HMMRepRequestLogLikelihood( qsr_type=qsr_type, dictionary=hmm, qsr_seq=qsr_seq ) ) return round(l, 5) def _to_strings(self, array): return [x.values()[0] for x in array] def test_qtcb_create(self): res = self._create_hmm(self.QTCB_QSR, 'qtcb') self.assertEqual(hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["qtcb"]) def test_qtcb_sample(self): res = self._create_sample(self.QTCB_SAMPLE_TEST_HMM, 'qtcb') self.assertEqual(res, self.correct_samples["qtcb"]) def test_qtcb_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCB_PASSBY_LEFT_HMM, self.QTCB_QSR, 'qtcb') self.assertEqual(res, self.correct_loglikelihoods["qtcb"]) def test_qtcc_create(self): res = self._create_hmm(self.QTCC_QSR, 'qtcc') self.assertEqual(hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["qtcc"]) def test_qtcc_sample(self): res = self._create_sample(self.QTCC_SAMPLE_TEST_HMM, 'qtcc') self.assertEqual(res, self.correct_samples["qtcc"]) def test_qtcc_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCC_PASSBY_LEFT_HMM, self.QTCC_QSR, 'qtcc') self.assertEqual(res, self.correct_loglikelihoods["qtcc"]) def test_qtcbc_create(self): res = self._create_hmm(self.QTCBC_QSR, 'qtcbc') self.assertEqual(hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["qtcbc"]) def test_qtcbc_sample(self): res = self._create_sample(self.QTCBC_SAMPLE_TEST_HMM, 'qtcbc') self.assertEqual(res, self.correct_samples["qtcbc"]) def test_qtcbc_loglikelihood(self): res = self._calculate_loglikelihood(self.QTCBC_PASSBY_LEFT_HMM, self.QTCBC_QSR, 'qtcbc') self.assertEqual(res, self.correct_loglikelihoods["qtcbc"]) def test_rcc3_create(self): res = self._create_hmm(self.RCC3_QSR, 'rcc3') self.assertEqual(hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["rcc3"]) def test_rcc3_sample(self): res = self._create_sample(self.RCC3_TEST_HMM, 'rcc3') self.assertTrue(type(res) == list and len(res) > 0) def test_rcc3_loglikelihood(self): res = self._calculate_loglikelihood(self.RCC3_TEST_HMM, self.RCC3_QSR, 'rcc3') self.assertEqual(res, self.correct_loglikelihoods["rcc3"])
class TestPf(unittest.TestCase): CROSSING_OBS = find_resource(PKG, 'crossing.obs')[0] CROSSING_PRED = find_resource(PKG, 'crossing.pred')[0] PASSBY_OBS = find_resource(PKG, 'passby.obs')[0] PASSBY_PRED = find_resource(PKG, 'passby.pred')[0] STATES = find_resource(PKG, 'qtch.states')[0] filters = map(str, range(5)) def __init__(self, *args): super(self.__class__, self).__init__(*args) rospy.init_node(NAME) self.r = ROSClient() def _create_model(self, obs, pred): m = PfModel() for f in pred: name = f.split('/')[-1].split('.')[0] with open(f, 'r') as a: m.add_prediction_matrix(name, np.loadtxt(a)) for f in obs: name = f.split('/')[-1].split('.')[0] with open(f, 'r') as a: m.add_observation_matrix(name, np.loadtxt(a)) return m.get() def _load_lookup(self, filename): with open(filename, 'r') as f: return json.load(f) def _create_pf(self, pred, obs, states, num_particles=1000, uuid="", starvation_factor=0.1, ensure_particle_per_state=True, debug=True): models = self._create_model(pred, obs) states = self._load_lookup(states) d = self.r.call_service( PfRepRequestCreate( num_particles=num_particles, models=models, state_lookup_table=states, starvation_factor=starvation_factor, ensure_particle_per_state=ensure_particle_per_state, debug=debug, uuid=uuid if uuid != "" else None)) return d def _predict(self, uuid, num_steps=5, debug=True): p = self.r.call_service( PfRepRequestPredict(uuid=uuid, num_steps=num_steps, debug=debug)) return p def _update(self, uuid, obs, debug=True): p = self.r.call_service( PfRepRequestUpdate(uuid=uuid, observation=obs, debug=debug)) return p def _list(self): p = self.r.call_service(PfRepRequestList()) return p def _remove(self, uuid): p = self.r.call_service(PfRepRequestRemove(uuid=uuid)) return p def test1_pf_create(self): res = [] for uuid in self.filters: res.append( self._create_pf(pred=[self.CROSSING_PRED, self.PASSBY_PRED], obs=[self.CROSSING_OBS, self.PASSBY_OBS], states=self.STATES, uuid=uuid)) self.assertEqual(map(str, res), self.filters) def test2_pf_predict(self): res = [] for uuid in self.filters: res.append(self._predict(uuid=uuid, num_steps=1)[0]) self.assertEqual(len(res), len(self.filters)) def test3_pf_update(self): res = [] for uuid in self.filters: res.append(self._update(uuid=uuid, obs='-,-,-')) self.assertEqual(len(res), len(self.filters)) def test4_pf_list(self): res = self._list() self.assertEqual(len(res), len(self.filters)) def test5_pf_remove(self): for uuid in self.filters: self._remove(uuid) self.assertFalse(len(self._list()))
'--qsr_seq', help="reads a file containing state chains", type=str, required=True) subparsers.add_parser('loglikelihood', parents=[general, log_parse, qtc_parse]) # Parse arguments args = parser.parse_args() rospy.init_node("ros_client") r = ROSClient() if args.action == "create": qsr_seq = load_files(args.input) q, d = r.call_service( HMMRepRequestCreate(qsr_seq=qsr_seq, qsr_type=args.qsr_type)) with open(args.output, 'w') as f: f.write(d) elif args.action == "sample": with open(args.input, 'r') as f: hmm = f.read() q, s = r.call_service( HMMRepRequestSample(qsr_type=args.qsr_type, xml=hmm, max_length=args.max_length, num_samples=args.num_samples)) try: with open(args.output, 'w') as f: json.dump(s, f) except TypeError:
subparsers.add_parser('remove', parents=[general, remove_parse]) # Parse arguments args = parser.parse_args() rospy.init_node("ros_client") r = ROSClient() if args.action == "create": models = load_files(args.model) states = load_lookup(args.states) d = r.call_service( PfRepRequestCreate( num_particles=args.num_particles, models=models, state_lookup_table=states, starvation_factor=args.starvation_factor, ensure_particle_per_state=args.ensure_particle_per_state, debug=args.debug, uuid=args.uuid if args.uuid != "" else None)) print d elif args.action == "predict": p = r.call_service( PfRepRequestPredict(uuid=args.uuid, num_steps=args.num_steps, debug=args.debug)) print p elif args.action == "update": p = r.call_service( PfRepRequestUpdate(uuid=args.uuid,