コード例 #1
0
ファイル: train.py プロジェクト: johndpope/jazzparser
 def train(self, emissions, max_iterations=None, \
                 convergence_logprob=None, logger=None, processes=1,
                 save=True, save_intermediate=False):
     """
     Performs unsupervised training using Baum-Welch EM.
     
     This is an instance method, because it is performed on a model 
     that has already been initialized. You might, for example, 
     create such a model using C{initialize_chord_types}.
     
     This is based on the training procedure in NLTK for HMMs:
     C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}.
     
     @type emissions: list of lists of emissions
     @param emissions: training data. Each element is a list of 
         emissions representing a sequence in the training data.
         Each emission is an emission like those used for 
         L{jazzparser.misc.raphsto.RaphstoHmm.emission_log_probability}, 
         i.e. a list of note 
         observations
     @type max_iterations: int
     @param max_iterations: maximum number of iterations to allow 
         for EM (default 100). Overrides the corresponding 
         module option
     @type convergence_logprob: float
     @param convergence_logprob: maximum change in log probability 
         to consider convergence to have been reached (default 1e-3). 
         Overrides the corresponding module option
     @type logger: logging.Logger
     @param logger: a logger to send progress logging to
     @type processes: int
     @param processes: number processes to spawn. A pool of this 
         many processes will be used to compute distribution updates 
         for sequences in parallel during each iteration.
     @type save: bool
     @param save: save the model at the end of training
     @type save_intermediate: bool
     @param save_intermediate: save the model after each iteration. Implies 
         C{save}
     
     """
     from . import raphsto_d
     if logger is None:
         from jazzparser.utils.loggers import create_dummy_logger
         logger = create_dummy_logger()
     
     if save_intermediate:
         save = True
         
     # No point in creating more processes than there are sequences
     if processes > len(emissions):
         processes = len(emissions)
     
     self.model.add_history("Beginning Baum-Welch unigram training on %s" % get_host_info_string())
     self.model.add_history("Training on %d sequences (with %s chords)" % \
         (len(emissions), ", ".join("%d" % len(seq) for seq in emissions)))
     
     # Use kwargs if given, otherwise module options
     if max_iterations is None:
         max_iterations = self.options['max_iterations']
     if convergence_logprob is None:
         convergence_logprob = self.options['convergence_logprob']
     
     # Enumerate the states
     state_ids = dict((state,num) for (num,state) in \
                                 enumerate(self.model.label_dom))
     
     # Enumerate the beat values (they're probably consecutive ints, but 
     #  let's not rely on it)
     beat_ids = dict((beat,num) for (num,beat) in \
                                 enumerate(self.model.beat_dom))
     num_beats = len(beat_ids)
     # Enumerate the d-values (d-function's domain)
     d_ids = dict((d,num) for (num,d) in \
                                 enumerate(self.model.emission_dist_dom))
     num_ds = len(d_ids)
     
     # Make a mutable distribution for the emission distribution we'll 
     #  be updating
     emission_mdist = DictionaryConditionalProbDist(
                 dict((s, MutableProbDist(self.model.emission_dist[s], 
                                          self.model.emission_dist_dom))
                     for s in self.model.emission_dist.conditions()))
     # Create dummy distributions to fill the places of the transition 
     #  distribution components
     key_mdist = DictionaryConditionalProbDist({})
     chord_mdist = DictionaryConditionalProbDist({})
     chord_uni_mdist = MutableProbDist({}, [])
     
     # Construct a model using these mutable distributions so we can 
     #  evaluate using them
     model = self.model_cls(key_mdist, 
                            chord_mdist,
                            emission_mdist, 
                            chord_uni_mdist,
                            chord_set=self.model.chord_set)
     
     iteration = 0
     last_logprob = None
     while iteration < max_iterations:
         logger.info("Beginning iteration %d" % iteration)
         current_logprob = 0.0
         
         # ems contains the new emission numerator probabilities
         # ems[r][d] = Sum_{d(y_n^k, x_n)=d, r_n^k=r}
         #                  alpha(x_n).beta(x_n) / 
         #                    Sum_{x'_n} (alpha(x'_n).beta(x'_n))
         ems = zeros((num_beats,num_ds), float64)
         # And these are the denominators
         ems_denom = zeros(num_beats, float64)
         
         def _training_callback(result):
             """
             Callback for the _sequence_updates processes that takes 
             the updates from a single sequence and adds them onto 
             the global update accumulators.
             
             """
             # _sequence_updates() returns all of this as a tuple
             (ems_local, ems_denom_local, seq_logprob) = result
             
             # Add these probabilities from this sequence to the 
             #  global matrices
             # Emission numerator
             array_add(ems, ems_local, ems)
             # Denominators
             array_add(ems_denom, ems_denom_local, ems_denom)
         ## End of _training_callback
         
         
         # Only use a process pool if there's more than one sequence
         if processes > 1:
             # Create a process pool to use for training
             logger.info("Creating a pool of %d processes" % processes)
             pool = Pool(processes=processes)
             
             async_results = []
             for seq_i,sequence in enumerate(emissions):
                 logger.info("Iteration %d, sequence %d" % (iteration, seq_i))
                 T = len(sequence)
                 if T == 0:
                     continue
                 
                 # Fire off a new call to the process pool for every sequence
                 async_results.append(
                         pool.apply_async(_sequence_updates_uni, 
                                             (sequence, model, 
                                                 self.model.label_dom, 
                                                 state_ids, 
                                                 beat_ids, d_ids, raphsto_d), 
                                             callback=_training_callback) )
             pool.close()
             # Wait for all the workers to complete
             pool.join()
             
             # Call get() on every AsyncResult so that any exceptions in 
             #  workers get raised
             for res in async_results:
                 # If there was an exception in _sequence_update, it 
                 #  will get raised here
                 res_tuple = res.get()
                 # Add this sequence's logprob into the total for all sequences
                 current_logprob += res_tuple[2]
         else:
             logger.info("One sequence: not using a process pool")
             sequence = emissions[0]
             
             if len(sequence) > 0:
                 updates = _sequence_updates_uni(
                                     sequence, model,
                                     self.model.label_dom,
                                     state_ids, 
                                     beat_ids, d_ids, raphsto_d)
                 _training_callback(updates)
                 # Update the overall logprob
                 current_logprob = updates[2]
         
         # Update the model's probabilities from the accumulated values
         for beat in self.model.beat_dom:
             denom = ems_denom[beat_ids[beat]]
             for d in self.model.emission_dist_dom:
                 if denom == 0.0:
                     # Zero denominator
                     prob = - logprob(len(d_ids))
                 else:
                     prob = logprob(ems[beat_ids[beat]][d_ids[d]] + ADD_SMALL) - logprob(denom + len(d_ids)*ADD_SMALL)
                 model.emission_dist[beat].update(d, prob)
         
         # Clear the model's cache so we get the new probabilities
         model.clear_cache()
         
         logger.info("Training data log prob: %s" % current_logprob)
         if last_logprob is not None and current_logprob < last_logprob:
             logger.error("Log probability dropped by %s" % \
                             (last_logprob - current_logprob))
         if last_logprob is not None:
             logger.info("Log prob change: %s" % \
                             (current_logprob - last_logprob))
         # Check whether the log probability has converged
         if iteration > 0 and \
                 abs(current_logprob - last_logprob) < convergence_logprob:
             # Don't iterate any more
             logger.info("Distribution has converged: ceasing training")
             break
         
         iteration += 1
         last_logprob = current_logprob
         
         # Update the main model
         # Only save if we've been asked to save between iterations
         self.update_model(model, save=save_intermediate)
     
     self.model.add_history("Completed Baum-Welch unigram training")
     # Update the distribution's parameters with those we've trained
     self.update_model(model, save=save)
     return
