예제 #1
0
def create_hmm(qsr_seq, filename):
    r = ROSClient()
    q, d = r.call_service(
        HMMRepRequestCreate(
            qsr_seq=qsr_seq,
            qsr_type="qtch"
        )
    )
    with open(filename+".hmm", 'w') as f: f.write(d)
예제 #2
0
    def __init__(self, name):
        self.isDebug = False

        rospy.loginfo("Starting %s ..." % name)
        self.client = ROSClient()
        qmc = QTCModelCreation()
        obs = qmc.create_observation_model(qtc_type=qmc.qtc_types.qtch, start_end=True)
        self.lookup = qmc.create_states(qtc_type=qmc.qtc_types.qtch, start_end=True)
        hmc = HMMModelCreation()
        m = PfModel()
        for f in os.listdir(rospy.get_param("~model_dir")):
            filename = rospy.get_param("~model_dir") + '/' + f
            if f.endswith(".hmm"):
                rospy.loginfo("Creating prediction model from: %s", filename)
                pred = hmc.create_prediction_model(input=filename)
                m.add_model(f.split('.')[0], pred, obs)
            elif f.endswith(".rules"):
                with open(filename) as fd:
                    rospy.loginfo("Reading rules from: %s", filename)
                    # Necessary due to old rule format:
                    self.rules[f.split('.')[0]] = self._create_proper_qtc_keys(json.load(fd))

        self.model = m.get()

        visualisation = rospy.get_param("~visualisation")
        self.visualisation_colors = visualisation['models']
        self.default_color = ColorRGBA(
            a=1.0,
            r=visualisation['default']['r']/255.,
            g=visualisation['default']['g']/255.,
            b=visualisation['default']['b']/255.
        )

        self.pub = rospy.Publisher("~prediction_array", QTCPredictionArray, queue_size=10)
        self.markpub = rospy.Publisher("~marker_array", MarkerArray, queue_size=10)
        rospy.Subscriber(rospy.get_param("~qtc_topic", "/online_qtc_creator/qtc_array"), QTCArray, self.callback, queue_size=1)
        rospy.Subscriber(rospy.get_param("~ppl_topic", "/people_tracker/positions"), PeopleTracker, self.people_callback, queue_size=1)
        rospy.loginfo("... all done")
    sample_parse.add_argument('-l', '--max_length', help="the maximum length of samples which will be ensure if at all possible", type=str, required=True)
    sample_parse.add_argument('--lookup', help="the lookup table json file", type=str, default="")
    subparsers.add_parser('sample', parents=[general, sample_parse, qtc_parse])

    # Parsers for loglikelihood function
    log_parse = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,add_help=False)
    log_parse.add_argument('-i', '--input', help="the xml file containing the HMM", type=str, required=True)
    log_parse.add_argument('-q', '--qsr_seq', help="reads a file containing state chains", type=str, required=True)
    log_parse.add_argument('--lookup', help="the lookup table json file", type=str, default="")
    subparsers.add_parser('loglikelihood', parents=[general, log_parse, qtc_parse])

    # Parse arguments
    args = parser.parse_args()

    rospy.init_node("ros_client")
    r = ROSClient()

    if args.action == "create":
        qsr_seq = load_files(args.input)
        d = r.call_service(
            HMMRepRequestCreate(
                qsr_seq=qsr_seq,
                qsr_type=args.qsr_type,
                pseudo_transitions=args.pseudo_transitions,
                lookup_table=load_json_file(args.lookup) if args.lookup != "" else None,
                transition_matrix=load_json_file(args.trans) if args.trans != "" else None,
                emission_matrix=load_json_file(args.emi) if args.emi != "" else None,
                start_at_zero=args.start_at_zero
            )
        )
        with open(args.output, 'w') as f: json.dump(d, f)
예제 #4
0
    def __init__(self, *args):
        super(TestHMM, self).__init__(*args)

        rospy.init_node(NAME)

        self.r = ROSClient()
예제 #5
0
    def __init__(self, *args):
        super(self.__class__, self).__init__(*args)

        rospy.init_node(NAME)

        self.r = ROSClient()