示例#1
0
def create_hmm(qsr_seq, filename):
    r = ROSClient()
    q, d = r.call_service(
        HMMRepRequestCreate(
            qsr_seq=qsr_seq,
            qsr_type="qtch"
        )
    )
    with open(filename+".hmm", 'w') as f: f.write(d)
示例#2
0
class TestHMM(unittest.TestCase):
    QTCB_SAMPLE_TEST_HMM  = find_resource(PKG, 'qtcb_sample_test.hmm')[0]
    QTCC_SAMPLE_TEST_HMM  = find_resource(PKG, 'qtcc_sample_test.hmm')[0]
    QTCBC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcbc_sample_test.hmm')[0]
    QTCB_PASSBY_LEFT_HMM  = find_resource(PKG, 'qtcb_passby_left.hmm')[0]
    QTCC_PASSBY_LEFT_HMM  = find_resource(PKG, 'qtcc_passby_left.hmm')[0]
    QTCBC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcbc_passby_left.hmm')[0]
    RCC3_TEST_HMM         = find_resource(PKG, 'rcc3_test.hmm')[0]
    QTCB_QSR              = find_resource(PKG, 'qtcb.qsr')[0]
    QTCC_QSR              = find_resource(PKG, 'qtcc.qsr')[0]
    QTCBC_QSR             = find_resource(PKG, 'qtcbc.qsr')[0]
    RCC3_QSR              = find_resource(PKG, 'rcc3.qsr')[0]

    correct_samples = {
        "qtcb": [[u'--', u'0-', u'+-', u'+0', u'++']],
        "qtcc": [[u'--+-', u'--0-', u'----', u'0---', u'+---', u'+0--', u'++--', u'---']],
        "qtcbc": [[u'--', u'----', u'0---', u'+---', u'+0--', u'++--', u'++']]
    }

    correct_hashsum = {
        "qtcb": "3fb65b50d0f7631a300132e8bca9ca13",
        "qtcc": "dbf1529cb0b0c90aaebbe7eafe0e9b05",
        "qtcbc": "0a3acf7b48c4c1155931442c86317ce4",
        "rcc3": "6d5bcc6c44d9b1120c738efa1994a40a"
    }

    correct_loglikelihoods ={
        "qtcb": -2.16887,
        "qtcc": -6.07188,
        "qtcbc": -2.23475,
        "rcc3": -4.15914
    }


    def __init__(self, *args):
        super(TestHMM, self).__init__(*args)

        rospy.init_node(NAME)

        self.r = ROSClient()

    def _create_hmm(self, qsr_file, qsr_type):
        with open(qsr_file, 'r') as f: qsr_seq = json.load(f)
        _, d = self.r.call_service(
            HMMRepRequestCreate(
                qsr_seq=qsr_seq,
                qsr_type=qsr_type
            )
        )
        return d

    def _create_sample(self, hmm_file, qsr_type):
        with open(hmm_file, 'r') as f: hmm = f.read()
        _, s = self.r.call_service(
            HMMRepRequestSample(
                qsr_type=qsr_type,
                xml=hmm,
                max_length=10,
                num_samples=1
            )
        )
        return s

    def _calculate_loglikelihood(self, hmm_file, qsr_file, qsr_type):
        with open(qsr_file, 'r') as f: qsr_seq = json.load(f)
        with open(hmm_file, 'r') as f: hmm = f.read()
        q, l = self.r.call_service(
            HMMRepRequestLogLikelihood(
                qsr_type=qsr_type,
                xml=hmm,
                qsr_seq=qsr_seq
            )
        )
        return round(l, 5)

    def _to_strings(self, array):
        return [x.values()[0] for x in array]

    def test_qtcb_create(self):
        res = self._create_hmm(self.QTCB_QSR, 'qtcb')
        self.assertEqual(hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcb"])

    def test_qtcb_sample(self):
        res = self._create_sample(self.QTCB_SAMPLE_TEST_HMM, 'qtcb')
        self.assertEqual(res, self.correct_samples["qtcb"])

    def test_qtcb_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCB_PASSBY_LEFT_HMM, self.QTCB_QSR, 'qtcb')
        self.assertEqual(res, self.correct_loglikelihoods["qtcb"])

    def test_qtcc_create(self):
        res = self._create_hmm(self.QTCC_QSR, 'qtcc')
        self.assertEqual(hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcc"])

    def test_qtcc_sample(self):
        res = self._create_sample(self.QTCC_SAMPLE_TEST_HMM, 'qtcc')
        self.assertEqual(res, self.correct_samples["qtcc"])

    def test_qtcc_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCC_PASSBY_LEFT_HMM, self.QTCC_QSR, 'qtcc')
        self.assertEqual(res, self.correct_loglikelihoods["qtcc"])

    def test_qtcbc_create(self):
        res = self._create_hmm(self.QTCBC_QSR, 'qtcbc')
        self.assertEqual(hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcbc"])

    def test_qtcbc_sample(self):
        res = self._create_sample(self.QTCBC_SAMPLE_TEST_HMM, 'qtcbc')
        self.assertEqual(res, self.correct_samples["qtcbc"])

    def test_qtcbc_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCBC_PASSBY_LEFT_HMM, self.QTCBC_QSR, 'qtcbc')
        self.assertEqual(res, self.correct_loglikelihoods["qtcbc"])

    def test_rcc3_create(self):
        res = self._create_hmm(self.RCC3_QSR, 'rcc3')
        self.assertEqual(hashlib.md5(res).hexdigest(), self.correct_hashsum["rcc3"])

    def test_rcc3_sample(self):
        res = self._create_sample(self.RCC3_TEST_HMM, 'rcc3')
        self.assertTrue(type(res) == list and len(res) > 0)

    def test_rcc3_loglikelihood(self):
        res = self._calculate_loglikelihood(self.RCC3_TEST_HMM, self.RCC3_QSR, 'rcc3')
        self.assertEqual(res, self.correct_loglikelihoods["rcc3"])
    subparsers.add_parser('loglikelihood', parents=[general, log_parse, qtc_parse])

    # Parse arguments
    args = parser.parse_args()

    rospy.init_node("ros_client")
    r = ROSClient()

    if args.action == "create":
        qsr_seq = load_files(args.input)
        d = r.call_service(
            HMMRepRequestCreate(
                qsr_seq=qsr_seq,
                qsr_type=args.qsr_type,
                pseudo_transitions=args.pseudo_transitions,
                lookup_table=load_json_file(args.lookup) if args.lookup != "" else None,
                transition_matrix=load_json_file(args.trans) if args.trans != "" else None,
                emission_matrix=load_json_file(args.emi) if args.emi != "" else None,
                start_at_zero=args.start_at_zero
            )
        )
        with open(args.output, 'w') as f: json.dump(d, f)

    elif args.action == "sample":
        with open(args.input, 'r') as f: hmm = json.load(f)
        s = r.call_service(
            HMMRepRequestSample(
                qsr_type=args.qsr_type,
                dictionary=hmm,
                max_length=args.max_length,
                num_samples=args.num_samples,
class StatePredictor(object):
    __interaction_types = ["passby", "pathcrossing"]
    __filters = {}
    __classification_results = {}

    model = None
    rules = {}

    def __init__(self, name):
        self.isDebug = False

        rospy.loginfo("Starting %s ..." % name)
        self.client = ROSClient()
        qmc = QTCModelCreation()
        obs = qmc.create_observation_model(qtc_type=qmc.qtc_types.qtch, start_end=True)
        self.lookup = qmc.create_states(qtc_type=qmc.qtc_types.qtch, start_end=True)
        hmc = HMMModelCreation()
        m = PfModel()
        for f in os.listdir(rospy.get_param("~model_dir")):
            filename = rospy.get_param("~model_dir") + '/' + f
            if f.endswith(".hmm"):
                rospy.loginfo("Creating prediction model from: %s", filename)
                pred = hmc.create_prediction_model(input=filename)
                m.add_model(f.split('.')[0], pred, obs)
            elif f.endswith(".rules"):
                with open(filename) as fd:
                    rospy.loginfo("Reading rules from: %s", filename)
                    # Necessary due to old rule format:
                    self.rules[f.split('.')[0]] = self._create_proper_qtc_keys(json.load(fd))

        self.model = m.get()

        visualisation = rospy.get_param("~visualisation")
        self.visualisation_colors = visualisation['models']
        self.default_color = ColorRGBA(
            a=1.0,
            r=visualisation['default']['r']/255.,
            g=visualisation['default']['g']/255.,
            b=visualisation['default']['b']/255.
        )

        self.pub = rospy.Publisher("~prediction_array", QTCPredictionArray, queue_size=10)
        self.markpub = rospy.Publisher("~marker_array", MarkerArray, queue_size=10)
        rospy.Subscriber(rospy.get_param("~qtc_topic", "/online_qtc_creator/qtc_array"), QTCArray, self.callback, queue_size=1)
        rospy.Subscriber(rospy.get_param("~ppl_topic", "/people_tracker/positions"), PeopleTracker, self.people_callback, queue_size=1)
        rospy.loginfo("... all done")

    def _create_proper_qtc_keys(self, dictionary):
        ret = {}
        for k,v in dictionary.items():
            if isinstance(v,dict):
                v = self._create_proper_qtc_keys(v)
            ret[','.join(k.replace('-1','-').replace('1','+').replace('9','').replace(',',''))] = v
        return ret

    def callback(self, msg):
        out = QTCPredictionArray()
        out.header = msg.header
        for q in msg.qtc:
            m = QTCPrediction()
            m.uuid = q.uuid
            if q.uuid not in self.__filters:
                self.client.call_service(
                    PfRepRequestCreate(
                        num_particles=1000,
                        models=self.model,
                        state_lookup_table=self.lookup,
                        uuid=q.uuid,
                        ensure_particle_per_state=True,
                        debug=self.isDebug,
                        starvation_factor=0.1
                    )
                )
            self.__filters[q.uuid] = msg.header.stamp.to_sec()

            qtc_robot = json.loads(q.qtc_robot_human)[-1].split(',')
            qtc_goal = json.loads(q.qtc_goal_human)[-1].split(',')

            qtc = [qtc_goal[1], qtc_goal[3], qtc_robot[1]]
            if len(qtc_robot) == 4:
                qtc.append(qtc_robot[3])
            qtc = ','.join(qtc)
#            start = time.time()
            qtc_state = self.client.call_service(
                PfRepRequestUpdate(
                    uuid=q.uuid,
                    observation=qtc,
                    debug=self.isDebug
                )
            )
#            print "+++ elapsed", time.time()-start
            if qtc_state == None:
                rospy.logwarn("[" + rospy.get_name() + "]: " +  str(qtc_state) + " state reported, aborting" )
                return
                

            self.__classification_results[q.uuid] = qtc_state[2]
            
            try:
                states = self.rules[qtc_state[2]][qtc_state[0]].keys()
                probs = self.rules[qtc_state[2]][qtc_state[0]].values() # Both lists are always in a corresponding order
            except KeyError as e:
                rospy.logwarn("%s not in rules" % e)
                return
            pred = qtc_state[0].split(',')
            print pred
            prediction = states[probs.index(max(probs))].split(',')
            print prediction
            prediction = [prediction[0], pred[2]] \
                if len(pred) < 4 else [prediction[0], pred[2], prediction[1], pred[3]]
            prediction = ','.join(prediction)
            rospy.logdebug("Prediction: %s" %prediction)
            if prediction == None:
                return
            m.qtc_serialised = json.dumps(prediction)
            out.qtc.append(m)
        self.pub.publish(out)
        self.__decay(self.__filters, decay_time=10.)

    def _create_qtc_string(self, qtc):
        qtc = np.array(qtc)
        qtc = qtc[qtc!=9.]
        return ','.join(map(str, qtc.astype(int))).replace('-1','-').replace('1','+')

    def __decay(self, filters, decay_time=60.):
        for k in filters.keys():
            if filters[k] + decay_time < rospy.Time.now().to_sec():
                rospy.logdebug("[" + rospy.get_name() + "]: " + "Deleting particle filter: " + str(k) + " last seen " + str(filters[k]) )
                self.client.call_service(PfRepRequestRemove(uuid=k))
                del filters[k]

    def people_callback(self, msg):
        people = [[],[]]
        for uuid, pose in zip(msg.uuids, msg.poses):
            people[0].append(pose)
            try:
                people[1].append(
                    ColorRGBA(
                        a=1.0,
                        r=self.visualisation_colors[self.__classification_results[uuid]]['color']['r']/255.,
                        g=self.visualisation_colors[self.__classification_results[uuid]]['color']['g']/255.,
                        b=self.visualisation_colors[self.__classification_results[uuid]]['color']['b']/255.
                    )
                )
            except KeyError:
                people[1].append(self.default_color)

        self.markpub.publish(
            mc.marker_array_from_people_tracker_msg(
                poses=people[0],
                target_frame=msg.header.frame_id,
                color=people[1]
            )
        )
class TestHMM(unittest.TestCase):
    QTCB_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcb_sample_test.hmm')[0]
    QTCC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcc_sample_test.hmm')[0]
    QTCBC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcbc_sample_test.hmm')[0]
    QTCB_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcb_passby_left.hmm')[0]
    QTCC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcc_passby_left.hmm')[0]
    QTCBC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcbc_passby_left.hmm')[0]
    RCC3_TEST_HMM = find_resource(PKG, 'rcc3_test.hmm')[0]
    QTCB_QSR = find_resource(PKG, 'qtcb.qsr')[0]
    QTCC_QSR = find_resource(PKG, 'qtcc.qsr')[0]
    QTCBC_QSR = find_resource(PKG, 'qtcbc.qsr')[0]
    RCC3_QSR = find_resource(PKG, 'rcc3.qsr')[0]

    correct_samples = {
        "qtcb": [[u'--', u'0-', u'+-', u'+0', u'++']],
        "qtcc": [[
            u'--+-', u'--0-', u'----', u'0---', u'+---', u'+0--', u'++--',
            u'---'
        ]],
        "qtcbc": [[u'--', u'----', u'0---', u'+---', u'+0--', u'++--', u'++']]
    }

    correct_hashsum = {
        "qtcb": "de1ea75d1b0d6c9ff8249a24583fedb9",
        "qtcc": "5d74fb27e53ba00f84014d3be70e2740",
        "qtcbc": "f73e2c85f1447a5109d6ab3d5201fb76",
        "rcc3": "402c53a1cc004a5f518d0b607eb7ac38"
    }

    correct_loglikelihoods = {
        "qtcb": -2.16887,
        "qtcc": -6.07188,
        "qtcbc": -2.23475,
        "rcc3": -4.15914
    }

    def __init__(self, *args):
        super(TestHMM, self).__init__(*args)

        rospy.init_node(NAME)

        self.r = ROSClient()

    def _create_hmm(self, qsr_file, qsr_type, start_at_zero=True):
        with open(qsr_file, 'r') as f:
            qsr_seq = json.load(f)
        d = self.r.call_service(
            HMMRepRequestCreate(qsr_seq=qsr_seq,
                                qsr_type=qsr_type,
                                start_at_zero=start_at_zero))
        return d

    def _create_sample(self, hmm_file, qsr_type):
        with open(hmm_file, 'r') as f:
            hmm = json.load(f)
        s = self.r.call_service(
            HMMRepRequestSample(qsr_type=qsr_type,
                                dictionary=hmm,
                                max_length=10,
                                num_samples=1))
        return s

    def _calculate_loglikelihood(self, hmm_file, qsr_file, qsr_type):
        with open(qsr_file, 'r') as f:
            qsr_seq = json.load(f)
        with open(hmm_file, 'r') as f:
            hmm = json.load(f)
        l = self.r.call_service(
            HMMRepRequestLogLikelihood(qsr_type=qsr_type,
                                       dictionary=hmm,
                                       qsr_seq=qsr_seq))
        return round(l, 5)

    def _to_strings(self, array):
        return [x.values()[0] for x in array]

    def test_qtcb_create(self):
        res = self._create_hmm(self.QTCB_QSR, 'qtcb')
        self.assertEqual(
            hashlib.md5(json.dumps(res)).hexdigest(),
            self.correct_hashsum["qtcb"])

    def test_qtcb_sample(self):
        res = self._create_sample(self.QTCB_SAMPLE_TEST_HMM, 'qtcb')
        self.assertEqual(res, self.correct_samples["qtcb"])

    def test_qtcb_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCB_PASSBY_LEFT_HMM,
                                            self.QTCB_QSR, 'qtcb')
        self.assertEqual(res, self.correct_loglikelihoods["qtcb"])

    def test_qtcc_create(self):
        res = self._create_hmm(self.QTCC_QSR, 'qtcc')
        self.assertEqual(
            hashlib.md5(json.dumps(res)).hexdigest(),
            self.correct_hashsum["qtcc"])

    def test_qtcc_sample(self):
        res = self._create_sample(self.QTCC_SAMPLE_TEST_HMM, 'qtcc')
        self.assertEqual(res, self.correct_samples["qtcc"])

    def test_qtcc_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCC_PASSBY_LEFT_HMM,
                                            self.QTCC_QSR, 'qtcc')
        self.assertEqual(res, self.correct_loglikelihoods["qtcc"])

    def test_qtcbc_create(self):
        res = self._create_hmm(self.QTCBC_QSR, 'qtcbc')
        self.assertEqual(
            hashlib.md5(json.dumps(res)).hexdigest(),
            self.correct_hashsum["qtcbc"])

    def test_qtcbc_sample(self):
        res = self._create_sample(self.QTCBC_SAMPLE_TEST_HMM, 'qtcbc')
        self.assertEqual(res, self.correct_samples["qtcbc"])

    def test_qtcbc_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCBC_PASSBY_LEFT_HMM,
                                            self.QTCBC_QSR, 'qtcbc')
        self.assertEqual(res, self.correct_loglikelihoods["qtcbc"])

    def test_rcc3_create(self):
        res = self._create_hmm(self.RCC3_QSR, 'rcc3')
        self.assertEqual(
            hashlib.md5(json.dumps(res)).hexdigest(),
            self.correct_hashsum["rcc3"])

    def test_rcc3_sample(self):
        res = self._create_sample(self.RCC3_TEST_HMM, 'rcc3')
        self.assertTrue(type(res) == list and len(res) > 0)

    def test_rcc3_loglikelihood(self):
        res = self._calculate_loglikelihood(self.RCC3_TEST_HMM, self.RCC3_QSR,
                                            'rcc3')
        self.assertEqual(res, self.correct_loglikelihoods["rcc3"])
    log_parse = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,add_help=False)
    log_parse.add_argument('-i', '--input', help="the xml file containing the HMM", type=str, required=True)
    log_parse.add_argument('-q', '--qsr_seq', help="reads a file containing state chains", type=str, required=True)
    subparsers.add_parser('loglikelihood', parents=[general, log_parse, qtc_parse])

    # Parse arguments
    args = parser.parse_args()

    rospy.init_node("ros_client")
    r = ROSClient()

    if args.action == "create":
        qsr_seq = load_files(args.input)
        q, d = r.call_service(
            HMMRepRequestCreate(
                qsr_seq=qsr_seq,
                qsr_type=args.qsr_type
            )
        )
        with open(args.output, 'w') as f: f.write(d)

    elif args.action == "sample":
        with open(args.input, 'r') as f: hmm = f.read()
        q, s = r.call_service(
            HMMRepRequestSample(
                qsr_type=args.qsr_type,
                xml=hmm,
                max_length=args.max_length,
                num_samples=args.num_samples
            )
        )
        try:
    # Parse arguments
    args = parser.parse_args()

    rospy.init_node("ros_client")
    r = ROSClient()

    if args.action == "create":
        models = load_files(args.model)
        states = load_lookup(args.states)
        d = r.call_service(
            PfRepRequestCreate(
                num_particles=args.num_particles,
                models=models,
                state_lookup_table=states,
                starvation_factor=args.starvation_factor,
                ensure_particle_per_state=args.ensure_particle_per_state,
                debug=args.debug,
                uuid=args.uuid if args.uuid != "" else None
            )
        )
        print d

    elif args.action == "predict":
        p = r.call_service(
            PfRepRequestPredict(
                uuid=args.uuid,
                num_steps=args.num_steps,
                debug=args.debug
            )
        )