コード例 #2
0
ファイル: train.py プロジェクト: johndpope/jazzparser
 def train(self, emissions, logger=None, save_callback=None):
     """
     Performs unsupervised training using Baum-Welch EM.
     
     This is performed on a model that has already been initialized. 
     You might, for example, create such a model using 
     L{jazzparser.taggers.segmidi.chordclass.hmm.ChordClassHmm.initialize_chord_classes}.
     
     This is based on the training procedure in NLTK for HMMs:
     C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}.
     
     @type emissions: L{jazzparser.data.input.MidiTaggerTrainingBulkInput} or 
         list of L{jazzparser.data.input.Input}s
     @param emissions: training MIDI data
     @type logger: logging.Logger
     @param logger: a logger to send progress logging to
     
     """
     if logger is None:
         from jazzparser.utils.loggers import create_dummy_logger
         logger = create_dummy_logger()
         
     self.model.add_history("Beginning Baum-Welch training on %s" % get_host_info_string())
     self.model.add_history("Training on %d MIDI sequences (with %s segments)" % \
         (len(emissions), ", ".join("%d" % len(seq) for seq in emissions)))
     logger.info("Beginning Baum-Welch training on %s" % get_host_info_string())
     
     # Get some options out of the module options
     max_iterations = self.options['max_iterations']
     convergence_logprob = self.options['convergence_logprob']
     split_length = self.options['split']
     truncate_length = self.options['truncate']
     save_intermediate = self.options['save_intermediate']
     processes = self.options['trainprocs']
     
     # Make a mutable distribution for each of the distributions 
     #  we'll be updating
     emission_mdist = cond_prob_dist_to_dictionary_cond_prob_dist(
                                 self.model.emission_dist, mutable=True)
     schema_trans_mdist = cond_prob_dist_to_dictionary_cond_prob_dist(
                                 self.model.schema_transition_dist, mutable=True)
     root_trans_mdist = cond_prob_dist_to_dictionary_cond_prob_dist(
                                 self.model.root_transition_dist, mutable=True)
     init_state_mdist = prob_dist_to_dictionary_prob_dist(
                                 self.model.initial_state_dist, mutable=True)
     
     # Get the sizes we'll need for the matrices
     num_schemata = len(self.model.schemata)
     num_root_changes = 12
     num_chord_classes = len(self.model.chord_classes)
     if self.model.metric:
         num_emission_conds = num_chord_classes * 4
     else:
         num_emission_conds = num_chord_classes
     num_emissions = 12
     
     # Enumerations to use for the matrices, so we know what they mean
     schema_ids = dict([(sch,i) for (i,sch) in enumerate(self.model.schemata+[None])])
     if self.model.metric:
         rs = range(4)
     else:
         rs = [0]
     emission_cond_ids = dict([(cc,i) for (i,cc) in enumerate(\
             sum([[
                 (str(cclass.name),r) for r in rs] for cclass in self.model.chord_classes], 
             []))])
     
     # Construct a model using these mutable distributions so we can 
     #  evaluate using them
     model = ChordClassHmm(schema_trans_mdist, 
                        root_trans_mdist, 
                        emission_mdist, 
                        self.model.emission_number_dist, 
                        init_state_mdist, 
                        self.model.schemata, 
                        self.model.chord_class_mapping,
                        self.model.chord_classes, 
                        metric=self.model.metric,
                        illegal_transitions=self.model.illegal_transitions,
                        fixed_root_transitions=self.model.fixed_root_transitions)
     
     def _save():
         if save_callback is None:
             logger.error("Could not save model, as no callback was given")
         else:
             # If the writing fails, wait till I've had a chance to sort it 
             #  out and then try again. This happens when my AFS token runs 
             #  out
             while True:
                 try:
                     save_callback()
                 except (IOError, OSError), err:
                     print "Error writing model to disk: %s. " % err
                     raw_input("Press <enter> to try again... ")
                 else:
                     break
