예제 #1
0
 def train(self, weighted_ink_groups, verbose=True):
     self.trained_prototypes = []
     for weighted_ink_list in weighted_ink_groups:
         # we skip small clusters 
         if len(weighted_ink_list) > self.min_cluster_size:
             ink_data, weights = zip(*weighted_ink_list)
             proto = PrototypeDTW(self.label, alpha=self.alpha)
             avg_score = proto.train(
                 ink_data, 
                 obs_weights=weights,
                 center_type=self.center_type,
                 state_reduction=self.state_reduction)
             self.trained_prototypes.append(proto)
예제 #2
0
def closest_prototype(con, user_id, label, user_data):
    cur = con.cursor()
    cur.execute("""
         SELECT t1.protoset_json
         FROM protosets as t1
         JOIN
            (SELECT MAX(protoset_id) as pid
             FROM protosets 
             WHERE label=%s AND user_id!=%s
             GROUP BY user_id) as t2
         ON t1.protoset_id = t2.pid
        """, (label,user_id,))
    rows = cur.fetchall()
    
    # no prototypes found, return empty
    if len(rows) == 0:
        return []

    # create a list of prototypes
    proto_list = []
    for row in rows:
        protoset_json = json.loads(row[0])
        for prototype_json in protoset_json['prototypes']:
            p = PrototypeDTW(label)
            p.fromJSON(prototype_json)
            proto_list.append(p)

    # select the best prototype
    count = np.zeros(len(proto_list))
    for ink in user_data:
        scores = np.zeros(len(proto_list))
        for i,p in enumerate(proto_list):
            scores[i] = p.score(ink)
        count[scores.argmax()] += 1
    best_proto_ink = proto_list[count.argmax()].model

    # make sure the penup is binary
    best_proto_ink[:,_PU_IDX] = best_proto_ink[:,_PU_IDX].round()
    return [best_proto_ink]