示例#8
0
class TestHMM(unittest.TestCase):
    QTCB_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcb_sample_test.hmm')[0]
    QTCC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcc_sample_test.hmm')[0]
    QTCBC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcbc_sample_test.hmm')[0]
    QTCB_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcb_passby_left.hmm')[0]
    QTCC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcc_passby_left.hmm')[0]
    QTCBC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcbc_passby_left.hmm')[0]
    RCC3_TEST_HMM = find_resource(PKG, 'rcc3_test.hmm')[0]
    QTCB_QSR = find_resource(PKG, 'qtcb.qsr')[0]
    QTCC_QSR = find_resource(PKG, 'qtcc.qsr')[0]
    QTCBC_QSR = find_resource(PKG, 'qtcbc.qsr')[0]
    RCC3_QSR = find_resource(PKG, 'rcc3.qsr')[0]

    correct_samples = {
        "qtcb": [[u'--', u'0-', u'+-', u'+0', u'++']],
        "qtcc": [[
            u'--+-', u'--0-', u'----', u'0---', u'+---', u'+0--', u'++--',
            u'---'
        ]],
        "qtcbc": [[u'--', u'----', u'0---', u'+---', u'+0--', u'++--', u'++']]
    }

    correct_hashsum = {
        "qtcb": "3fb65b50d0f7631a300132e8bca9ca13",
        "qtcc": "dbf1529cb0b0c90aaebbe7eafe0e9b05",
        "qtcbc": "0a3acf7b48c4c1155931442c86317ce4",
        "rcc3": "6d5bcc6c44d9b1120c738efa1994a40a"
    }

    correct_loglikelihoods = {
        "qtcb": -2.16887,
        "qtcc": -6.07188,
        "qtcbc": -2.23475,
        "rcc3": -4.15914
    }

    def __init__(self, *args):
        super(TestHMM, self).__init__(*args)

        rospy.init_node(NAME)

        self.r = ROSClient()

    def _create_hmm(self, qsr_file, qsr_type):
        with open(qsr_file, 'r') as f:
            qsr_seq = json.load(f)
        _, d = self.r.call_service(
            HMMRepRequestCreate(qsr_seq=qsr_seq, qsr_type=qsr_type))
        return d

    def _create_sample(self, hmm_file, qsr_type):
        with open(hmm_file, 'r') as f:
            hmm = f.read()
        _, s = self.r.call_service(
            HMMRepRequestSample(qsr_type=qsr_type,
                                xml=hmm,
                                max_length=10,
                                num_samples=1))
        return s

    def _calculate_loglikelihood(self, hmm_file, qsr_file, qsr_type):
        with open(qsr_file, 'r') as f:
            qsr_seq = json.load(f)
        with open(hmm_file, 'r') as f:
            hmm = f.read()
        q, l = self.r.call_service(
            HMMRepRequestLogLikelihood(qsr_type=qsr_type,
                                       xml=hmm,
                                       qsr_seq=qsr_seq))
        return round(l, 5)

    def _to_strings(self, array):
        return [x.values()[0] for x in array]

    def test_qtcb_create(self):
        res = self._create_hmm(self.QTCB_QSR, 'qtcb')
        self.assertEqual(
            hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcb"])

    def test_qtcb_sample(self):
        res = self._create_sample(self.QTCB_SAMPLE_TEST_HMM, 'qtcb')
        self.assertEqual(res, self.correct_samples["qtcb"])

    def test_qtcb_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCB_PASSBY_LEFT_HMM,
                                            self.QTCB_QSR, 'qtcb')
        self.assertEqual(res, self.correct_loglikelihoods["qtcb"])

    def test_qtcc_create(self):
        res = self._create_hmm(self.QTCC_QSR, 'qtcc')
        self.assertEqual(
            hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcc"])

    def test_qtcc_sample(self):
        res = self._create_sample(self.QTCC_SAMPLE_TEST_HMM, 'qtcc')
        self.assertEqual(res, self.correct_samples["qtcc"])

    def test_qtcc_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCC_PASSBY_LEFT_HMM,
                                            self.QTCC_QSR, 'qtcc')
        self.assertEqual(res, self.correct_loglikelihoods["qtcc"])

    def test_qtcbc_create(self):
        res = self._create_hmm(self.QTCBC_QSR, 'qtcbc')
        self.assertEqual(
            hashlib.md5(res).hexdigest(), self.correct_hashsum["qtcbc"])

    def test_qtcbc_sample(self):
        res = self._create_sample(self.QTCBC_SAMPLE_TEST_HMM, 'qtcbc')
        self.assertEqual(res, self.correct_samples["qtcbc"])

    def test_qtcbc_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCBC_PASSBY_LEFT_HMM,
                                            self.QTCBC_QSR, 'qtcbc')
        self.assertEqual(res, self.correct_loglikelihoods["qtcbc"])

    def test_rcc3_create(self):
        res = self._create_hmm(self.RCC3_QSR, 'rcc3')
        self.assertEqual(
            hashlib.md5(res).hexdigest(), self.correct_hashsum["rcc3"])

    def test_rcc3_sample(self):
        res = self._create_sample(self.RCC3_TEST_HMM, 'rcc3')
        self.assertTrue(type(res) == list and len(res) > 0)

    def test_rcc3_loglikelihood(self):
        res = self._calculate_loglikelihood(self.RCC3_TEST_HMM, self.RCC3_QSR,
                                            'rcc3')
        self.assertEqual(res, self.correct_loglikelihoods["rcc3"])
