def train(self, emissions, max_iterations=None, \ convergence_logprob=None, logger=None, processes=1, save=True, save_intermediate=False): """ Performs unsupervised training using Baum-Welch EM. This is an instance method, because it is performed on a model that has already been initialized. You might, for example, create such a model using C{initialize_chord_types}. This is based on the training procedure in NLTK for HMMs: C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}. @type emissions: list of lists of emissions @param emissions: training data. Each element is a list of emissions representing a sequence in the training data. Each emission is an emission like those used for L{jazzparser.misc.raphsto.RaphstoHmm.emission_log_probability}, i.e. a list of note observations @type max_iterations: int @param max_iterations: maximum number of iterations to allow for EM (default 100). Overrides the corresponding module option @type convergence_logprob: float @param convergence_logprob: maximum change in log probability to consider convergence to have been reached (default 1e-3). Overrides the corresponding module option @type logger: logging.Logger @param logger: a logger to send progress logging to @type processes: int @param processes: number processes to spawn. A pool of this many processes will be used to compute distribution updates for sequences in parallel during each iteration. @type save: bool @param save: save the model at the end of training @type save_intermediate: bool @param save_intermediate: save the model after each iteration. Implies C{save} """ from . import raphsto_d if logger is None: from jazzparser.utils.loggers import create_dummy_logger logger = create_dummy_logger() if save_intermediate: save = True # No point in creating more processes than there are sequences if processes > len(emissions): processes = len(emissions) self.model.add_history("Beginning Baum-Welch unigram training on %s" % get_host_info_string()) self.model.add_history("Training on %d sequences (with %s chords)" % \ (len(emissions), ", ".join("%d" % len(seq) for seq in emissions))) # Use kwargs if given, otherwise module options if max_iterations is None: max_iterations = self.options['max_iterations'] if convergence_logprob is None: convergence_logprob = self.options['convergence_logprob'] # Enumerate the states state_ids = dict((state,num) for (num,state) in \ enumerate(self.model.label_dom)) # Enumerate the beat values (they're probably consecutive ints, but # let's not rely on it) beat_ids = dict((beat,num) for (num,beat) in \ enumerate(self.model.beat_dom)) num_beats = len(beat_ids) # Enumerate the d-values (d-function's domain) d_ids = dict((d,num) for (num,d) in \ enumerate(self.model.emission_dist_dom)) num_ds = len(d_ids) # Make a mutable distribution for the emission distribution we'll # be updating emission_mdist = DictionaryConditionalProbDist( dict((s, MutableProbDist(self.model.emission_dist[s], self.model.emission_dist_dom)) for s in self.model.emission_dist.conditions())) # Create dummy distributions to fill the places of the transition # distribution components key_mdist = DictionaryConditionalProbDist({}) chord_mdist = DictionaryConditionalProbDist({}) chord_uni_mdist = MutableProbDist({}, []) # Construct a model using these mutable distributions so we can # evaluate using them model = self.model_cls(key_mdist, chord_mdist, emission_mdist, chord_uni_mdist, chord_set=self.model.chord_set) iteration = 0 last_logprob = None while iteration < max_iterations: logger.info("Beginning iteration %d" % iteration) current_logprob = 0.0 # ems contains the new emission numerator probabilities # ems[r][d] = Sum_{d(y_n^k, x_n)=d, r_n^k=r} # alpha(x_n).beta(x_n) / # Sum_{x'_n} (alpha(x'_n).beta(x'_n)) ems = zeros((num_beats,num_ds), float64) # And these are the denominators ems_denom = zeros(num_beats, float64) def _training_callback(result): """ Callback for the _sequence_updates processes that takes the updates from a single sequence and adds them onto the global update accumulators. """ # _sequence_updates() returns all of this as a tuple (ems_local, ems_denom_local, seq_logprob) = result # Add these probabilities from this sequence to the # global matrices # Emission numerator array_add(ems, ems_local, ems) # Denominators array_add(ems_denom, ems_denom_local, ems_denom) ## End of _training_callback # Only use a process pool if there's more than one sequence if processes > 1: # Create a process pool to use for training logger.info("Creating a pool of %d processes" % processes) pool = Pool(processes=processes) async_results = [] for seq_i,sequence in enumerate(emissions): logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) T = len(sequence) if T == 0: continue # Fire off a new call to the process pool for every sequence async_results.append( pool.apply_async(_sequence_updates_uni, (sequence, model, self.model.label_dom, state_ids, beat_ids, d_ids, raphsto_d), callback=_training_callback) ) pool.close() # Wait for all the workers to complete pool.join() # Call get() on every AsyncResult so that any exceptions in # workers get raised for res in async_results: # If there was an exception in _sequence_update, it # will get raised here res_tuple = res.get() # Add this sequence's logprob into the total for all sequences current_logprob += res_tuple[2] else: logger.info("One sequence: not using a process pool") sequence = emissions[0] if len(sequence) > 0: updates = _sequence_updates_uni( sequence, model, self.model.label_dom, state_ids, beat_ids, d_ids, raphsto_d) _training_callback(updates) # Update the overall logprob current_logprob = updates[2] # Update the model's probabilities from the accumulated values for beat in self.model.beat_dom: denom = ems_denom[beat_ids[beat]] for d in self.model.emission_dist_dom: if denom == 0.0: # Zero denominator prob = - logprob(len(d_ids)) else: prob = logprob(ems[beat_ids[beat]][d_ids[d]] + ADD_SMALL) - logprob(denom + len(d_ids)*ADD_SMALL) model.emission_dist[beat].update(d, prob) # Clear the model's cache so we get the new probabilities model.clear_cache() logger.info("Training data log prob: %s" % current_logprob) if last_logprob is not None and current_logprob < last_logprob: logger.error("Log probability dropped by %s" % \ (last_logprob - current_logprob)) if last_logprob is not None: logger.info("Log prob change: %s" % \ (current_logprob - last_logprob)) # Check whether the log probability has converged if iteration > 0 and \ abs(current_logprob - last_logprob) < convergence_logprob: # Don't iterate any more logger.info("Distribution has converged: ceasing training") break iteration += 1 last_logprob = current_logprob # Update the main model # Only save if we've been asked to save between iterations self.update_model(model, save=save_intermediate) self.model.add_history("Completed Baum-Welch unigram training") # Update the distribution's parameters with those we've trained self.update_model(model, save=save) return
def train(self, emissions, logger=None, save_callback=None): """ Performs unsupervised training using Baum-Welch EM. This is performed on a model that has already been initialized. You might, for example, create such a model using L{jazzparser.taggers.segmidi.chordclass.hmm.ChordClassHmm.initialize_chord_classes}. This is based on the training procedure in NLTK for HMMs: C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}. @type emissions: L{jazzparser.data.input.MidiTaggerTrainingBulkInput} or list of L{jazzparser.data.input.Input}s @param emissions: training MIDI data @type logger: logging.Logger @param logger: a logger to send progress logging to """ if logger is None: from jazzparser.utils.loggers import create_dummy_logger logger = create_dummy_logger() self.model.add_history("Beginning Baum-Welch training on %s" % get_host_info_string()) self.model.add_history("Training on %d MIDI sequences (with %s segments)" % \ (len(emissions), ", ".join("%d" % len(seq) for seq in emissions))) logger.info("Beginning Baum-Welch training on %s" % get_host_info_string()) # Get some options out of the module options max_iterations = self.options['max_iterations'] convergence_logprob = self.options['convergence_logprob'] split_length = self.options['split'] truncate_length = self.options['truncate'] save_intermediate = self.options['save_intermediate'] processes = self.options['trainprocs'] # Make a mutable distribution for each of the distributions # we'll be updating emission_mdist = cond_prob_dist_to_dictionary_cond_prob_dist( self.model.emission_dist, mutable=True) schema_trans_mdist = cond_prob_dist_to_dictionary_cond_prob_dist( self.model.schema_transition_dist, mutable=True) root_trans_mdist = cond_prob_dist_to_dictionary_cond_prob_dist( self.model.root_transition_dist, mutable=True) init_state_mdist = prob_dist_to_dictionary_prob_dist( self.model.initial_state_dist, mutable=True) # Get the sizes we'll need for the matrices num_schemata = len(self.model.schemata) num_root_changes = 12 num_chord_classes = len(self.model.chord_classes) if self.model.metric: num_emission_conds = num_chord_classes * 4 else: num_emission_conds = num_chord_classes num_emissions = 12 # Enumerations to use for the matrices, so we know what they mean schema_ids = dict([(sch,i) for (i,sch) in enumerate(self.model.schemata+[None])]) if self.model.metric: rs = range(4) else: rs = [0] emission_cond_ids = dict([(cc,i) for (i,cc) in enumerate(\ sum([[ (str(cclass.name),r) for r in rs] for cclass in self.model.chord_classes], []))]) # Construct a model using these mutable distributions so we can # evaluate using them model = ChordClassHmm(schema_trans_mdist, root_trans_mdist, emission_mdist, self.model.emission_number_dist, init_state_mdist, self.model.schemata, self.model.chord_class_mapping, self.model.chord_classes, metric=self.model.metric, illegal_transitions=self.model.illegal_transitions, fixed_root_transitions=self.model.fixed_root_transitions) def _save(): if save_callback is None: logger.error("Could not save model, as no callback was given") else: # If the writing fails, wait till I've had a chance to sort it # out and then try again. This happens when my AFS token runs # out while True: try: save_callback() except (IOError, OSError), err: print "Error writing model to disk: %s. " % err raw_input("Press <enter> to try again... ") else: break
def train(self, emissions, max_iterations=None, \ convergence_logprob=None, logger=None, processes=1, save=True, save_intermediate=False): """ Performs unsupervised training using Baum-Welch EM. This is an instance method, because it is performed on a model that has already been initialized. You might, for example, create such a model using C{initialize_chord_types}. This is based on the training procedure in NLTK for HMMs: C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}. @type emissions: list of lists of emissions @param emissions: training data. Each element is a list of emissions representing a sequence in the training data. Each emission is an emission like those used for L{jazzparser.misc.raphsto.RaphstoHmm.emission_log_probability}, i.e. a list of note observations @type max_iterations: int @param max_iterations: maximum number of iterations to allow for EM (default 100). Overrides the corresponding module option @type convergence_logprob: float @param convergence_logprob: maximum change in log probability to consider convergence to have been reached (default 1e-3). Overrides the corresponding module option @type logger: logging.Logger @param logger: a logger to send progress logging to @type processes: int @param processes: number processes to spawn. A pool of this many processes will be used to compute distribution updates for sequences in parallel during each iteration. @type save: bool @param save: save the model at the end of training @type save_intermediate: bool @param save_intermediate: save the model after each iteration. Implies C{save} """ from . import raphsto_d if logger is None: from jazzparser.utils.loggers import create_dummy_logger logger = create_dummy_logger() if save_intermediate: save = True # No point in creating more processes than there are sequences if processes > len(emissions): processes = len(emissions) self.model.add_history("Beginning Baum-Welch unigram training on %s" % get_host_info_string()) self.model.add_history("Training on %d sequences (with %s chords)" % \ (len(emissions), ", ".join("%d" % len(seq) for seq in emissions))) # Use kwargs if given, otherwise module options if max_iterations is None: max_iterations = self.options['max_iterations'] if convergence_logprob is None: convergence_logprob = self.options['convergence_logprob'] # Enumerate the states state_ids = dict((state,num) for (num,state) in \ enumerate(self.model.label_dom)) # Enumerate the beat values (they're probably consecutive ints, but # let's not rely on it) beat_ids = dict((beat,num) for (num,beat) in \ enumerate(self.model.beat_dom)) num_beats = len(beat_ids) # Enumerate the d-values (d-function's domain) d_ids = dict((d,num) for (num,d) in \ enumerate(self.model.emission_dist_dom)) num_ds = len(d_ids) # Make a mutable distribution for the emission distribution we'll # be updating emission_mdist = DictionaryConditionalProbDist( dict((s, MutableProbDist(self.model.emission_dist[s], self.model.emission_dist_dom)) for s in self.model.emission_dist.conditions())) # Create dummy distributions to fill the places of the transition # distribution components key_mdist = DictionaryConditionalProbDist({}) chord_mdist = DictionaryConditionalProbDist({}) chord_uni_mdist = MutableProbDist({}, []) # Construct a model using these mutable distributions so we can # evaluate using them model = self.model_cls(key_mdist, chord_mdist, emission_mdist, chord_uni_mdist, chord_set=self.model.chord_set) iteration = 0 last_logprob = None while iteration < max_iterations: logger.info("Beginning iteration %d" % iteration) current_logprob = 0.0 # ems contains the new emission numerator probabilities # ems[r][d] = Sum_{d(y_n^k, x_n)=d, r_n^k=r} # alpha(x_n).beta(x_n) / # Sum_{x'_n} (alpha(x'_n).beta(x'_n)) ems = zeros((num_beats, num_ds), float64) # And these are the denominators ems_denom = zeros(num_beats, float64) def _training_callback(result): """ Callback for the _sequence_updates processes that takes the updates from a single sequence and adds them onto the global update accumulators. """ # _sequence_updates() returns all of this as a tuple (ems_local, ems_denom_local, seq_logprob) = result # Add these probabilities from this sequence to the # global matrices # Emission numerator array_add(ems, ems_local, ems) # Denominators array_add(ems_denom, ems_denom_local, ems_denom) ## End of _training_callback # Only use a process pool if there's more than one sequence if processes > 1: # Create a process pool to use for training logger.info("Creating a pool of %d processes" % processes) pool = Pool(processes=processes) async_results = [] for seq_i, sequence in enumerate(emissions): logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) T = len(sequence) if T == 0: continue # Fire off a new call to the process pool for every sequence async_results.append( pool.apply_async( _sequence_updates_uni, (sequence, model, self.model.label_dom, state_ids, beat_ids, d_ids, raphsto_d), callback=_training_callback)) pool.close() # Wait for all the workers to complete pool.join() # Call get() on every AsyncResult so that any exceptions in # workers get raised for res in async_results: # If there was an exception in _sequence_update, it # will get raised here res_tuple = res.get() # Add this sequence's logprob into the total for all sequences current_logprob += res_tuple[2] else: logger.info("One sequence: not using a process pool") sequence = emissions[0] if len(sequence) > 0: updates = _sequence_updates_uni(sequence, model, self.model.label_dom, state_ids, beat_ids, d_ids, raphsto_d) _training_callback(updates) # Update the overall logprob current_logprob = updates[2] # Update the model's probabilities from the accumulated values for beat in self.model.beat_dom: denom = ems_denom[beat_ids[beat]] for d in self.model.emission_dist_dom: if denom == 0.0: # Zero denominator prob = -logprob(len(d_ids)) else: prob = logprob(ems[beat_ids[beat]][d_ids[d]] + ADD_SMALL) - logprob( denom + len(d_ids) * ADD_SMALL) model.emission_dist[beat].update(d, prob) # Clear the model's cache so we get the new probabilities model.clear_cache() logger.info("Training data log prob: %s" % current_logprob) if last_logprob is not None and current_logprob < last_logprob: logger.error("Log probability dropped by %s" % \ (last_logprob - current_logprob)) if last_logprob is not None: logger.info("Log prob change: %s" % \ (current_logprob - last_logprob)) # Check whether the log probability has converged if iteration > 0 and \ abs(current_logprob - last_logprob) < convergence_logprob: # Don't iterate any more logger.info("Distribution has converged: ceasing training") break iteration += 1 last_logprob = current_logprob # Update the main model # Only save if we've been asked to save between iterations self.update_model(model, save=save_intermediate) self.model.add_history("Completed Baum-Welch unigram training") # Update the distribution's parameters with those we've trained self.update_model(model, save=save) return
def train(self, emissions, logger=None, save_callback=None): """ Performs unsupervised training using Baum-Welch EM. This is performed on a model that has already been initialized. You might, for example, create such a model using L{jazzparser.taggers.segmidi.chordclass.hmm.ChordClassHmm.initialize_chord_classes}. This is based on the training procedure in NLTK for HMMs: C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}. @type emissions: L{jazzparser.data.input.MidiTaggerTrainingBulkInput} or list of L{jazzparser.data.input.Input}s @param emissions: training MIDI data @type logger: logging.Logger @param logger: a logger to send progress logging to """ if logger is None: from jazzparser.utils.loggers import create_dummy_logger logger = create_dummy_logger() self.model.add_history("Beginning Baum-Welch training on %s" % get_host_info_string()) self.model.add_history("Training on %d MIDI sequences (with %s segments)" % \ (len(emissions), ", ".join("%d" % len(seq) for seq in emissions))) logger.info("Beginning Baum-Welch training on %s" % get_host_info_string()) # Get some options out of the module options max_iterations = self.options['max_iterations'] convergence_logprob = self.options['convergence_logprob'] split_length = self.options['split'] truncate_length = self.options['truncate'] save_intermediate = self.options['save_intermediate'] processes = self.options['trainprocs'] # Make a mutable distribution for each of the distributions # we'll be updating emission_mdist = cond_prob_dist_to_dictionary_cond_prob_dist( self.model.emission_dist, mutable=True) schema_trans_mdist = cond_prob_dist_to_dictionary_cond_prob_dist( self.model.schema_transition_dist, mutable=True) root_trans_mdist = cond_prob_dist_to_dictionary_cond_prob_dist( self.model.root_transition_dist, mutable=True) init_state_mdist = prob_dist_to_dictionary_prob_dist( self.model.initial_state_dist, mutable=True) # Get the sizes we'll need for the matrices num_schemata = len(self.model.schemata) num_root_changes = 12 num_chord_classes = len(self.model.chord_classes) if self.model.metric: num_emission_conds = num_chord_classes * 4 else: num_emission_conds = num_chord_classes num_emissions = 12 # Enumerations to use for the matrices, so we know what they mean schema_ids = dict([ (sch, i) for (i, sch) in enumerate(self.model.schemata + [None]) ]) if self.model.metric: rs = range(4) else: rs = [0] emission_cond_ids = dict([(cc,i) for (i,cc) in enumerate(\ sum([[ (str(cclass.name),r) for r in rs] for cclass in self.model.chord_classes], []))]) # Construct a model using these mutable distributions so we can # evaluate using them model = ChordClassHmm( schema_trans_mdist, root_trans_mdist, emission_mdist, self.model.emission_number_dist, init_state_mdist, self.model.schemata, self.model.chord_class_mapping, self.model.chord_classes, metric=self.model.metric, illegal_transitions=self.model.illegal_transitions, fixed_root_transitions=self.model.fixed_root_transitions) def _save(): if save_callback is None: logger.error("Could not save model, as no callback was given") else: # If the writing fails, wait till I've had a chance to sort it # out and then try again. This happens when my AFS token runs # out while True: try: save_callback() except (IOError, OSError), err: print "Error writing model to disk: %s. " % err raw_input("Press <enter> to try again... ") else: break
def train(self, emissions, logger=None): """ Performs unsupervised training using Baum-Welch EM. This is performed as a retraining step on a model that has already been initialized. This is based on the training procedure in NLTK for HMMs: C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}. @type emissions: list of lists of emissions @param emissions: training data. Each element is a list of emissions representing a sequence in the training data. Each emission is an emission like those used for C{emission_log_probability} on the model @type logger: logging.Logger @param logger: a logger to send progress logging to """ if logger is None: from jazzparser.utils.loggers import create_dummy_logger logger = create_dummy_logger() self.record_history("Beginning Baum-Welch training on %s" % get_host_info_string()) self.record_history("Training on %d inputs (with %s segments)" % \ (len(emissions), ", ".join("%d" % len(seq) for seq in emissions))) logger.info("Beginning Baum-Welch training on %s" % get_host_info_string()) # Get some options out of the module options max_iterations = self.options['max_iterations'] convergence_logprob = self.options['convergence_logprob'] split_length = self.options['split'] truncate_length = self.options['truncate'] save_intermediate = self.options['save_intermediate'] processes = self.options['trainprocs'] # Make a mutable version of the model that we can update each iteration self.model = self.create_mutable_model(self.model) # Getting the array id mappings array_ids = self.get_array_indices() ########## Data preprocessing logger.info("%d input sequences" % len(emissions)) # Truncate long streams if truncate_length is not None: logger.info("Truncating sequences to max %d timesteps" % \ truncate_length) emissions = [stream[:truncate_length] for stream in emissions] # Split up long streams if requested # After this, each stream is a tuple (first,stream), where first # indicates whether the stream segment begins a song if split_length is not None: logger.info("Splitting sequences into max %d-sized chunks" % \ split_length) split_emissions = [] # Split each stream for emstream in emissions: input_ems = list(emstream) splits = [] first = True # Take bits of length split_length until we're under the max while len(input_ems) >= split_length: # Overlap the splits by one so we get all transitions splits.append((first, input_ems[:split_length])) input_ems = input_ems[split_length - 1:] first = False # Get the last short one if len(input_ems): # Try to avoid having a small bit that's split off at the end if len(splits) and len(input_ems) <= split_length / 5: # Add these to the end of the last split # This will make it slightly longer than requested splits[-1][1].extend(input_ems) else: splits.append((first, input_ems)) split_emissions.extend(splits) else: # All streams begin a song split_emissions = [(True, stream) for stream in emissions] logger.info("Sequence lengths after preprocessing: %s" % " ".join([str(len(em[1])) for em in split_emissions])) ########## # Special case of -1 for number of sequences # No point in creating more processes than there are sequences if processes == -1 or processes > len(split_emissions): processes = len(split_emissions) iteration = 0 last_logprob = None while iteration < max_iterations: logger.info("Beginning iteration %d" % iteration) current_logprob = 0.0 # Build a tuple of the arrays that will be updated by each sequence self.global_arrays = self.get_empty_arrays() # Only use a process pool if there's more than one sequence if processes > 1: # Create a process pool to use for training logger.info("Creating a pool of %d processes" % processes) # catch them at this level pool = Pool(processes=processes) async_results = [] try: for seq_i, (first, sequence) in enumerate(split_emissions): logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) T = len(sequence) if T == 0: continue def _notifier_closure(seq_index): def _notifier(res): logger.info("Sequence %d finished" % seq_index) return _notifier # Create some empty arrays for the updates to go into empty_arrays = self.get_empty_arrays() # Fire off a new call to the process pool for every sequence async_results.append( pool.apply_async(self.sequence_updates, (sequence, self.model, empty_arrays, array_ids), {'update_initial': first}, _notifier_closure(seq_i))) pool.close() # Wait for all the workers to complete pool.join() except KeyboardInterrupt: # If Ctl+C is fired during the processing, we exit here logger.info("Keyboard interrupt was received during EM "\ "updates") raise # Call get() on every AsyncResult so that any exceptions in # workers get raised for res in async_results: # If there was an exception in sequence_updates, it # will get raised here res_tuple = res.get() # Run the callback on the results from this process # It might seem sensible to do this using the callback # arg to apply_async, but then the callback must be # picklable and it doesn't buy us anything really self.sequence_updates_callback(res_tuple) # Add this sequence's logprob into the total for all sequences current_logprob += res_tuple[-1] else: if len(split_emissions) == 1: logger.info("One sequence: not using a process pool") else: logger.info("Not using a process pool: training %d "\ "emission sequences sequentially" % \ len(split_emissions)) for seq_i, (first, sequence) in enumerate(split_emissions): if len(sequence) > 0: logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) # Create some empty arrays for the updates to go into empty_arrays = self.get_empty_arrays() updates = self.sequence_updates(sequence, self.model, empty_arrays, array_ids, update_initial=first) self.sequence_updates_callback(updates) # Update the overall logprob current_logprob += updates[-1] ######## Model updates # Update the main model self.update_model(self.global_arrays, array_ids) # Clear the model's cache so we get the new probabilities self.model.clear_cache() logger.info("Training data log prob: %s" % current_logprob) if last_logprob is not None and current_logprob < last_logprob: logger.error("Log probability dropped by %s" % \ (last_logprob - current_logprob)) if last_logprob is not None: logger.info("Log prob change: %s" % \ (current_logprob - last_logprob)) # Check whether the log probability has converged if iteration > 0 and \ abs(current_logprob - last_logprob) < convergence_logprob: # Don't iterate any more logger.info("Distribution has converged: ceasing training") break iteration += 1 last_logprob = current_logprob # Only save if we've been asked to save between iterations if save_intermediate: self.save() self.record_history("Completed Baum-Welch training") # Always save the model now that we're done self.save() return self.model
def train(self, emissions, logger=None): """ Performs unsupervised training using Baum-Welch EM. This is performed as a retraining step on a model that has already been initialized. This is based on the training procedure in NLTK for HMMs: C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}. @type emissions: list of lists of emissions @param emissions: training data. Each element is a list of emissions representing a sequence in the training data. Each emission is an emission like those used for C{emission_log_probability} on the model @type logger: logging.Logger @param logger: a logger to send progress logging to """ if logger is None: from jazzparser.utils.loggers import create_dummy_logger logger = create_dummy_logger() self.record_history("Beginning Baum-Welch training on %s" % get_host_info_string()) self.record_history("Training on %d inputs (with %s segments)" % \ (len(emissions), ", ".join("%d" % len(seq) for seq in emissions))) logger.info("Beginning Baum-Welch training on %s" % get_host_info_string()) # Get some options out of the module options max_iterations = self.options['max_iterations'] convergence_logprob = self.options['convergence_logprob'] split_length = self.options['split'] truncate_length = self.options['truncate'] save_intermediate = self.options['save_intermediate'] processes = self.options['trainprocs'] # Make a mutable version of the model that we can update each iteration self.model = self.create_mutable_model(self.model) # Getting the array id mappings array_ids = self.get_array_indices() ########## Data preprocessing logger.info("%d input sequences" % len(emissions)) # Truncate long streams if truncate_length is not None: logger.info("Truncating sequences to max %d timesteps" % \ truncate_length) emissions = [stream[:truncate_length] for stream in emissions] # Split up long streams if requested # After this, each stream is a tuple (first,stream), where first # indicates whether the stream segment begins a song if split_length is not None: logger.info("Splitting sequences into max %d-sized chunks" % \ split_length) split_emissions = [] # Split each stream for emstream in emissions: input_ems = list(emstream) splits = [] first = True # Take bits of length split_length until we're under the max while len(input_ems) >= split_length: # Overlap the splits by one so we get all transitions splits.append((first, input_ems[:split_length])) input_ems = input_ems[split_length-1:] first = False # Get the last short one if len(input_ems): # Try to avoid having a small bit that's split off at the end if len(splits) and len(input_ems) <= split_length / 5: # Add these to the end of the last split # This will make it slightly longer than requested splits[-1][1].extend(input_ems) else: splits.append((first, input_ems)) split_emissions.extend(splits) else: # All streams begin a song split_emissions = [(True,stream) for stream in emissions] logger.info("Sequence lengths after preprocessing: %s" % " ".join([str(len(em[1])) for em in split_emissions])) ########## # Special case of -1 for number of sequences # No point in creating more processes than there are sequences if processes == -1 or processes > len(split_emissions): processes = len(split_emissions) iteration = 0 last_logprob = None while iteration < max_iterations: logger.info("Beginning iteration %d" % iteration) current_logprob = 0.0 # Build a tuple of the arrays that will be updated by each sequence self.global_arrays = self.get_empty_arrays() # Only use a process pool if there's more than one sequence if processes > 1: # Create a process pool to use for training logger.info("Creating a pool of %d processes" % processes) # catch them at this level pool = Pool(processes=processes) async_results = [] try: for seq_i,(first,sequence) in enumerate(split_emissions): logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) T = len(sequence) if T == 0: continue def _notifier_closure(seq_index): def _notifier(res): logger.info("Sequence %d finished" % seq_index) return _notifier # Create some empty arrays for the updates to go into empty_arrays = self.get_empty_arrays() # Fire off a new call to the process pool for every sequence async_results.append( pool.apply_async(self.sequence_updates, (sequence, self.model, empty_arrays, array_ids), { 'update_initial' : first }, _notifier_closure(seq_i)) ) pool.close() # Wait for all the workers to complete pool.join() except KeyboardInterrupt: # If Ctl+C is fired during the processing, we exit here logger.info("Keyboard interrupt was received during EM "\ "updates") raise # Call get() on every AsyncResult so that any exceptions in # workers get raised for res in async_results: # If there was an exception in sequence_updates, it # will get raised here res_tuple = res.get() # Run the callback on the results from this process # It might seem sensible to do this using the callback # arg to apply_async, but then the callback must be # picklable and it doesn't buy us anything really self.sequence_updates_callback(res_tuple) # Add this sequence's logprob into the total for all sequences current_logprob += res_tuple[-1] else: if len(split_emissions) == 1: logger.info("One sequence: not using a process pool") else: logger.info("Not using a process pool: training %d "\ "emission sequences sequentially" % \ len(split_emissions)) for seq_i,(first,sequence) in enumerate(split_emissions): if len(sequence) > 0: logger.info("Iteration %d, sequence %d" % (iteration, seq_i)) # Create some empty arrays for the updates to go into empty_arrays = self.get_empty_arrays() updates = self.sequence_updates( sequence, self.model, empty_arrays, array_ids, update_initial=first) self.sequence_updates_callback(updates) # Update the overall logprob current_logprob += updates[-1] ######## Model updates # Update the main model self.update_model(self.global_arrays, array_ids) # Clear the model's cache so we get the new probabilities self.model.clear_cache() logger.info("Training data log prob: %s" % current_logprob) if last_logprob is not None and current_logprob < last_logprob: logger.error("Log probability dropped by %s" % \ (last_logprob - current_logprob)) if last_logprob is not None: logger.info("Log prob change: %s" % \ (current_logprob - last_logprob)) # Check whether the log probability has converged if iteration > 0 and \ abs(current_logprob - last_logprob) < convergence_logprob: # Don't iterate any more logger.info("Distribution has converged: ceasing training") break iteration += 1 last_logprob = current_logprob # Only save if we've been asked to save between iterations if save_intermediate: self.save() self.record_history("Completed Baum-Welch training") # Always save the model now that we're done self.save() return self.model