コード例 #3
0
    def train(self, emissions, max_iterations=None, \
                    convergence_logprob=None, logger=None, processes=1,
                    save=True, save_intermediate=False):
        """
        Performs unsupervised training using Baum-Welch EM.
        
        This is an instance method, because it is performed on a model 
        that has already been initialized. You might, for example, 
        create such a model using C{initialize_chord_types}.
        
        This is based on the training procedure in NLTK for HMMs:
        C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}.
        
        @type emissions: list of lists of emissions
        @param emissions: training data. Each element is a list of 
            emissions representing a sequence in the training data.
            Each emission is an emission like those used for 
            L{jazzparser.misc.raphsto.RaphstoHmm.emission_log_probability}, 
            i.e. a list of note 
            observations
        @type max_iterations: int
        @param max_iterations: maximum number of iterations to allow 
            for EM (default 100). Overrides the corresponding 
            module option
        @type convergence_logprob: float
        @param convergence_logprob: maximum change in log probability 
            to consider convergence to have been reached (default 1e-3). 
            Overrides the corresponding module option
        @type logger: logging.Logger
        @param logger: a logger to send progress logging to
        @type processes: int
        @param processes: number processes to spawn. A pool of this 
            many processes will be used to compute distribution updates 
            for sequences in parallel during each iteration.
        @type save: bool
        @param save: save the model at the end of training
        @type save_intermediate: bool
        @param save_intermediate: save the model after each iteration. Implies 
            C{save}
        
        """
        from . import raphsto_d
        if logger is None:
            from jazzparser.utils.loggers import create_dummy_logger
            logger = create_dummy_logger()

        if save_intermediate:
            save = True

        # No point in creating more processes than there are sequences
        if processes > len(emissions):
            processes = len(emissions)

        self.model.add_history("Beginning Baum-Welch unigram training on %s" %
                               get_host_info_string())
        self.model.add_history("Training on %d sequences (with %s chords)" % \
            (len(emissions), ", ".join("%d" % len(seq) for seq in emissions)))

        # Use kwargs if given, otherwise module options
        if max_iterations is None:
            max_iterations = self.options['max_iterations']
        if convergence_logprob is None:
            convergence_logprob = self.options['convergence_logprob']

        # Enumerate the states
        state_ids = dict((state,num) for (num,state) in \
                                    enumerate(self.model.label_dom))

        # Enumerate the beat values (they're probably consecutive ints, but
        #  let's not rely on it)
        beat_ids = dict((beat,num) for (num,beat) in \
                                    enumerate(self.model.beat_dom))
        num_beats = len(beat_ids)
        # Enumerate the d-values (d-function's domain)
        d_ids = dict((d,num) for (num,d) in \
                                    enumerate(self.model.emission_dist_dom))
        num_ds = len(d_ids)

        # Make a mutable distribution for the emission distribution we'll
        #  be updating
        emission_mdist = DictionaryConditionalProbDist(
            dict((s,
                  MutableProbDist(self.model.emission_dist[s],
                                  self.model.emission_dist_dom))
                 for s in self.model.emission_dist.conditions()))
        # Create dummy distributions to fill the places of the transition
        #  distribution components
        key_mdist = DictionaryConditionalProbDist({})
        chord_mdist = DictionaryConditionalProbDist({})
        chord_uni_mdist = MutableProbDist({}, [])

        # Construct a model using these mutable distributions so we can
        #  evaluate using them
        model = self.model_cls(key_mdist,
                               chord_mdist,
                               emission_mdist,
                               chord_uni_mdist,
                               chord_set=self.model.chord_set)

        iteration = 0
        last_logprob = None
        while iteration < max_iterations:
            logger.info("Beginning iteration %d" % iteration)
            current_logprob = 0.0

            # ems contains the new emission numerator probabilities
            # ems[r][d] = Sum_{d(y_n^k, x_n)=d, r_n^k=r}
            #                  alpha(x_n).beta(x_n) /
            #                    Sum_{x'_n} (alpha(x'_n).beta(x'_n))
            ems = zeros((num_beats, num_ds), float64)
            # And these are the denominators
            ems_denom = zeros(num_beats, float64)

            def _training_callback(result):
                """
                Callback for the _sequence_updates processes that takes 
                the updates from a single sequence and adds them onto 
                the global update accumulators.
                
                """
                # _sequence_updates() returns all of this as a tuple
                (ems_local, ems_denom_local, seq_logprob) = result

                # Add these probabilities from this sequence to the
                #  global matrices
                # Emission numerator
                array_add(ems, ems_local, ems)
                # Denominators
                array_add(ems_denom, ems_denom_local, ems_denom)

            ## End of _training_callback

            # Only use a process pool if there's more than one sequence
            if processes > 1:
                # Create a process pool to use for training
                logger.info("Creating a pool of %d processes" % processes)
                pool = Pool(processes=processes)

                async_results = []
                for seq_i, sequence in enumerate(emissions):
                    logger.info("Iteration %d, sequence %d" %
                                (iteration, seq_i))
                    T = len(sequence)
                    if T == 0:
                        continue

                    # Fire off a new call to the process pool for every sequence
                    async_results.append(
                        pool.apply_async(
                            _sequence_updates_uni,
                            (sequence, model, self.model.label_dom, state_ids,
                             beat_ids, d_ids, raphsto_d),
                            callback=_training_callback))
                pool.close()
                # Wait for all the workers to complete
                pool.join()

                # Call get() on every AsyncResult so that any exceptions in
                #  workers get raised
                for res in async_results:
                    # If there was an exception in _sequence_update, it
                    #  will get raised here
                    res_tuple = res.get()
                    # Add this sequence's logprob into the total for all sequences
                    current_logprob += res_tuple[2]
            else:
                logger.info("One sequence: not using a process pool")
                sequence = emissions[0]

                if len(sequence) > 0:
                    updates = _sequence_updates_uni(sequence, model,
                                                    self.model.label_dom,
                                                    state_ids, beat_ids, d_ids,
                                                    raphsto_d)
                    _training_callback(updates)
                    # Update the overall logprob
                    current_logprob = updates[2]

            # Update the model's probabilities from the accumulated values
            for beat in self.model.beat_dom:
                denom = ems_denom[beat_ids[beat]]
                for d in self.model.emission_dist_dom:
                    if denom == 0.0:
                        # Zero denominator
                        prob = -logprob(len(d_ids))
                    else:
                        prob = logprob(ems[beat_ids[beat]][d_ids[d]] +
                                       ADD_SMALL) - logprob(
                                           denom + len(d_ids) * ADD_SMALL)
                    model.emission_dist[beat].update(d, prob)

            # Clear the model's cache so we get the new probabilities
            model.clear_cache()

            logger.info("Training data log prob: %s" % current_logprob)
            if last_logprob is not None and current_logprob < last_logprob:
                logger.error("Log probability dropped by %s" % \
                                (last_logprob - current_logprob))
            if last_logprob is not None:
                logger.info("Log prob change: %s" % \
                                (current_logprob - last_logprob))
            # Check whether the log probability has converged
            if iteration > 0 and \
                    abs(current_logprob - last_logprob) < convergence_logprob:
                # Don't iterate any more
                logger.info("Distribution has converged: ceasing training")
                break

            iteration += 1
            last_logprob = current_logprob

            # Update the main model
            # Only save if we've been asked to save between iterations
            self.update_model(model, save=save_intermediate)

        self.model.add_history("Completed Baum-Welch unigram training")
        # Update the distribution's parameters with those we've trained
        self.update_model(model, save=save)
        return