示例#9
0
class TestPf(unittest.TestCase):
    CROSSING_OBS   = find_resource(PKG, 'crossing.obs')[0]
    CROSSING_PRED  = find_resource(PKG, 'crossing.pred')[0]
    PASSBY_OBS     = find_resource(PKG, 'passby.obs')[0]
    PASSBY_PRED    = find_resource(PKG, 'passby.pred')[0]
    STATES         = find_resource(PKG, 'qtch.states')[0]

    filters = map(str, range(5))

    def __init__(self, *args):
        super(self.__class__, self).__init__(*args)

        rospy.init_node(NAME)

        self.r = ROSClient()

    def _create_model(self, obs, pred):
        m = PfModel()
        for f in pred:
            name = f.split('/')[-1].split('.')[0]
            with open(f, 'r') as a:
                m.add_prediction_matrix(name, np.loadtxt(a))

        for f in obs:
            name = f.split('/')[-1].split('.')[0]
            with open(f, 'r') as a:
                m.add_observation_matrix(name, np.loadtxt(a))

        return m.get()

    def _load_lookup(self, filename):
        with open(filename, 'r') as f:
            return json.load(f)

    def _create_pf(self, pred, obs, states, num_particles=1000, uuid="", starvation_factor=0.1, ensure_particle_per_state=True, debug=True):
        models = self._create_model(pred, obs)
        states = self._load_lookup(states)
        d = self.r.call_service(
            PfRepRequestCreate(
                num_particles=num_particles,
                models=models,
                state_lookup_table=states,
                starvation_factor=starvation_factor,
                ensure_particle_per_state=ensure_particle_per_state,
                debug=debug,
                uuid=uuid if uuid != "" else None
            )
        )
        return d

    def _predict(self, uuid, num_steps=5, debug=True):
        p = self.r.call_service(
            PfRepRequestPredict(
                uuid=uuid,
                num_steps=num_steps,
                debug=debug
            )
        )
        return p

    def _update(self, uuid, obs, debug=True):
        p = self.r.call_service(
            PfRepRequestUpdate(
                uuid=uuid,
                observation=obs,
                debug=debug
            )
        )
        return p

    def _list(self):
        p = self.r.call_service(
            PfRepRequestList()
        )
        return p

    def _remove(self, uuid):
        p = self.r.call_service(
            PfRepRequestRemove(
                uuid=uuid
            )
        )
        return p

    def test1_pf_create(self):
        res = []
        for uuid in self.filters:
            res.append(self._create_pf(
                pred=[self.CROSSING_PRED, self.PASSBY_PRED],
                obs=[self.CROSSING_OBS, self.PASSBY_OBS],
                states=self.STATES,
                uuid=uuid
            ))
        self.assertEqual(map(str, res), self.filters)

    def test2_pf_predict(self):
        res = []
        for uuid in self.filters:
            res.append(self._predict(
                uuid=uuid,
                num_steps=1
            )[0])
        self.assertEqual(len(res), len(self.filters))

    def test3_pf_update(self):
        res = []
        for uuid in self.filters:
            res.append(self._update(
                uuid=uuid,
                obs='-,-,-'
            ))
        self.assertEqual(len(res), len(self.filters))

    def test4_pf_list(self):
        res = self._list()
        self.assertEqual(len(res), len(self.filters))

    def test5_pf_remove(self):
        for uuid in self.filters:
            self._remove(uuid)
        self.assertFalse(len(self._list()))
