def train(self): """Train HMMs""" initial_hmm_files = self.get_initial_hmm_files() if len(initial_hmm_files) == 0: raise ModelException, "No initial HMM files found." if not os.path.exists(self.TRAIN_HMM_ROOT): os.makedirs(self.TRAIN_HMM_ROOT) trainer = self.Trainer() viterbi_trainer = ViterbiTrainer(self.ViterbiCalculator(), non_diagonal=self.NON_DIAGONAL) for file in initial_hmm_files: char_code = int(os.path.basename(file).split(".")[0]) hmm = MultivariateHmm.from_file(file) sset_file = os.path.join(self.TRAIN_FEATURES_ROOT, str(char_code) + ".sset") sset = self.get_sequence_set(sset_file) output_file = os.path.join(self.TRAIN_HMM_ROOT, "%d.xml" % char_code) if self.TRAINING in (self.TRAINING_VITERBI, self.TRAINING_BOTH): self.print_verbose("Viterbi training: " + output_file) viterbi_trainer.train(hmm, sset) if self.TRAINING in (self.TRAINING_BAUM_WELCH, self.TRAINING_BOTH): self.print_verbose("Baum-Welch training: " + output_file) trainer.train(hmm, sset) hmm.write(output_file)
def get_hmms_from_files(self, files): hmms = [] for file in files: char_code = int(os.path.basename(file).split(".")[0]) hmm = MultivariateHmm.from_file(file) hmm.char_code = char_code hmms.append(hmm) return hmms