コード例 #4
0
ファイル: train.py プロジェクト: johndpope/jazzparser
    def train(self, emissions, logger=None, save_callback=None):
        """
        Performs unsupervised training using Baum-Welch EM.
        
        This is performed on a model that has already been initialized. 
        You might, for example, create such a model using 
        L{jazzparser.taggers.segmidi.chordclass.hmm.ChordClassHmm.initialize_chord_classes}.
        
        This is based on the training procedure in NLTK for HMMs:
        C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}.
        
        @type emissions: L{jazzparser.data.input.MidiTaggerTrainingBulkInput} or 
            list of L{jazzparser.data.input.Input}s
        @param emissions: training MIDI data
        @type logger: logging.Logger
        @param logger: a logger to send progress logging to
        
        """
        if logger is None:
            from jazzparser.utils.loggers import create_dummy_logger
            logger = create_dummy_logger()

        self.model.add_history("Beginning Baum-Welch training on %s" %
                               get_host_info_string())
        self.model.add_history("Training on %d MIDI sequences (with %s segments)" % \
            (len(emissions), ", ".join("%d" % len(seq) for seq in emissions)))
        logger.info("Beginning Baum-Welch training on %s" %
                    get_host_info_string())

        # Get some options out of the module options
        max_iterations = self.options['max_iterations']
        convergence_logprob = self.options['convergence_logprob']
        split_length = self.options['split']
        truncate_length = self.options['truncate']
        save_intermediate = self.options['save_intermediate']
        processes = self.options['trainprocs']

        # Make a mutable distribution for each of the distributions
        #  we'll be updating
        emission_mdist = cond_prob_dist_to_dictionary_cond_prob_dist(
            self.model.emission_dist, mutable=True)
        schema_trans_mdist = cond_prob_dist_to_dictionary_cond_prob_dist(
            self.model.schema_transition_dist, mutable=True)
        root_trans_mdist = cond_prob_dist_to_dictionary_cond_prob_dist(
            self.model.root_transition_dist, mutable=True)
        init_state_mdist = prob_dist_to_dictionary_prob_dist(
            self.model.initial_state_dist, mutable=True)

        # Get the sizes we'll need for the matrices
        num_schemata = len(self.model.schemata)
        num_root_changes = 12
        num_chord_classes = len(self.model.chord_classes)
        if self.model.metric:
            num_emission_conds = num_chord_classes * 4
        else:
            num_emission_conds = num_chord_classes
        num_emissions = 12

        # Enumerations to use for the matrices, so we know what they mean
        schema_ids = dict([
            (sch, i) for (i, sch) in enumerate(self.model.schemata + [None])
        ])
        if self.model.metric:
            rs = range(4)
        else:
            rs = [0]
        emission_cond_ids = dict([(cc,i) for (i,cc) in enumerate(\
                sum([[
                    (str(cclass.name),r) for r in rs] for cclass in self.model.chord_classes],
                []))])

        # Construct a model using these mutable distributions so we can
        #  evaluate using them
        model = ChordClassHmm(
            schema_trans_mdist,
            root_trans_mdist,
            emission_mdist,
            self.model.emission_number_dist,
            init_state_mdist,
            self.model.schemata,
            self.model.chord_class_mapping,
            self.model.chord_classes,
            metric=self.model.metric,
            illegal_transitions=self.model.illegal_transitions,
            fixed_root_transitions=self.model.fixed_root_transitions)

        def _save():
            if save_callback is None:
                logger.error("Could not save model, as no callback was given")
            else:
                # If the writing fails, wait till I've had a chance to sort it
                #  out and then try again. This happens when my AFS token runs
                #  out
                while True:
                    try:
                        save_callback()
                    except (IOError, OSError), err:
                        print "Error writing model to disk: %s. " % err
                        raw_input("Press <enter> to try again... ")
                    else:
                        break