示例#10
0
class TestHMM(unittest.TestCase):
    QTCB_SAMPLE_TEST_HMM  = find_resource(PKG, 'qtcb_sample_test.hmm')[0]
    QTCC_SAMPLE_TEST_HMM  = find_resource(PKG, 'qtcc_sample_test.hmm')[0]
    QTCBC_SAMPLE_TEST_HMM = find_resource(PKG, 'qtcbc_sample_test.hmm')[0]
    QTCB_PASSBY_LEFT_HMM  = find_resource(PKG, 'qtcb_passby_left.hmm')[0]
    QTCC_PASSBY_LEFT_HMM  = find_resource(PKG, 'qtcc_passby_left.hmm')[0]
    QTCBC_PASSBY_LEFT_HMM = find_resource(PKG, 'qtcbc_passby_left.hmm')[0]
    RCC3_TEST_HMM         = find_resource(PKG, 'rcc3_test.hmm')[0]
    QTCB_QSR              = find_resource(PKG, 'qtcb.qsr')[0]
    QTCC_QSR              = find_resource(PKG, 'qtcc.qsr')[0]
    QTCBC_QSR             = find_resource(PKG, 'qtcbc.qsr')[0]
    RCC3_QSR              = find_resource(PKG, 'rcc3.qsr')[0]

    correct_samples = {
        "qtcb": [[u'--', u'0-', u'+-', u'+0', u'++']],
        "qtcc": [[u'--+-', u'--0-', u'----', u'0---', u'+---', u'+0--', u'++--', u'---']],
        "qtcbc": [[u'--', u'----', u'0---', u'+---', u'+0--', u'++--', u'++']]
    }

    correct_hashsum = {
        "qtcb": "de1ea75d1b0d6c9ff8249a24583fedb9",
        "qtcc": "5d74fb27e53ba00f84014d3be70e2740",
        "qtcbc": "f73e2c85f1447a5109d6ab3d5201fb76",
        "rcc3": "402c53a1cc004a5f518d0b607eb7ac38"
    }

    correct_loglikelihoods ={
        "qtcb": -2.16887,
        "qtcc": -6.07188,
        "qtcbc": -2.23475,
        "rcc3": -4.15914
    }


    def __init__(self, *args):
        super(TestHMM, self).__init__(*args)

        rospy.init_node(NAME)

        self.r = ROSClient()

    def _create_hmm(self, qsr_file, qsr_type, start_at_zero=True):
        with open(qsr_file, 'r') as f: qsr_seq = json.load(f)
        d = self.r.call_service(
            HMMRepRequestCreate(
                qsr_seq=qsr_seq,
                qsr_type=qsr_type,
                start_at_zero=start_at_zero
            )
        )
        return d

    def _create_sample(self, hmm_file, qsr_type):
        with open(hmm_file, 'r') as f: hmm = json.load(f)
        s = self.r.call_service(
            HMMRepRequestSample(
                qsr_type=qsr_type,
                dictionary=hmm,
                max_length=10,
                num_samples=1
            )
        )
        return s

    def _calculate_loglikelihood(self, hmm_file, qsr_file, qsr_type):
        with open(qsr_file, 'r') as f: qsr_seq = json.load(f)
        with open(hmm_file, 'r') as f: hmm = json.load(f)
        l = self.r.call_service(
            HMMRepRequestLogLikelihood(
                qsr_type=qsr_type,
                dictionary=hmm,
                qsr_seq=qsr_seq
            )
        )
        return round(l, 5)

    def _to_strings(self, array):
        return [x.values()[0] for x in array]

    def test_qtcb_create(self):
        res = self._create_hmm(self.QTCB_QSR, 'qtcb')
        self.assertEqual(hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["qtcb"])

    def test_qtcb_sample(self):
        res = self._create_sample(self.QTCB_SAMPLE_TEST_HMM, 'qtcb')
        self.assertEqual(res, self.correct_samples["qtcb"])

    def test_qtcb_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCB_PASSBY_LEFT_HMM, self.QTCB_QSR, 'qtcb')
        self.assertEqual(res, self.correct_loglikelihoods["qtcb"])

    def test_qtcc_create(self):
        res = self._create_hmm(self.QTCC_QSR, 'qtcc')
        self.assertEqual(hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["qtcc"])

    def test_qtcc_sample(self):
        res = self._create_sample(self.QTCC_SAMPLE_TEST_HMM, 'qtcc')
        self.assertEqual(res, self.correct_samples["qtcc"])

    def test_qtcc_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCC_PASSBY_LEFT_HMM, self.QTCC_QSR, 'qtcc')
        self.assertEqual(res, self.correct_loglikelihoods["qtcc"])

    def test_qtcbc_create(self):
        res = self._create_hmm(self.QTCBC_QSR, 'qtcbc')
        self.assertEqual(hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["qtcbc"])

    def test_qtcbc_sample(self):
        res = self._create_sample(self.QTCBC_SAMPLE_TEST_HMM, 'qtcbc')
        self.assertEqual(res, self.correct_samples["qtcbc"])

    def test_qtcbc_loglikelihood(self):
        res = self._calculate_loglikelihood(self.QTCBC_PASSBY_LEFT_HMM, self.QTCBC_QSR, 'qtcbc')
        self.assertEqual(res, self.correct_loglikelihoods["qtcbc"])

    def test_rcc3_create(self):
        res = self._create_hmm(self.RCC3_QSR, 'rcc3')
        self.assertEqual(hashlib.md5(json.dumps(res)).hexdigest(), self.correct_hashsum["rcc3"])

    def test_rcc3_sample(self):
        res = self._create_sample(self.RCC3_TEST_HMM, 'rcc3')
        self.assertTrue(type(res) == list and len(res) > 0)

    def test_rcc3_loglikelihood(self):
        res = self._calculate_loglikelihood(self.RCC3_TEST_HMM, self.RCC3_QSR, 'rcc3')
        self.assertEqual(res, self.correct_loglikelihoods["rcc3"])
