示例#1
0
    def __init__(self,
                 training_options,
                 optimization_options,
                 network,
                 vocabulary,
                 scorer,
                 training_files,
                 sampling,
                 validation_iter,
                 state,
                 profile=False):
        """Creates the optimizer and initializes the training process.

        :type training_options: dict
        :param training_options: a dictionary of training options

        :type optimization_options: dict
        :param optimization_options: a dictionary of optimization options

        :type network: Network
        :param network: a neural network to be trained

        :type vocabulary: Vocabulary
        :param vocabulary: vocabulary that provides mapping between words and
                           word IDs

        :type scorer: TextScorer
        :param scorer: a text scorer for computing validation set perplexity

        :type training_files: list of file objects
        :param training_files: list of files to be used as training data

        :type sampling: list of floats
        :param sampling: specifies a fraction for each training file, how much
                         to sample on each epoch

        :type validation_iter: theanolm.BatchIterator
        :param validation_iter: an iterator for computing validation set
                                perplexity

        :type state: h5py.File
        :param state: HDF5 file where initial training state will be possibly
                      read from, and candidate states will be saved to

        :type profile: bool
        :param profile: if set to True, creates Theano profile objects
        """

        self.network = network
        self.vocabulary = vocabulary
        self.scorer = scorer
        self.validation_iter = validation_iter

        self.optimizer = create_optimizer(optimization_options, self.network,
                                          profile)

        self.training_iter = ShufflingBatchIterator(
            training_files,
            sampling,
            vocabulary,
            batch_size=training_options['batch_size'],
            max_sequence_length=training_options['sequence_length'])

        print("Computing the number of training updates per epoch.")
        sys.stdout.flush()
        self.updates_per_epoch = len(self.training_iter)
        if self.updates_per_epoch < 1:
            raise ValueError("Training data does not contain any sentences.")

        self.stopper = create_stopper(training_options, self)
        self.options = training_options

        # current candidate for the minimum validation cost state
        self._candidate_state = state
        if 'trainer' in self._candidate_state:
            print("Restoring initial network state from {}.".format(
                self._candidate_state.filename))
            sys.stdout.flush()
            self._reset_state()
        else:
            # index to the cost history that corresponds to the current candidate
            # state
            self._candidate_index = None
            # current training epoch
            self.epoch_number = 1
            # number of mini-batch updates performed in this epoch
            self.update_number = 0
            # validation set cost history
            self._cost_history = numpy.asarray([], dtype=theano.config.floatX)

        # number of mini-batch updates between log messages
        self.log_update_interval = 0
        # total number of mini-batch updates performed (after restart)
        self.total_updates = 0