コード例 #5
0
ファイル: baumwelch.py プロジェクト: johndpope/jazzparser
    def train(self, emissions, logger=None):
        """
        Performs unsupervised training using Baum-Welch EM.
        
        This is performed as a retraining step on a model that has already 
        been initialized. 
        
        This is based on the training procedure in NLTK for HMMs:
        C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}.
        
        @type emissions: list of lists of emissions
        @param emissions: training data. Each element is a list of 
            emissions representing a sequence in the training data.
            Each emission is an emission like those used for 
            C{emission_log_probability} on the model
        @type logger: logging.Logger
        @param logger: a logger to send progress logging to
        
        """
        if logger is None:
            from jazzparser.utils.loggers import create_dummy_logger
            logger = create_dummy_logger()

        self.record_history("Beginning Baum-Welch training on %s" %
                            get_host_info_string())
        self.record_history("Training on %d inputs (with %s segments)" % \
            (len(emissions), ", ".join("%d" % len(seq) for seq in emissions)))
        logger.info("Beginning Baum-Welch training on %s" %
                    get_host_info_string())

        # Get some options out of the module options
        max_iterations = self.options['max_iterations']
        convergence_logprob = self.options['convergence_logprob']
        split_length = self.options['split']
        truncate_length = self.options['truncate']
        save_intermediate = self.options['save_intermediate']
        processes = self.options['trainprocs']

        # Make a mutable version of the model that we can update each iteration
        self.model = self.create_mutable_model(self.model)
        # Getting the array id mappings
        array_ids = self.get_array_indices()

        ########## Data preprocessing
        logger.info("%d input sequences" % len(emissions))
        # Truncate long streams
        if truncate_length is not None:
            logger.info("Truncating sequences to max %d timesteps" % \
                                                            truncate_length)
            emissions = [stream[:truncate_length] for stream in emissions]
        # Split up long streams if requested
        # After this, each stream is a tuple (first,stream), where first
        #  indicates whether the stream segment begins a song
        if split_length is not None:
            logger.info("Splitting sequences into max %d-sized chunks" % \
                                                                split_length)
            split_emissions = []
            # Split each stream
            for emstream in emissions:
                input_ems = list(emstream)
                splits = []
                first = True
                # Take bits of length split_length until we're under the max
                while len(input_ems) >= split_length:
                    # Overlap the splits by one so we get all transitions
                    splits.append((first, input_ems[:split_length]))
                    input_ems = input_ems[split_length - 1:]
                    first = False
                # Get the last short one
                if len(input_ems):
                    # Try to avoid having a small bit that's split off at the end
                    if len(splits) and len(input_ems) <= split_length / 5:
                        # Add these to the end of the last split
                        # This will make it slightly longer than requested
                        splits[-1][1].extend(input_ems)
                    else:
                        splits.append((first, input_ems))
                split_emissions.extend(splits)
        else:
            # All streams begin a song
            split_emissions = [(True, stream) for stream in emissions]
        logger.info("Sequence lengths after preprocessing: %s" %
                    " ".join([str(len(em[1])) for em in split_emissions]))
        ##########

        # Special case of -1 for number of sequences
        # No point in creating more processes than there are sequences
        if processes == -1 or processes > len(split_emissions):
            processes = len(split_emissions)

        iteration = 0
        last_logprob = None
        while iteration < max_iterations:
            logger.info("Beginning iteration %d" % iteration)
            current_logprob = 0.0

            # Build a tuple of the arrays that will be updated by each sequence
            self.global_arrays = self.get_empty_arrays()

            # Only use a process pool if there's more than one sequence
            if processes > 1:
                # Create a process pool to use for training
                logger.info("Creating a pool of %d processes" % processes)
                #  catch them at this level
                pool = Pool(processes=processes)

                async_results = []
                try:
                    for seq_i, (first, sequence) in enumerate(split_emissions):
                        logger.info("Iteration %d, sequence %d" %
                                    (iteration, seq_i))
                        T = len(sequence)
                        if T == 0:
                            continue

                        def _notifier_closure(seq_index):
                            def _notifier(res):
                                logger.info("Sequence %d finished" % seq_index)

                            return _notifier

                        # Create some empty arrays for the updates to go into
                        empty_arrays = self.get_empty_arrays()
                        # Fire off a new call to the process pool for every sequence
                        async_results.append(
                            pool.apply_async(self.sequence_updates,
                                             (sequence, self.model,
                                              empty_arrays, array_ids),
                                             {'update_initial': first},
                                             _notifier_closure(seq_i)))
                    pool.close()
                    # Wait for all the workers to complete
                    pool.join()
                except KeyboardInterrupt:
                    # If Ctl+C is fired during the processing, we exit here
                    logger.info("Keyboard interrupt was received during EM "\
                        "updates")
                    raise

                # Call get() on every AsyncResult so that any exceptions in
                #  workers get raised
                for res in async_results:
                    # If there was an exception in sequence_updates, it
                    #  will get raised here
                    res_tuple = res.get()
                    # Run the callback on the results from this process
                    # It might seem sensible to do this using the callback
                    #  arg to apply_async, but then the callback must be
                    #  picklable and it doesn't buy us anything really
                    self.sequence_updates_callback(res_tuple)
                    # Add this sequence's logprob into the total for all sequences
                    current_logprob += res_tuple[-1]
            else:
                if len(split_emissions) == 1:
                    logger.info("One sequence: not using a process pool")
                else:
                    logger.info("Not using a process pool: training %d "\
                        "emission sequences sequentially" % \
                        len(split_emissions))

                for seq_i, (first, sequence) in enumerate(split_emissions):
                    if len(sequence) > 0:
                        logger.info("Iteration %d, sequence %d" %
                                    (iteration, seq_i))
                        # Create some empty arrays for the updates to go into
                        empty_arrays = self.get_empty_arrays()
                        updates = self.sequence_updates(sequence,
                                                        self.model,
                                                        empty_arrays,
                                                        array_ids,
                                                        update_initial=first)
                        self.sequence_updates_callback(updates)
                        # Update the overall logprob
                        current_logprob += updates[-1]

            ######## Model updates
            # Update the main model
            self.update_model(self.global_arrays, array_ids)

            # Clear the model's cache so we get the new probabilities
            self.model.clear_cache()

            logger.info("Training data log prob: %s" % current_logprob)
            if last_logprob is not None and current_logprob < last_logprob:
                logger.error("Log probability dropped by %s" % \
                                (last_logprob - current_logprob))
            if last_logprob is not None:
                logger.info("Log prob change: %s" % \
                                (current_logprob - last_logprob))
            # Check whether the log probability has converged
            if iteration > 0 and \
                    abs(current_logprob - last_logprob) < convergence_logprob:
                # Don't iterate any more
                logger.info("Distribution has converged: ceasing training")
                break

            iteration += 1
            last_logprob = current_logprob

            # Only save if we've been asked to save between iterations
            if save_intermediate:
                self.save()

        self.record_history("Completed Baum-Welch training")
        # Always save the model now that we're done
        self.save()
        return self.model