示例#11
0
class TestPf(unittest.TestCase):
    CROSSING_OBS = find_resource(PKG, 'crossing.obs')[0]
    CROSSING_PRED = find_resource(PKG, 'crossing.pred')[0]
    PASSBY_OBS = find_resource(PKG, 'passby.obs')[0]
    PASSBY_PRED = find_resource(PKG, 'passby.pred')[0]
    STATES = find_resource(PKG, 'qtch.states')[0]

    filters = map(str, range(5))

    def __init__(self, *args):
        super(self.__class__, self).__init__(*args)

        rospy.init_node(NAME)

        self.r = ROSClient()

    def _create_model(self, obs, pred):
        m = PfModel()
        for f in pred:
            name = f.split('/')[-1].split('.')[0]
            with open(f, 'r') as a:
                m.add_prediction_matrix(name, np.loadtxt(a))

        for f in obs:
            name = f.split('/')[-1].split('.')[0]
            with open(f, 'r') as a:
                m.add_observation_matrix(name, np.loadtxt(a))

        return m.get()

    def _load_lookup(self, filename):
        with open(filename, 'r') as f:
            return json.load(f)

    def _create_pf(self,
                   pred,
                   obs,
                   states,
                   num_particles=1000,
                   uuid="",
                   starvation_factor=0.1,
                   ensure_particle_per_state=True,
                   debug=True):
        models = self._create_model(pred, obs)
        states = self._load_lookup(states)
        d = self.r.call_service(
            PfRepRequestCreate(
                num_particles=num_particles,
                models=models,
                state_lookup_table=states,
                starvation_factor=starvation_factor,
                ensure_particle_per_state=ensure_particle_per_state,
                debug=debug,
                uuid=uuid if uuid != "" else None))
        return d

    def _predict(self, uuid, num_steps=5, debug=True):
        p = self.r.call_service(
            PfRepRequestPredict(uuid=uuid, num_steps=num_steps, debug=debug))
        return p

    def _update(self, uuid, obs, debug=True):
        p = self.r.call_service(
            PfRepRequestUpdate(uuid=uuid, observation=obs, debug=debug))
        return p

    def _list(self):
        p = self.r.call_service(PfRepRequestList())
        return p

    def _remove(self, uuid):
        p = self.r.call_service(PfRepRequestRemove(uuid=uuid))
        return p

    def test1_pf_create(self):
        res = []
        for uuid in self.filters:
            res.append(
                self._create_pf(pred=[self.CROSSING_PRED, self.PASSBY_PRED],
                                obs=[self.CROSSING_OBS, self.PASSBY_OBS],
                                states=self.STATES,
                                uuid=uuid))
        self.assertEqual(map(str, res), self.filters)

    def test2_pf_predict(self):
        res = []
        for uuid in self.filters:
            res.append(self._predict(uuid=uuid, num_steps=1)[0])
        self.assertEqual(len(res), len(self.filters))

    def test3_pf_update(self):
        res = []
        for uuid in self.filters:
            res.append(self._update(uuid=uuid, obs='-,-,-'))
        self.assertEqual(len(res), len(self.filters))

    def test4_pf_list(self):
        res = self._list()
        self.assertEqual(len(res), len(self.filters))

    def test5_pf_remove(self):
        for uuid in self.filters:
            self._remove(uuid)
        self.assertFalse(len(self._list()))
