Exemplo n.º 1
0
    def build_from_document_corpus(corpus, model_type, model_name,
                                   progress=False, project_events=False, include_events=False, hash_size=50,
                                   log=None, redis_port=6379, filter_chains=None):
        if log is None:
            log = get_console_logger("neighbour indexing")

        log.info("Loading model %s/%s" % (model_type, model_name))
        model = NarrativeChainModel.load_by_type(model_type, model_name)
        vector_size = model.vector_size

        db_filename = "vectors.rdb"
        # Make sure the model directory exists, so we can get the Redis server pointing there
        model_dir = model.get_model_directory(model_name)
        # If the Redis stored db already exists, remove it, so that we don't end up adding to old data
        if os.path.exists(os.path.join(model_dir, db_filename)):
            os.remove(os.path.join(model_dir, db_filename))
        log.info("Storing vectors in %s" % os.path.join(model_dir, db_filename))

        log.info("Preparing neighbour search hash")
        # Create binary hash
        binary_hash = RandomBinaryProjections("%s:%s_binary_hash" % (model_type, model_name), hash_size)

        log.info("Connecting to Redis server on port %d" % redis_port)
        # Prepare an engine for storing the vectors in
        try:
            redis = Redis(host='localhost', port=redis_port, db=0)
        except ConnectionError, e:
            raise RuntimeError("could not connect to redis server on port %s. Is it running? (%s)" % (redis_port, e))
Exemplo n.º 2
0
    def train_from_cmd_line(self):
        """
        Use Argparse to process command line args, using the model class' own parameters and
        begin model training.

        """
        parser = argparse.ArgumentParser(description="Train a %s model" %
                                         self.model_cls.MODEL_TYPE_NAME)
        # Add some standard options that all trainers use
        parser.add_argument("model_type", help="Type of model to train")
        data_grp = parser.add_argument_group("Event chain data")
        data_grp.add_argument(
            "corpus_dir",
            help="Directory to read in chains from for training data")
        data_grp.add_argument("model_name",
                              help="Name under which to store the model")
        # Add the model type's arguments
        self.prepare_arguments(parser)

        # Parse cmd line args
        opts = parser.parse_args()

        log = get_console_logger("%s train" % self.model_cls.MODEL_TYPE_NAME)

        # Load event chain data
        tarred = detect_tarred_corpus(opts.corpus_dir)
        if tarred:
            log.info("Loading tarred dataset")
        else:
            log.info("Loading raw (untarred) dataset")
        corpus = RichEventDocumentCorpus(opts.corpus_dir, tarred=tarred)

        log.info("Counting corpus size")
        num_docs = len(corpus)
        if num_docs == 0:
            log.error("No documents in corpus")
            sys.exit(1)

        log.info("Training model '%s' on %d documents" %
                 (opts.model_name, num_docs))
        try:
            self.train(opts.model_name, corpus, log, opts)
        except ModelTrainingError, e:
            log.error("Error training model: %s" % e)
Exemplo n.º 3
0
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>
#
from __future__ import absolute_import
import tarfile
import os
import codecs

from whim_common.utils.logging import get_console_logger
from whim_common.utils.progress import get_progress_bar


log = get_console_logger("Tar indexer")