コード例 #6
0
ファイル: baumwelch.py プロジェクト: johndpope/jazzparser
 def train(self, emissions, logger=None):
     """
     Performs unsupervised training using Baum-Welch EM.
     
     This is performed as a retraining step on a model that has already 
     been initialized. 
     
     This is based on the training procedure in NLTK for HMMs:
     C{nltk.tag.hmm.HiddenMarkovModelTrainer.train_unsupervised}.
     
     @type emissions: list of lists of emissions
     @param emissions: training data. Each element is a list of 
         emissions representing a sequence in the training data.
         Each emission is an emission like those used for 
         C{emission_log_probability} on the model
     @type logger: logging.Logger
     @param logger: a logger to send progress logging to
     
     """
     if logger is None:
         from jazzparser.utils.loggers import create_dummy_logger
         logger = create_dummy_logger()
         
     self.record_history("Beginning Baum-Welch training on %s" % get_host_info_string())
     self.record_history("Training on %d inputs (with %s segments)" % \
         (len(emissions), ", ".join("%d" % len(seq) for seq in emissions)))
     logger.info("Beginning Baum-Welch training on %s" % get_host_info_string())
     
     # Get some options out of the module options
     max_iterations = self.options['max_iterations']
     convergence_logprob = self.options['convergence_logprob']
     split_length = self.options['split']
     truncate_length = self.options['truncate']
     save_intermediate = self.options['save_intermediate']
     processes = self.options['trainprocs']
     
     # Make a mutable version of the model that we can update each iteration
     self.model = self.create_mutable_model(self.model)
     # Getting the array id mappings
     array_ids = self.get_array_indices()
     
     ########## Data preprocessing
     logger.info("%d input sequences" % len(emissions))
     # Truncate long streams
     if truncate_length is not None:
         logger.info("Truncating sequences to max %d timesteps" % \
                                                         truncate_length)
         emissions = [stream[:truncate_length] for stream in emissions]
     # Split up long streams if requested
     # After this, each stream is a tuple (first,stream), where first 
     #  indicates whether the stream segment begins a song
     if split_length is not None:
         logger.info("Splitting sequences into max %d-sized chunks" % \
                                                             split_length)
         split_emissions = []
         # Split each stream
         for emstream in emissions:
             input_ems = list(emstream)
             splits = []
             first = True
             # Take bits of length split_length until we're under the max
             while len(input_ems) >= split_length:
                 # Overlap the splits by one so we get all transitions
                 splits.append((first, input_ems[:split_length]))
                 input_ems = input_ems[split_length-1:]
                 first = False
             # Get the last short one
             if len(input_ems):
                 # Try to avoid having a small bit that's split off at the end
                 if len(splits) and len(input_ems) <= split_length / 5:
                     # Add these to the end of the last split
                     # This will make it slightly longer than requested
                     splits[-1][1].extend(input_ems)
                 else:
                     splits.append((first, input_ems))
             split_emissions.extend(splits)
     else:
         # All streams begin a song
         split_emissions = [(True,stream) for stream in emissions]
     logger.info("Sequence lengths after preprocessing: %s" % 
             " ".join([str(len(em[1])) for em in split_emissions]))
     ##########
     
     # Special case of -1 for number of sequences
     # No point in creating more processes than there are sequences
     if processes == -1 or processes > len(split_emissions):
         processes = len(split_emissions)
     
     iteration = 0
     last_logprob = None
     while iteration < max_iterations:
         logger.info("Beginning iteration %d" % iteration)
         current_logprob = 0.0
         
         # Build a tuple of the arrays that will be updated by each sequence
         self.global_arrays = self.get_empty_arrays()
         
         # Only use a process pool if there's more than one sequence
         if processes > 1:
             # Create a process pool to use for training
             logger.info("Creating a pool of %d processes" % processes)
             #  catch them at this level
             pool = Pool(processes=processes)
             
             async_results = []
             try:
                 for seq_i,(first,sequence) in enumerate(split_emissions):
                     logger.info("Iteration %d, sequence %d" % (iteration, seq_i))
                     T = len(sequence)
                     if T == 0:
                         continue
                     
                     def _notifier_closure(seq_index):
                         def _notifier(res):
                             logger.info("Sequence %d finished" % seq_index)
                         return _notifier
                     # Create some empty arrays for the updates to go into
                     empty_arrays = self.get_empty_arrays()
                     # Fire off a new call to the process pool for every sequence
                     async_results.append(
                             pool.apply_async(self.sequence_updates, 
                                              (sequence, self.model, empty_arrays, array_ids), 
                                              { 'update_initial' : first },
                                              _notifier_closure(seq_i)) )
                 pool.close()
                 # Wait for all the workers to complete
                 pool.join()
             except KeyboardInterrupt:
                 # If Ctl+C is fired during the processing, we exit here
                 logger.info("Keyboard interrupt was received during EM "\
                     "updates")
                 raise
             
             # Call get() on every AsyncResult so that any exceptions in 
             #  workers get raised
             for res in async_results:
                 # If there was an exception in sequence_updates, it 
                 #  will get raised here
                 res_tuple = res.get()
                 # Run the callback on the results from this process
                 # It might seem sensible to do this using the callback 
                 #  arg to apply_async, but then the callback must be 
                 #  picklable and it doesn't buy us anything really
                 self.sequence_updates_callback(res_tuple)
                 # Add this sequence's logprob into the total for all sequences
                 current_logprob += res_tuple[-1]
         else:
             if len(split_emissions) == 1:
                 logger.info("One sequence: not using a process pool")
             else:
                 logger.info("Not using a process pool: training %d "\
                     "emission sequences sequentially" % \
                     len(split_emissions))
             
             for seq_i,(first,sequence) in enumerate(split_emissions):
                 if len(sequence) > 0:
                     logger.info("Iteration %d, sequence %d" % (iteration, seq_i))
                     # Create some empty arrays for the updates to go into
                     empty_arrays = self.get_empty_arrays()
                     updates = self.sequence_updates(
                                         sequence, self.model,
                                         empty_arrays, array_ids,
                                         update_initial=first)
                     self.sequence_updates_callback(updates)
                     # Update the overall logprob
                     current_logprob += updates[-1]
         
         ######## Model updates
         # Update the main model
         self.update_model(self.global_arrays, array_ids)
         
         # Clear the model's cache so we get the new probabilities
         self.model.clear_cache()
         
         logger.info("Training data log prob: %s" % current_logprob)
         if last_logprob is not None and current_logprob < last_logprob:
             logger.error("Log probability dropped by %s" % \
                             (last_logprob - current_logprob))
         if last_logprob is not None:
             logger.info("Log prob change: %s" % \
                             (current_logprob - last_logprob))
         # Check whether the log probability has converged
         if iteration > 0 and \
                 abs(current_logprob - last_logprob) < convergence_logprob:
             # Don't iterate any more
             logger.info("Distribution has converged: ceasing training")
             break
         
         iteration += 1
         last_logprob = current_logprob
         
         # Only save if we've been asked to save between iterations
         if save_intermediate:
             self.save()
     
     self.record_history("Completed Baum-Welch training")
     # Always save the model now that we're done
     self.save()
     return self.model