示例#12
0
                           '--qsr_seq',
                           help="reads a file containing state chains",
                           type=str,
                           required=True)
    subparsers.add_parser('loglikelihood',
                          parents=[general, log_parse, qtc_parse])

    # Parse arguments
    args = parser.parse_args()

    rospy.init_node("ros_client")
    r = ROSClient()

    if args.action == "create":
        qsr_seq = load_files(args.input)
        q, d = r.call_service(
            HMMRepRequestCreate(qsr_seq=qsr_seq, qsr_type=args.qsr_type))
        with open(args.output, 'w') as f:
            f.write(d)

    elif args.action == "sample":
        with open(args.input, 'r') as f:
            hmm = f.read()
        q, s = r.call_service(
            HMMRepRequestSample(qsr_type=args.qsr_type,
                                xml=hmm,
                                max_length=args.max_length,
                                num_samples=args.num_samples))
        try:
            with open(args.output, 'w') as f:
                json.dump(s, f)
        except TypeError:
    subparsers.add_parser('remove', parents=[general, remove_parse])

    # Parse arguments
    args = parser.parse_args()

    rospy.init_node("ros_client")
    r = ROSClient()

    if args.action == "create":
        models = load_files(args.model)
        states = load_lookup(args.states)
        d = r.call_service(
            PfRepRequestCreate(
                num_particles=args.num_particles,
                models=models,
                state_lookup_table=states,
                starvation_factor=args.starvation_factor,
                ensure_particle_per_state=args.ensure_particle_per_state,
                debug=args.debug,
                uuid=args.uuid if args.uuid != "" else None))
        print d

    elif args.action == "predict":
        p = r.call_service(
            PfRepRequestPredict(uuid=args.uuid,
                                num_steps=args.num_steps,
                                debug=args.debug))
        print p
    elif args.action == "update":
        p = r.call_service(
            PfRepRequestUpdate(uuid=args.uuid,