def train(self, weighted_ink_groups, verbose=True): self.trained_prototypes = [] for weighted_ink_list in weighted_ink_groups: # we skip small clusters if len(weighted_ink_list) > self.min_cluster_size: ink_data, weights = zip(*weighted_ink_list) proto = PrototypeDTW(self.label, alpha=self.alpha) avg_score = proto.train( ink_data, obs_weights=weights, center_type=self.center_type, state_reduction=self.state_reduction) self.trained_prototypes.append(proto)
def closest_prototype(con, user_id, label, user_data): cur = con.cursor() cur.execute(""" SELECT t1.protoset_json FROM protosets as t1 JOIN (SELECT MAX(protoset_id) as pid FROM protosets WHERE label=%s AND user_id!=%s GROUP BY user_id) as t2 ON t1.protoset_id = t2.pid """, (label,user_id,)) rows = cur.fetchall() # no prototypes found, return empty if len(rows) == 0: return [] # create a list of prototypes proto_list = [] for row in rows: protoset_json = json.loads(row[0]) for prototype_json in protoset_json['prototypes']: p = PrototypeDTW(label) p.fromJSON(prototype_json) proto_list.append(p) # select the best prototype count = np.zeros(len(proto_list)) for ink in user_data: scores = np.zeros(len(proto_list)) for i,p in enumerate(proto_list): scores[i] = p.score(ink) count[scores.argmax()] += 1 best_proto_ink = proto_list[count.argmax()].model # make sure the penup is binary best_proto_ink[:,_PU_IDX] = best_proto_ink[:,_PU_IDX].round() return [best_proto_ink]