class IndexedTarfile(object):
    def __init__(self, tar_filename, index_filename):
        self.index_filename = index_filename
        self.tar_filename = tar_filename

        self._cache = {}

    def clear_cache(self):
        self._cache = {}

    def fill_cache(self, paths):
        """
        To speed up loading lots of files from the same archive, fill the cache with a whole list of
    def train(self,
              xs,
              ys,
              iterations=10000,
              iteration_callback=None,
              learning_rate=None,
              regularization=None,
              batch_size=20,
              batch_callback=None,
              validation_set=None,
              stopping_iterations=10,
              log=None,
              class_weights=None,
              cost_plot_filename=None,
              training_cost_prop_change_threshold=None,
              undersample=None,
              print_predictions=False):
        """
        Train on data stored in Theano tensors. Uses minibatch training.

        E.g.
        xs = rng.randn(N, num_features)
        ys = rng.randint(size=N, low=0, high=2)

        iteration_callback is called after each iteration with args (iteration, error array).

        If a validation set (xs, ys) is given, it is used to compute an error after each iteration
        and to enforce a stopping criterion. The algorithm will terminate if it goes stopping_iterations
        iterations without an improvement in validation error.

        Updates for each target class can be weighted by giving a vector class_weights. Alternatively,
        give the string 'freq' to weight them by inverse class frequency, or leave as None to apply
        no weighting.

        If compute_error_frequency > 1 (default=5), this number of iterations are performed between each time
        the error is computed on the training set.

        The algorithm will assume it has converged and stop early if the proportional change between successive
        training costs drops below training_cost_prop_change_threshold for five iterations in a row.
        If threshold is given as None, this stopping condition will not be used.

        If undersample is given it should be a float. The training data will be randomly undersampled to produce
        a set in which the expected number of instances of each class is undersample*min_freq, where min_freq
        is the number of instances of the least common observed class. A value of 1.0 will produce a roughly
        balanced set. Every class that is observed at all will be included at least once. The sampling is
        performed once at the beginning of training.

        """
        if log is None:
            log = get_console_logger("MLP train")

        if cost_plot_filename is not None:
            _fname, __, _ext = cost_plot_filename.rpartition(".")
            balanced_cost_plot_filename = "%s_balanced.%s" % (_fname, _ext)
            log.info("Outputting balanced costs to: %s" %
                     balanced_cost_plot_filename)
        else:
            balanced_cost_plot_filename = None

        kwargs = {}
        cost_kwargs = {
            "reg_coef": 0.,  # Always compute the cost without regularization
        }
        if learning_rate is not None:
            kwargs["learning_rate"] = learning_rate
        if regularization is not None:
            kwargs["reg_coef"] = regularization
        log.info("Training params: learning rate=%s, reg coef=%s" %
                 (learning_rate, regularization))
        log.info("Training with %s, batch size=%d" %
                 (self.optimization, batch_size))
        if undersample is not None and undersample > 0.0:
            log.info("Undersampling the dataset with a ratio of %s" %
                     undersample)

        # Work out how many batches to do
        if batch_size is None or batch_size == 0:
            num_batches = 1
        else:
            num_batches = xs.shape[0] / batch_size
            if xs.shape[0] % batch_size != 0:
                num_batches += 1

        if undersample is not None and undersample > 0.0:
            # Undersample the training data to produce a (more) balanced set
            balanced_indices = balanced_array_sample(ys,
                                                     balance_ratio=undersample,
                                                     min_inclusion=1)
            # Copy the data so we're not dealing with a view
            xs = numpy.copy(xs[balanced_indices])
            ys = numpy.copy(ys[balanced_indices])
            # Also sample the validation set similarly
            balanced_validation_indices = balanced_array_sample(
                validation_set[1], balance_ratio=undersample, min_inclusion=1)
            validation_set = (
                numpy.copy(validation_set[0][balanced_validation_indices]),
                numpy.copy(validation_set[1][balanced_validation_indices]))
            log.info("Sampled %d training and %d validation instances" %
                     (xs.shape[0], validation_set[0].shape[0]))

        # Work out class weighting
        # Do this after undersampling: if both are used, we only want the weights to account for any imbalance
        #  left after undersampling
        if class_weights is not None:
            if class_weights == 'freq':
                # Use inverse frequency to weight class updates
                # This procedure is modelled directly on what liblinear does
                class_counts = self.get_class_counts(ys).astype(numpy.float64)
                # Replace zero-counts with 1s
                class_counts = numpy.maximum(class_counts, 1.0)
                class_weights = 1.0 / class_counts
                class_weights *= self.network.num_classes / class_weights.sum()
                log.info("Inverse-frequency class weighting")
            elif class_weights == 'log':
                # Use a different scheme, inversely proportional to the log of the class frequencies
                class_counts = self.get_class_counts(ys).astype(numpy.float64)
                class_counts = numpy.maximum(class_counts, 1.0)
                class_weights = 1.0 / (1.0 + numpy.log(class_counts))
                class_weights *= self.network.num_classes / class_weights.sum()
                log.info("Log-inverse-frequency class weighting")
            else:
                log.info("Custom vector class weighting")
            kwargs["class_weights"] = class_weights
            cost_kwargs["class_weights"] = class_weights
        else:
            log.info("No class weighting")

        # Keep a record of costs, so we can plot them
        val_costs = []
        training_costs = []
        # The costs using the balanced metric
        bal_val_costs = []
        bal_training_costs = []

        # Compute costs using the initialized network
        training_cost = self.compute_cost(xs, ys, **cost_kwargs)
        training_costs.append(training_cost)
        if validation_set is not None:
            val_cost = self.compute_cost(validation_set[0], validation_set[1],
                                         **cost_kwargs)
            val_costs.append(val_cost)
        else:
            val_cost = None

        log.info("Computing initial validation set metrics:")
        class_accuracies = self.network.per_class_accuracy(
            validation_set[0], validation_set[1])
        class_accuracies = class_accuracies[numpy.where(
            numpy.logical_not(numpy.isnan(class_accuracies)))]
        mean_class_accuracy = class_accuracies.mean()
        log.info("Per-class accuracy: %.4f%% (mean over %d classes)" %
                 (mean_class_accuracy, class_accuracies.shape[0]))
        # Also compute mean log prob of targets over val set
        mean_log_prob = self.network.mean_log_prob(validation_set[0],
                                                   validation_set[1])
        log.info("Mean target log prob: %.4f" % mean_log_prob)
        mean_per_class_log_prob = self.network.mean_per_class_target_log_prob(
            validation_set[0], validation_set[1])
        log.info("Mean per-class mean target log prob: %.4f" %
                 mean_per_class_log_prob)

        # Keep a copy of the best weights so far
        best_weights = best_iter = best_val_cost = None
        if validation_set is not None:
            best_weights = self.network.get_weights()
            best_iter = -1
            best_val_cost = val_cost

        below_threshold_its = 0

        # Count the instances we're learning from to give an idea of how hard a time the model's got
        training_class_counts = numpy.bincount(ys)
        training_class_counts = training_class_counts[
            training_class_counts.nonzero()]
        log.info(
            "Training instances per class: min=%d, max=%d (%d unseen classes)"
            % (int(
                training_class_counts.min()), int(training_class_counts.max()),
               self.network.num_classes - training_class_counts.shape[0]))

        for i in range(iterations):
            # Shuffle the training data between iterations, as one should with SGD
            shuffle = numpy.random.permutation(xs.shape[0])
            xs[:] = xs[shuffle]
            ys[:] = ys[shuffle]

            err = 0.0
            if num_batches > 1:
                for batch in range(num_batches):
                    # Update the model with this batch's data
                    batch_err = self._train_fn(
                        xs[batch * batch_size:(batch + 1) * batch_size],
                        ys[batch * batch_size:(batch + 1) * batch_size],
                        **kwargs)
                    err += batch_err

                    if batch_callback is not None:
                        batch_callback(batch, num_batches, batch_err)
            else:
                # Batch training: no need to loop
                err = self._train_fn(xs, ys, **kwargs)

            # Go back and compute training cost
            training_cost = self.compute_cost(xs, ys, **cost_kwargs)
            training_costs.append(training_cost)
            # Training set error
            train_error = self.network.error(xs, ys)
            bal_training_costs.append(
                -self.network.mean_per_class_target_log_prob(xs, ys))

            if validation_set is not None:
                if print_predictions:
                    # Perform some predictions on a random sample of the val set
                    for randind in numpy.random.randint(
                            validation_set[0].shape[0], size=5):
                        # Get the network's predictions
                        predictions = self.network.predict(
                            validation_set[0][None, randind, :])
                        predictions = predictions[0, None]
                        log.info("Input: %s. Predictions: %s" %
                                 (list(
                                     numpy.where(
                                         validation_set[0][randind] > 0)[0]),
                                  list(predictions)))
                # Compute the cost function on the validation set
                val_cost = self.compute_cost(validation_set[0],
                                             validation_set[1], **cost_kwargs)
                val_costs.append(val_cost)
                if val_cost <= best_val_cost:
                    # We assume that, if the validation error remains the same, it's better to use the new set of
                    # weights (with, presumably, a better training error)
                    # Update our best estimate
                    best_weights = self.network.get_weights()
                    best_iter = i
                    best_val_cost = val_cost

                if i - best_iter >= stopping_iterations:
                    # We've gone on long enough without improving validation error
                    # Time to call a halt and use the best validation error we got
                    log.info(
                        "Stopping after %d iterations without improving validation cost"
                        % stopping_iterations)
                    break

                # Compute various metrics
                # Per-class accuracy on val set
                class_accuracies = self.network.per_class_accuracy(
                    validation_set[0], validation_set[1])
                class_accuracies = class_accuracies[numpy.where(
                    numpy.logical_not(numpy.isnan(class_accuracies)))]
                mean_class_accuracy = class_accuracies.mean()
                # Mean log prob of targets over val set
                mean_log_prob = self.network.mean_log_prob(
                    validation_set[0], validation_set[1])
                mean_per_class_log_prob = self.network.mean_per_class_target_log_prob(
                    validation_set[0], validation_set[1])
                log.info(
                    "Completed iteration %d, training cost=%.5f, val cost=%.5f, training error=%.2f%%, "
                    "per-class accuracy: %.4f%%, mean tgt logprob: %.4f, per-class tgt logprob: %.4f"
                    % (i, training_cost, val_cost, train_error * 100.0,
                       mean_class_accuracy, mean_log_prob,
                       mean_per_class_log_prob))
                bal_val_costs.append(-mean_per_class_log_prob)

                if best_iter < i:
                    log.info("No improvement in validation cost")
            else:
                log.info(
                    "Completed iteration %d, training cost=%.5f, training error=%.2f%%"
                    % (i, training_cost, train_error * 100.0))

            if cost_plot_filename:
                # Plot the cost function as we train
                columns = [(training_costs, "Train cost")]
                if validation_set is not None:
                    columns.append((val_costs, "Val cost"))
                ax = plot_costs(None, *columns)
                # Add a line at the most recent best val cost
                ax.axvline(float(best_iter + 1), color="b")
                ax.text(float(best_iter + 1) + 0.1,
                        best_val_cost * 1.1,
                        "Best val cost",
                        color="b")
                plt.savefig(cost_plot_filename)

                bal_columns = [(bal_training_costs, "Train cost (balanced)")]
                if validation_set is not None:
                    bal_columns.append((bal_val_costs, "Val cost (balanced)"))
                plot_costs(balanced_cost_plot_filename, *bal_columns)

            if iteration_callback is not None:
                iteration_callback(i, training_cost, val_cost, train_error,
                                   best_iter)

            # Check the proportional change between this iteration's training cost and the last
            if len(training_costs
                   ) > 2 and training_cost_prop_change_threshold is not None:
                training_cost_prop_change = abs(
                    (training_costs[-2] - training_costs[-1]) /
                    training_costs[-2])
                if training_cost_prop_change < training_cost_prop_change_threshold:
                    # Very small change in training cost - maybe we've converged
                    below_threshold_its += 1
                    if below_threshold_its >= 5:
                        # We've had enough iterations with very small changes: we've converged
                        log.info(
                            "Proportional change in training cost (%g) below %g for five successive iterations: "
                            "converged" %
                            (training_cost_prop_change,
                             training_cost_prop_change_threshold))
                        break
                    else:
                        log.info(
                            "Proportional change in training cost (%g) below %g for %d successive iterations: "
                            "waiting until it's been low for five iterations" %
                            (training_cost_prop_change,
                             training_cost_prop_change_threshold,
                             below_threshold_its))
                else:
                    # Reset the below threshold counter
                    below_threshold_its = 0

        if best_weights is not None:
            # Use the weights that gave us the best error on the validation set
            self.network.set_weights(best_weights)
    parser.add_argument("corpus_dir", help="Path to rich event chain corpus to project into vector space")
    parser.add_argument("model_type", help="Model type to load for projection")
    parser.add_argument("model_name", help="Name of model to load for projection")
    parser.add_argument("--tarred", action="store_true", help="The corpus is tarred")
    parser.add_argument("--events", action="store_true", help="Project each individual event, not whole chains")
    parser.add_argument("--redis-port", type=int, default=6379, help="Port that Redis server is running on")
    parser.add_argument("--hash", type=int, default=10, help="Number of binary hash bits to use")
    parser.add_argument("--threshold", type=int,
                        help="Threshold to apply to counts of predicates and arguments. Events with rare predicates "
                             "or arguments are simply filtered out and not projected")
    parser.add_argument("--replace-entities", action="store_true",
                        help="Replace all entities with their headword (other than the chain protagonist) and treat "
                             "them as NP args")
    opts = parser.parse_args()

    log = get_console_logger("project")

    project_events = opts.events

    log.info("Loading corpus from %s" % opts.corpus_dir)
    corpus = RichEventDocumentCorpus(opts.corpus_dir, tarred=opts.tarred)
    # Doing this caches the corpus length, which we're going to need anyway
    num_docs = len(corpus)
    if project_events:
        log.info("Projecting events from %d documents" % num_docs)
    else:
        log.info("Projecting chains from %d documents" % num_docs)

    # Prepare a filter to apply to each event
    if opts.replace_entities:
        log.info("Replacing entities other than protagonist with NP headwords")
Exemplo n.º 6
0
    def train(self,
              batch_iterator,
              total_samples,
              iterations=10000,
              validation_set=None,
              stopping_iterations=10,
              cost_plot_filename=None,
              iteration_callback=None,
              log=None,
              training_cost_prop_change_threshold=0.0005,
              batch_callback=None,
              first_it_last_layer=False):
        if log is None:
            log = get_console_logger("Autoencoder tune")

        log.info(
            "Tuning params: learning rate=%s (->%s), regularization=%s" %
            (self.learning_rate, self.min_learning_rate, self.regularization))
        if self.update_empty_vectors:
            log.info("Training empty vectors")
        if self.update_input_vectors:
            log.info("Updating basic word representations")

        ######## Compile functions
        network = self.model.pair_projection_model
        # Prepare cost/update functions for training
        cost, updates = self.get_triple_cost_updates()
        cost_without_reg, __ = self.get_triple_cost_updates(regularization=0.)
        # Prepare training functions
        cost_fn = theano.function(
            inputs=network.triple_inputs,
            outputs=cost_without_reg,
        )
        train_fn = theano.function(
            inputs=network.triple_inputs + [
                # Allow the learning rate to be set per update
                theano.Param(self.learning_rate_var,
                             default=self.learning_rate),
            ],
            outputs=cost,
            updates=updates,
        )
        # Doesn't do anything now: used to do something different
        first_pass_train_fn = train_fn
        ###########

        # Keep a record of costs, so we can plot them
        val_costs = []
        training_costs = []

        # Keep a copy of the best weights so far
        val_cost = 0.
        best_weights = best_iter = best_val_cost = None
        if validation_set is not None:
            best_weights = self.network.get_weights()
            best_iter = -1
            best_val_cost = cost_fn(validation_set)

        below_threshold_its = 0

        for i in range(iterations):
            err = 0.0
            batch_num = 0
            learning_rate = self.learning_rate
            seen_samples = 0
            tenth_progress = -1

            if i == 0 and first_it_last_layer:
                # On the first iteration, use the training function that only updates the final layer
                log.info(
                    "First pass: only updating final layer (logistic regression)"
                )
                train = first_pass_train_fn
            else:
                train = train_fn

            for batch_num, batch_inputs in enumerate(batch_iterator):
                # Shuffle the training data between iterations, as one should with SGD
                # Just shuffle within batches
                shuffle = numpy.random.permutation(batch_inputs[0].shape[0])
                for batch_data in batch_inputs:
                    batch_data[:] = batch_data[shuffle]

                # Update the model with this batch's data
                err += train(*batch_inputs, learning_rate=learning_rate)

                seen_samples += batch_inputs[0].shape[0]
                # Update the learning rate, so it falls away as we go through
                # Do this only on the first iteration. After that, LR should just stay at the min
                if i == 0:
                    learning_rate = max(
                        self.min_learning_rate,
                        self.learning_rate *
                        (1. - float(seen_samples) / total_samples))

                current_tenth_progress = int(
                    math.floor(10. * float(seen_samples) / total_samples))
                if current_tenth_progress > tenth_progress:
                    tenth_progress = current_tenth_progress
                    mean_cost_so_far = err / (batch_num + 1)
                    log.info("%d%% of iteration: training cost so far = %.5g" %
                             (current_tenth_progress * 10, mean_cost_so_far))
                    if i == 0:
                        log.info("Learning rate updated to %g" % learning_rate)

                if batch_callback is not None:
                    batch_callback(i, batch_num)

            if batch_num == 0:
                raise ModelTrainingError(
                    "zero batches returned by training data iterator")
            training_costs.append(err / (batch_num + 1))

            if validation_set is not None:
                # Compute the cost function on the validation set
                val_cost = cost_fn(validation_set) / validation_set.shape[0]
                val_costs.append(val_cost)
                if val_cost <= best_val_cost:
                    # We assume that, if the validation error remains the same, it's better to use the new set of
                    # weights (with, presumably, a better training error)
                    if val_cost == best_val_cost:
                        log.info(
                            "Same validation cost: %.4f, using new weights" %
                            val_cost)
                    else:
                        log.info("New best validation cost: %.4f" % val_cost)
                    # Update our best estimate
                    best_weights = self.network.get_weights()
                    best_iter = i
                    best_val_cost = val_cost
                if val_cost >= best_val_cost and i - best_iter >= stopping_iterations:
                    # We've gone on long enough without improving validation error
                    # Time to call a halt and use the best validation error we got
                    log.info(
                        "Stopping after %d iterations of increasing validation cost"
                        % stopping_iterations)
                    break

            log.info(
                "COMPLETED ITERATION %d: training cost=%.5g, val cost=%.5g" %
                (i, training_costs[-1], val_cost))

            if cost_plot_filename:
                # Plot the cost function as we train
                # Skip the first costs, as they're usually so much higher than others that the rest is indistinguishable
                columns = [(training_costs[1:], "Train cost")]
                if validation_set is not None:
                    columns.append((val_costs[1:], "Val cost"))
                ax = plot_costs(None, *columns)
                # Add a line at the most recent best val cost
                ax.axvline(float(best_iter), color="b")
                ax.text(float(best_iter + 1) + 0.1,
                        best_val_cost * 1.1,
                        "Best val cost",
                        color="b")
                plt.savefig(cost_plot_filename)

            if iteration_callback is not None:
                # Not computing training error at the moment
                iteration_callback(i, training_costs[-1], val_cost, 0.0,
                                   best_iter)

            # Check the proportional change between this iteration's training cost and the last
            if len(training_costs) > 2:
                training_cost_prop_change = abs(
                    (training_costs[-2] - training_costs[-1]) /
                    training_costs[-2])
                if training_cost_prop_change < training_cost_prop_change_threshold:
                    # Very small change in training cost - maybe we've converged
                    below_threshold_its += 1
                    if below_threshold_its >= 5:
                        # We've had enough iterations with very small changes: we've converged
                        log.info(
                            "Proportional change in training cost (%g) below %g for five successive iterations: "
                            "converged" %
                            (training_cost_prop_change,
                             training_cost_prop_change_threshold))
                        break
                    else:
                        log.info(
                            "Proportional change in training cost (%g) below %g for %d successive iterations: "
                            "waiting until it's been low for five iterations" %
                            (training_cost_prop_change,
                             training_cost_prop_change_threshold,
                             below_threshold_its))
                else:
                    # Reset the below threshold counter
                    below_threshold_its = 0

        if best_weights is not None:
            # Use the weights that gave us the best error on the validation set
            self.network.set_weights(best_weights)
def generate_questions(corpus,
                       output_dir,
                       min_context=4,
                       truncate_context=None,
                       samples=1000,
                       alternatives=5,
                       log=None,
                       unbalanced_sample_rate=None,
                       stoplist=None):
    if log is None:
        log = get_console_logger("Multiple choice questions")

    # Prepare the output directory
    log.info("Outputting to %s" % output_dir)
    if os.path.exists(output_dir):
        nfs_rmtree(output_dir)
    os.makedirs(output_dir)

    # Draw samples samples to evaluate on
    if isinstance(corpus, RichEventDocumentCorpus):
        log.info("Generating %d test samples (unbalanced)" % samples)
        if unbalanced_sample_rate is None:
            log.info("Not subsampling")
            if samples > len(corpus):
                log.warn(
                    "Trying to generate %d samples, but only %d docs in corpus"
                    % (samples, len(corpus)))
        else:
            log.info("Subsampling at a rate of %.3g" % unbalanced_sample_rate)
            if samples > int(
                    float(len(corpus)) * unbalanced_sample_rate * 0.9):
                log.warn(
                    "Trying to generate %d samples, but likely to run out by %d"
                    % float(len(corpus)) * unbalanced_sample_rate)

        questions = MultipleChoiceQuestion.generate_random_unbalanced(
            corpus,
            min_context=min_context,
            truncate_context=truncate_context,
            choices=alternatives,
            subsample=unbalanced_sample_rate,
            stoplist=stoplist)
    else:
        log.info("Generating %d test samples (balanced on verb)" % samples)
        questions = MultipleChoiceQuestion.generate_random_balanced_on_verb(
            corpus,
            min_context=min_context,
            truncate_context=truncate_context,
            choices=alternatives)

    pbar = get_progress_bar(samples, "Generating")
    filename_fmt = "question%%0%dd.txt" % len("%d" % (samples - 1))
    q = 0
    for q, question in enumerate(questions):
        pbar.update(q)
        with open(os.path.join(output_dir, filename_fmt % q), 'w') as q_file:
            q_file.write(question.to_text())
        if q == samples - 1:
            # Got enough samples: stop here
            break
    else:
        log.info("Question generation finished after %d samples" % q)
    pbar.finish()
Exemplo n.º 8
0
    def train(self,
              batch_iterator,
              iterations=10000,
              iteration_callback=None,
              validation_set=None,
              stopping_iterations=10,
              log=None,
              cost_plot_filename=None,
              training_cost_prop_change_threshold=0.0005,
              learning_rate=0.1,
              regularization=0.,
              class_weights_vector=None,
              corruption_level=0.,
              continuous_corruption=False,
              loss="xent"):
        """
        Train on data stored in Theano tensors. Uses minibatch training.

        batch_iterator should be a repeatable iterator producing batches.

        iteration_callback is called after each iteration with args (iteration, error array).

        If a validation set (matrix) is given, it is used to compute an error after each iteration
        and to enforce a stopping criterion. The algorithm will terminate if it goes stopping_iterations
        iterations without an improvement in validation error.

        If compute_error_frequency > 1 (default=5), this number of iterations are performed between each time
        the error is computed on the training set.

        The algorithm will assume it has converged and stop early if the proportional change between successive
        training costs drops below training_cost_prop_change_threshold for five iterations in a row.

        Uses L2 regularization.

        """
        if log is None:
            log = get_console_logger("Autoencoder train")

        log.info(
            "Training params: learning rate=%s, noise ratio=%.1f%% (%s), regularization=%s"
            % (learning_rate, corruption_level * 100.0, "continuous corruption"
               if continuous_corruption else "zeroing corruption",
               regularization))
        log.info("Training with SGD")

        ######## Compile functions
        # Prepare cost/update functions for training
        cost, updates = self.network.get_cost_updates(
            self.learning_rate,
            self.regularization,
            class_cost_weights=class_weights_vector,
            corruption_level=corruption_level,
            continuous_corruption=continuous_corruption,
            loss=loss)
        # Prepare training functions
        cost_fn = theano.function(
            inputs=[self.network.x,
                    Param(self.regularization, default=0.0)],
            outputs=cost,
        )
        train_fn = theano.function(
            inputs=[
                self.network.x,
                Param(self.learning_rate, default=0.1),
                Param(self.regularization, default=0.0)
            ],
            outputs=cost,
            updates=updates,
        )
        # Prepare a function to test how close to the identity function the learned mapping is
        # A lower value indicates that it's generalizing more (though not necessarily better)
        identity_ratio = T.mean(
            T.sum(self.network.get_prediction_dist() * (self.network.x > 0),
                  axis=1))
        identity_ratio_fn = theano.function(inputs=[self.network.x],
                                            outputs=identity_ratio)
        ###########

        # Keep a record of costs, so we can plot them
        val_costs = []
        training_costs = []

        # Keep a copy of the best weights so far
        val_cost = 0.
        best_weights = best_iter = best_val_cost = None
        if validation_set is not None:
            best_weights = self.network.get_weights()
            best_iter = -1
            best_val_cost = cost_fn(validation_set)

            log.info("Computing initial validation scores")
            f_score, precision, recall, f_score_classes = self.compute_f_scores(
                validation_set)
            log.info(
                "F-score: %.4f%% (mean over %d classes), P=%.4f%%, R=%.4f%%" %
                (f_score * 100.0, f_score_classes, precision * 100.0,
                 recall * 100.0))
            identity_ratio = identity_ratio_fn(validation_set)
            log.info("Identity ratio = %.4g" % identity_ratio)

        below_threshold_its = 0

        for i in range(iterations):
            err = 0.0
            batch_num = 0
            for batch_num, batch in enumerate(batch_iterator):
                # Shuffle the training data between iterations, as one should with SGD
                # Just shuffle within batches
                shuffle = numpy.random.permutation(batch.shape[0])
                batch[:] = batch[shuffle]

                # Update the model with this batch's data
                err += train_fn(batch,
                                learning_rate=learning_rate,
                                regularization=regularization)

            training_costs.append(err / batch_num)

            if validation_set is not None:
                # Compute the cost function on the validation set
                val_cost = cost_fn(validation_set) / validation_set.shape[0]
                val_costs.append(val_cost)
                if val_cost <= best_val_cost:
                    # We assume that, if the validation error remains the same, it's better to use the new set of
                    # weights (with, presumably, a better training error)
                    if val_cost == best_val_cost:
                        log.info(
                            "Same validation cost: %.4f, using new weights" %
                            val_cost)
                    else:
                        log.info("New best validation cost: %.4f" % val_cost)
                    # Update our best estimate
                    best_weights = self.network.get_weights()
                    best_iter = i
                    best_val_cost = val_cost
                if val_cost >= best_val_cost and i - best_iter >= stopping_iterations:
                    # We've gone on long enough without improving validation error
                    # Time to call a halt and use the best validation error we got
                    log.info(
                        "Stopping after %d iterations of increasing validation cost"
                        % stopping_iterations)
                    break

            log.info(
                "COMPLETED ITERATION %d: training cost=%.5g, val cost=%.5g" %
                (i, training_costs[-1], val_cost))

            if cost_plot_filename:
                # Plot the cost function as we train
                # Skip the first costs, as they're usually so much higher than others that the rest is indistinguishable
                columns = [(training_costs[1:], "Train cost")]
                if validation_set is not None:
                    columns.append((val_costs[1:], "Val cost"))
                ax = plot_costs(None, *columns)
                # Add a line at the most recent best val cost
                ax.axvline(float(best_iter), color="b")
                ax.text(float(best_iter + 1) + 0.1,
                        best_val_cost * 1.1,
                        "Best val cost",
                        color="b")
                from matplotlib import pyplot as plt
                plt.savefig(cost_plot_filename)

            if validation_set is not None:
                f_score, precision, recall, f_score_classes = self.compute_f_scores(
                    validation_set)
                log.info(
                    "Validation f-score: %.4f%% (mean over %d classes), P=%.4f%%, R=%.4f%%"
                    % (f_score * 100.0, f_score_classes, precision * 100.0,
                       recall * 100.0))
                identity_ratio = identity_ratio_fn(validation_set)
                log.info("Validation identity ratio = %.4g" % identity_ratio)

            if iteration_callback is not None:
                # Not computing training error at the moment
                iteration_callback(i, training_costs[-1], val_cost, 0.0,
                                   best_iter)

            # Check the proportional change between this iteration's training cost and the last
            if len(training_costs) > 2:
                training_cost_prop_change = abs(
                    (training_costs[-2] - training_costs[-1]) /
                    training_costs[-2])
                if training_cost_prop_change < training_cost_prop_change_threshold:
                    # Very small change in training cost - maybe we've converged
                    below_threshold_its += 1
                    if below_threshold_its >= 5:
                        # We've had enough iterations with very small changes: we've converged
                        log.info(
                            "Proportional change in training cost (%g) below %g for five successive iterations: "
                            "converged" %
                            (training_cost_prop_change,
                             training_cost_prop_change_threshold))
                        break
                    else:
                        log.info(
                            "Proportional change in training cost (%g) below %g for %d successive iterations: "
                            "waiting until it's been low for five iterations" %
                            (training_cost_prop_change,
                             training_cost_prop_change_threshold,
                             below_threshold_its))
                else:
                    # Reset the below threshold counter
                    below_threshold_its = 0

        if best_weights is not None:
            # Use the weights that gave us the best error on the validation set
            self.network.set_weights(best_weights)
Exemplo n.º 9
0
    def train(self,
              batch_iterator,
              iterations=10000,
              iteration_callback=None,
              validation_set=None,
              stopping_iterations=10,
              log=None,
              cost_plot_filename=None,
              training_cost_prop_change_threshold=0.0005,
              learning_rate=0.1,
              regularization=0.,
              class_weights_vector=None,
              corruption_level=0.,
              continuous_corruption=False,
              loss="xent"):
        """
        See autoencoder trainer: uses the same training for each layer in turn, then rolls out and
        trains the whole thing together.

        """
        if log is None:
            log = get_console_logger("Autoencoder train")

        # Because the layers are all already properly stacked, when we get the cost/updates for a layer,
        # it's already a function of the original input, but only updates the layer itself
        for layer_num, layer in enumerate(self.network.layers):
            log.info("TRAINING LAYER %d" % layer_num)
            ## Compile functions
            # Prepare cost/update functions for training
            cost, updates = layer.get_cost_updates(
                self.learning_rate,
                self.regularization,
                class_cost_weights=class_weights_vector,
                corruption_level=corruption_level,
                continuous_corruption=continuous_corruption,
                loss=loss)
            # Prepare training functions
            # Note that these use the initial input, not the layer input
            cost_fn = theano.function(
                inputs=[self.input,
                        Param(self.regularization, default=0.0)],
                outputs=cost,
            )
            train_fn = theano.function(
                inputs=[
                    self.input,
                    Param(self.learning_rate, default=0.1),
                    Param(self.regularization, default=0.0)
                ],
                outputs=cost,
                updates=updates,
            )
            # Prepare a function to test how close to the identity function the learned mapping is
            # A lower value indicates that it's generalizing more (though not necessarily better)
            identity_ratio = T.mean(
                T.sum(layer.get_prediction_dist() * (layer.x > 0), axis=1))
            identity_ratio_fn = theano.function(inputs=[self.input],
                                                outputs=identity_ratio)

            # Keep a record of costs, so we can plot them
            val_costs = []
            training_costs = []

            # Keep a copy of the best weights so far
            val_cost = 0.
            best_weights = best_iter = best_val_cost = None
            if validation_set is not None:
                best_weights = layer.get_weights()
                best_iter = -1
                best_val_cost = cost_fn(validation_set)

                log.info("Computing initial validation scores")
                identity_ratio = identity_ratio_fn(validation_set)
                log.info("Identity ratio = %.4g" % identity_ratio)

            log.info("Computing initial training cost")
            batch_costs = [cost_fn(batch) for batch in batch_iterator]
            initial_cost = sum(batch_costs) / len(batch_costs)
            log.info("Cost = %g (%d batches)" %
                     (initial_cost, len(batch_costs)))

            below_threshold_its = 0

            for i in range(iterations):
                err = 0.0
                batch_num = 0
                for batch_num, batch in enumerate(batch_iterator):
                    # Shuffle the training data between iterations, as one should with SGD
                    # Just shuffle within batches
                    shuffle = numpy.random.permutation(batch.shape[0])
                    batch[:] = batch[shuffle]

                    # Update the model with this batch's data
                    err += train_fn(batch,
                                    learning_rate=learning_rate,
                                    regularization=regularization)

                training_costs.append(err / batch_num)

                if validation_set is not None:
                    # Compute the cost function on the validation set
                    val_cost = cost_fn(
                        validation_set) / validation_set.shape[0]
                    val_costs.append(val_cost)
                    if val_cost <= best_val_cost:
                        # We assume that, if the validation error remains the same, it's better to use the new set of
                        # weights (with, presumably, a better training error)
                        if val_cost == best_val_cost:
                            log.info(
                                "Same validation cost: %.4f, using new weights"
                                % val_cost)
                        else:
                            log.info("New best validation cost: %.4f" %
                                     val_cost)
                        # Update our best estimate
                        best_weights = layer.get_weights()
                        best_iter = i
                        best_val_cost = val_cost
                    if val_cost >= best_val_cost and i - best_iter >= stopping_iterations:
                        # We've gone on long enough without improving validation error
                        # Time to call a halt and use the best validation error we got
                        log.info(
                            "Stopping after %d iterations of increasing validation cost"
                            % stopping_iterations)
                        break

                    log.info(
                        "COMPLETED ITERATION %d: training cost=%.5g, val cost=%.5g"
                        % (i, training_costs[-1], val_cost))
                else:
                    log.info("COMPLETED ITERATION %d: training cost=%.5g" %
                             (i, training_costs[-1]))

                if cost_plot_filename:
                    # Plot the cost function as we train
                    # Skip the first costs, as they're usually so much higher that the rest is indistinguishable
                    columns = [(training_costs[1:], "Train cost")]
                    if validation_set is not None:
                        columns.append((val_costs[1:], "Val cost"))
                    ax = plot_costs(None, *columns)
                    # Add a line at the most recent best val cost
                    ax.axvline(float(best_iter), color="b")
                    ax.text(float(best_iter + 1) + 0.1,
                            best_val_cost * 1.1,
                            "Best val cost",
                            color="b")
                    from matplotlib import pyplot as plt
                    plt.savefig(cost_plot_filename)

                if validation_set is not None:
                    identity_ratio = identity_ratio_fn(validation_set)
                    log.info("Validation identity ratio = %.4g" %
                             identity_ratio)

                if iteration_callback is not None:
                    # Not computing training error at the moment
                    iteration_callback(i, training_costs[-1], val_cost, 0.0,
                                       best_iter)

                # Check the proportional change between this iteration's training cost and the last
                if len(training_costs) > 2:
                    training_cost_prop_change = abs(
                        (training_costs[-2] - training_costs[-1]) /
                        training_costs[-2])
                    if training_cost_prop_change < training_cost_prop_change_threshold:
                        # Very small change in training cost - maybe we've converged
                        below_threshold_its += 1
                        if below_threshold_its >= 5:
                            # We've had enough iterations with very small changes: we've converged
                            log.info(
                                "Proportional change in training cost (%g) below %g for five successive iterations: "
                                "converged" %
                                (training_cost_prop_change,
                                 training_cost_prop_change_threshold))
                            break
                        else:
                            log.info(
                                "Proportional change in training cost (%g) below %g for %d successive iterations: "
                                "waiting until it's been low for five iterations"
                                % (training_cost_prop_change,
                                   training_cost_prop_change_threshold,
                                   below_threshold_its))
                    else:
                        # Reset the below threshold counter
                        below_threshold_its = 0

            if best_weights is not None:
                # Use the weights that gave us the best error on the validation set
                layer.set_weights(best_weights)
Exemplo n.º 10
0
    def train(self, batch_iterator, iterations=10000, iteration_callback=None, learning_rate=None, regularization=None,
              batch_callback=None, validation_set=None, stopping_iterations=10, log=None,
              cost_plot_filename=None, training_cost_prop_change_threshold=None):
        """
        Train on data stored in Theano tensors. Uses minibatch training.

        The input is given as an iterator over batches that should produce (x, y) pairs.

        E.g.
        xs = rng.randn(N, num_features)
        ys = rng.randint(size=N, low=0, high=2)

        iteration_callback is called after each iteration with args (iteration, error array).

        If a validation set (xs, ys) is given, it is used to compute an error after each iteration
        and to enforce a stopping criterion. The algorithm will terminate if it goes stopping_iterations
        iterations without an improvement in validation error.

        Updates for each target class can be weighted by giving a vector class_weights. Alternatively,
        give the string 'freq' to weight them by inverse class frequency, or leave as None to apply
        no weighting.

        If compute_error_frequency > 1 (default=5), this number of iterations are performed between each time
        the error is computed on the training set.

        The algorithm will assume it has converged and stop early if the proportional change between successive
        training costs drops below training_cost_prop_change_threshold for five iterations in a row.
        If threshold is given as None, this stopping condition will not be used.

        """
        if log is None:
            log = get_console_logger("MLP train")

        if plot_costs is None and cost_plot_filename is not None:
            warnings.warn("disabling plotting, since matplotlib couldn't be loaded")
            cost_plot_filename = None
        elif cost_plot_filename is not None:
            log.info("Plotting costs to %s" % cost_plot_filename)

        kwargs = {}
        if learning_rate is not None:
            kwargs["learning_rate"] = learning_rate
        if regularization is not None:
            kwargs["reg_coef"] = regularization
        log.info("Training params: learning rate=%s, reg coef=%s, algorithm=%s" %
                 (learning_rate, regularization, self.optimization))

        # Keep a record of costs, so we can plot them
        val_costs = []
        training_costs = []

        # Compute costs using the initialized network
        initial_batch_costs = [self.compute_cost(xs, ys) for (xs, ys) in batch_iterator]
        training_cost = sum(initial_batch_costs) / len(initial_batch_costs)
        log.info("Initial training cost: %g" % training_cost)
        training_costs.append(training_cost)
        if validation_set is not None:
            val_cost = self.compute_cost(validation_set[0], validation_set[1])
            val_costs.append(val_cost)
        else:
            val_cost = None
        log.info("Training on %d batches" % len(initial_batch_costs))

        # Keep a copy of the best weights so far
        best_weights = best_iter = best_val_cost = None
        if validation_set is not None:
            best_weights = self.network.get_weights()
            best_iter = -1
            best_val_cost = val_cost

        below_threshold_its = 0

        for i in range(iterations):
            err = 0.0
            batch_num = 0
            for batch_num, (xs, ys) in enumerate(batch_iterator):
                # Shuffle the training data between iterations, as one should with SGD
                # We only do it within batches
                shuffle = numpy.random.permutation(xs.shape[0])
                xs[:] = xs[shuffle]
                ys[:] = ys[shuffle]
                # Update the model with this batch's data
                batch_err = self._train_fn(xs, ys, **kwargs)
                err += batch_err

                if batch_callback is not None:
                    batch_callback(batch_num, batch_err)

            # Go back and compute training cost
            training_cost = err / batch_num
            training_costs.append(training_cost)

            if validation_set is not None:
                # Compute the cost function on the validation set
                val_cost = self.compute_cost(validation_set[0], validation_set[1])
                val_costs.append(val_cost)
                if val_cost <= best_val_cost:
                    # We assume that, if the validation error remains the same, it's better to use the new set of
                    # weights (with, presumably, a better training error)
                    # Update our best estimate
                    best_weights = self.network.get_weights()
                    best_iter = i
                    best_val_cost = val_cost

                if i - best_iter >= stopping_iterations:
                    # We've gone on long enough without improving validation error
                    # Time to call a halt and use the best validation error we got
                    log.info("Stopping after %d iterations without improving validation cost" %
                             stopping_iterations)
                    break

                log.info("Completed iteration %d, training cost=%.5f, val cost=%.5f" % (i, training_cost, val_cost))

                if best_iter < i:
                    log.info("No improvement in validation cost")
            else:
                log.info("Completed iteration %d, training cost=%.5f" % (i, training_cost))

            if cost_plot_filename:
                # Plot the cost function as we train
                # Training cost is usually so high on the first iteration that it makes it impossible to see others
                columns = [(training_costs[1:], "Train cost")]
                if validation_set is not None:
                    columns.append((val_costs[1:], "Val cost"))
                ax, fig = plot_costs(None, *columns, return_figure=True)
                if best_iter is not None:
                    # Add a line at the most recent best val cost
                    ax.axvline(float(best_iter+1), color="b")
                    ax.text(float(best_iter+1)+0.1, best_val_cost*1.1, "Best val cost", color="b")
                # Write out to a file
                from matplotlib import pyplot as plt
                plt.savefig(cost_plot_filename)
                plt.close(fig)

            if iteration_callback is not None:
                iteration_callback(i, training_cost, val_cost, best_iter)

            # Check the proportional change between this iteration's training cost and the last
            if len(training_costs) > 2 and training_cost_prop_change_threshold is not None:
                training_cost_prop_change = abs((training_costs[-2] - training_costs[-1]) / training_costs[-2])
                if training_cost_prop_change < training_cost_prop_change_threshold:
                    # Very small change in training cost - maybe we've converged
                    below_threshold_its += 1
                    if below_threshold_its >= 5:
                        # We've had enough iterations with very small changes: we've converged
                        log.info("Proportional change in training cost (%g) below %g for five successive iterations: "
                                 "converged" % (training_cost_prop_change, training_cost_prop_change_threshold))
                        break
                    else:
                        log.info("Proportional change in training cost (%g) below %g for %d successive iterations: "
                                 "waiting until it's been low for five iterations" %
                                 (training_cost_prop_change, training_cost_prop_change_threshold, below_threshold_its))
                else:
                    # Reset the below threshold counter
                    below_threshold_its = 0

        if best_weights is not None:
            # Use the weights that gave us the best error on the validation set
            # If val set wasn't given, the network just has the latest weights
            self.network.set_weights(best_weights)
Exemplo n.º 11
0
    def train(self,
              xs,
              iterations=10000,
              iteration_callback=None,
              batch_size=20,
              batch_callback=None,
              validation_set=None,
              stopping_iterations=10,
              log=None,
              cost_plot_filename=None,
              training_cost_prop_change_threshold=0.0005,
              learning_rate=0.1,
              regularization=None,
              class_weights=None,
              corruption_level=0.,
              continuous_corruption=False,
              loss="xent"):
        """
        Train on data stored in Theano tensors. Uses minibatch training.

        xs are the vectors to train on. Targets needn't be given, since the input and output are the
        same in an autoencoder.

        iteration_callback is called after each iteration with args (iteration, error array).

        If a validation set (matrix) is given, it is used to compute an error after each iteration
        and to enforce a stopping criterion. The algorithm will terminate if it goes stopping_iterations
        iterations without an improvement in validation error.

        If compute_error_frequency > 1 (default=5), this number of iterations are performed between each time
        the error is computed on the training set.

        The algorithm will assume it has converged and stop early if the proportional change between successive
        training costs drops below training_cost_prop_change_threshold for five iterations in a row.

        Uses L2 regularization.

        Several params are included just to implement the same interface as single_hidden_layer.
        Might want to change this later to be a bit neater.

        """
        if log is None:
            log = get_console_logger("Autoencoder train")

        log.info(
            "Training params: learning rate=%s, noise ratio=%.1f%% (%s), regularization=%.2f"
            % (learning_rate, self.network.corruption_level * 100.0,
               "continuous corruption" if self.network.continuous_corruption
               else "zeroing corruption", regularization))
        log.info("Training with SGD, batch size=%d" % batch_size)

        if class_weights is None:
            # Don't apply any weighting
            class_weights_vector = None
        elif class_weights == "freq":
            # Apply inverse frequency weighting
            class_counts = numpy.maximum(xs.sum(axis=0), 1.0)
            class_weights_vector = 1. / class_counts
            class_weights_vector *= xs.shape[1] / class_weights_vector.sum()
            log.info(
                "Using inverse frequency class weighting in cost function")
        elif class_weights == "log":
            class_counts = numpy.maximum(xs.sum(axis=0), 1.0)
            class_weights_vector = 1. / (numpy.log(class_counts) + 1.)
            class_weights_vector *= xs.shape[1] / class_weights_vector.sum()
            log.info(
                "Using inverse log frequency class weighting in cost function")
        else:
            raise ValueError("invalid class weighting '%s'" % class_weights)

        ######## Compile functions
        # Prepare cost/update functions for training
        cost, updates = self.network.get_cost_updates(
            self.learning_rate,
            self.regularization,
            class_cost_weights=class_weights_vector,
            corruption_level=corruption_level,
            continuous_corruption=continuous_corruption,
            loss=loss)
        # Prepare training functions
        cost_fn = theano.function(
            inputs=[self.network.x,
                    Param(self.regularization, default=0.0)],
            outputs=cost,
        )
        train_fn = theano.function(
            inputs=[
                self.network.x,
                Param(self.learning_rate, default=0.1),
                Param(self.regularization, default=0.0)
            ],
            outputs=cost,
            updates=updates,
        )
        # Prepare a function to test how close to the identity function the learned mapping is
        # A lower value indicates that it's generalizing more (though not necessarily better)
        identity_ratio = T.mean(
            T.sum(self.network.get_prediction_dist() * (self.network.x > 0),
                  axis=1))
        identity_ratio_fn = theano.function(inputs=[self.network.x],
                                            outputs=identity_ratio)
        ###########

        # Throw away ys in validation set
        validation_set = validation_set[0]

        # Prepare a prediction validation set by holding one event out of every chain in the val set
        prediction_targets = numpy.array([
            random.choice(numpy.where(x_row > 0)[0])
            for x_row in validation_set
        ],
                                         dtype=numpy.int16)
        prediction_contexts = validation_set.copy()
        prediction_contexts[range(prediction_contexts.shape[0]),
                            prediction_targets] = 0.
        prediction_balanced_sample = balanced_array_sample(prediction_targets,
                                                           balance_ratio=4.,
                                                           min_inclusion=1)
        prediction_targets = prediction_targets[prediction_balanced_sample]
        prediction_contexts = prediction_contexts[prediction_balanced_sample]
        log.info(
            "Prepared roughly balanced prediction set from validation set with %d examples"
            % prediction_contexts.shape[0])

        # Work out how many batches to do
        if batch_size is None or batch_size == 0:
            num_batches = 1
        else:
            num_batches = xs.shape[0] / batch_size
            if xs.shape[0] % batch_size != 0:
                num_batches += 1

        # Keep a record of costs, so we can plot them
        val_costs = []
        training_costs = []

        # Compute costs using the initialized network
        training_cost = cost_fn(xs)
        training_costs.append(training_cost)
        if validation_set is not None:
            val_cost = cost_fn(validation_set)
            val_costs.append(val_cost)
        else:
            val_cost = None

        log.info("Computing initial validation scores")
        f_score, precision, recall, f_score_classes = self.compute_f_scores(
            validation_set)
        log.info("F-score: %.4f%% (mean over %d classes), P=%.4f%%, R=%.4f%%" %
                 (f_score * 100.0, f_score_classes, precision * 100.0,
                  recall * 100.0))
        log_prob = self.network.prediction_log_prob(prediction_contexts,
                                                    prediction_targets)
        log.info("Logprob = %.4g" % log_prob)
        gen_log_prob = self.network.generalization_log_prob(
            prediction_contexts, prediction_targets)
        log.info("Generalization logprob = %.4g" % gen_log_prob)
        identity_ratio = identity_ratio_fn(validation_set)
        log.info("Identity ratio = %.4g" % identity_ratio)

        # Keep a copy of the best weights so far
        best_weights = best_iter = best_val_cost = None
        if validation_set is not None:
            best_weights = self.network.get_weights()
            best_iter = -1
            best_val_cost = val_cost

        below_threshold_its = 0

        for i in range(iterations):
            # Shuffle the training data between iterations, as one should with SGD
            shuffle = numpy.random.permutation(xs.shape[0])
            xs[:] = xs[shuffle]

            err = 0.0
            if num_batches > 1:
                for batch in range(num_batches):
                    # Update the model with this batch's data
                    batch_err = train_fn(xs[batch * batch_size:(batch + 1) *
                                            batch_size],
                                         learning_rate=learning_rate,
                                         regularization=regularization)
                    err += batch_err

                    if batch_callback is not None:
                        batch_callback(batch, num_batches, batch_err)
            else:
                # Batch training: no need to loop
                ### Always perform one batch iteration to start with to get us into a good part of the space
                train_fn(xs,
                         learning_rate=learning_rate,
                         regularization=regularization)

            # Go back and compute training cost
            training_cost = cost_fn(xs)
            training_costs.append(training_cost)

            if validation_set is not None:
                # Compute the cost function on the validation set
                val_cost = cost_fn(validation_set)
                val_costs.append(val_cost)
                if val_cost <= best_val_cost:
                    # We assume that, if the validation error remains the same, it's better to use the new set of
                    # weights (with, presumably, a better training error)
                    if val_cost == best_val_cost:
                        log.info(
                            "Same validation cost: %.4f, using new weights" %
                            val_cost)
                    else:
                        log.info("New best validation cost: %.4f" % val_cost)
                    # Update our best estimate
                    best_weights = self.network.get_weights()
                    best_iter = i
                    best_val_cost = val_cost
                if val_cost >= best_val_cost and i - best_iter >= stopping_iterations:
                    # We've gone on long enough without improving validation error
                    # Time to call a halt and use the best validation error we got
                    log.info(
                        "Stopping after %d iterations of increasing validation cost"
                        % stopping_iterations)
                    break

            log.info(
                "COMPLETED ITERATION %d: training cost=%.5f, val cost=%.5f" %
                (i, training_cost, val_cost))

            if cost_plot_filename:
                # Plot the cost function as we train
                # Skip the first costs, as they're usually so much higher than others that the rest is indistinguishable
                columns = [(training_costs[1:], "Train cost")]
                if validation_set is not None:
                    columns.append((val_costs[1:], "Val cost"))
                ax = plot_costs(None, *columns)
                # Add a line at the most recent best val cost
                ax.axvline(float(best_iter), color="b")
                ax.text(float(best_iter + 1) + 0.1,
                        best_val_cost * 1.1,
                        "Best val cost",
                        color="b")
                plt.savefig(cost_plot_filename)

            f_score, precision, recall, f_score_classes = self.compute_f_scores(
                validation_set)
            log.info(
                "Validation f-score: %.4f%% (mean over %d classes), P=%.4f%%, R=%.4f%%"
                % (f_score * 100.0, f_score_classes, precision * 100.0,
                   recall * 100.0))
            #log_prob = self.network.prediction_log_prob(prediction_contexts, prediction_targets)
            #log.info("Prediction logprob = %.4g" % log_prob)
            gen_log_prob = self.network.generalization_log_prob(
                prediction_contexts, prediction_targets)
            log.info("Generalization logprob = %.4g" % gen_log_prob)
            identity_ratio = identity_ratio_fn(validation_set)
            log.info("Validation identity ratio = %.4g" % identity_ratio)

            if iteration_callback is not None:
                # Not computing training error at the moment
                iteration_callback(i, training_cost, val_cost, 0.0, best_iter)

            # Check the proportional change between this iteration's training cost and the last
            if len(training_costs) > 2:
                training_cost_prop_change = abs(
                    (training_costs[-2] - training_costs[-1]) /
                    training_costs[-2])
                if training_cost_prop_change < training_cost_prop_change_threshold:
                    # Very small change in training cost - maybe we've converged
                    below_threshold_its += 1
                    if below_threshold_its >= 5:
                        # We've had enough iterations with very small changes: we've converged
                        log.info(
                            "Proportional change in training cost (%g) below %g for five successive iterations: "
                            "converged" %
                            (training_cost_prop_change,
                             training_cost_prop_change_threshold))
                        break
                    else:
                        log.info(
                            "Proportional change in training cost (%g) below %g for %d successive iterations: "
                            "waiting until it's been low for five iterations" %
                            (training_cost_prop_change,
                             training_cost_prop_change_threshold,
                             below_threshold_its))
                else:
                    # Reset the below threshold counter
                    below_threshold_its = 0

        if best_weights is not None:
            # Use the weights that gave us the best error on the validation set
            self.network.set_weights(best_weights)