def index_tar_to_file(tar_filename, index_filename, progress=False):
    filesize = os.path.getsize(tar_filename)

    if progress:
        pbar = get_progress_bar(filesize, title="Indexing tarfile")
    else:
        pbar = None

    with tarfile.open(tar_filename, 'r|') as db:
        with open(index_filename, 'w') as outfile:
            counter = 0
            for tarinfo in db:
                currentseek = tarinfo.offset_data
                rec = "%s %d %d\n" % (tarinfo.name, tarinfo.offset_data, tarinfo.size)
                outfile.write(rec)

                counter += 1
                if counter % 1000 == 0:
                    # free ram...
                    db.members = []
                if pbar and counter % 100 == 0:
                    pbar.update(currentseek)

    if pbar:
        pbar.finish()
Beispiel #2
0
    def _score_choices(self, entities, contexts, choice_lists, progress=False):
        """
        Vector space models have a common implementation of the scoring function that makes use of the
        projection defined by the specific model.

        """
        pairwise = self.model_options["pairwise"]

        pbar = None
        if progress:
            pbar = get_progress_bar(len(contexts), title="Scoring choices")

        scores = numpy.zeros(
            (len(contexts), max(len(choices) for choices in choice_lists)),
            dtype=numpy.float64)

        if pairwise:
            for i, (entity, choice_list,
                    context) in enumerate(zip(entities, choice_lists,
                                              contexts)):
                # Project each of the context events
                context_projections = self.project_events(
                    [(entity, event) for event in context], progress=progress)

                if numpy.all(context_projections == 0.):
                    # We have no representation of the chain, so can't score the candidates
                    # Give them all equal scores
                    scores[i, :] = 1.
                    continue

                scores[i, :] = numpy.sum([
                    self._candidates_similarity(context_projections[context_i],
                                                entity, choice_list)
                    for context_i in range(context_projections.shape[0])
                ],
                                         axis=0)

                if pbar:
                    pbar.update(i)
        else:
            # Perform projection of context chains into vector space
            projection = self.project_chains(list(zip(entities, contexts)),
                                             progress=progress)
            for i, (entity,
                    choice_list) in enumerate(zip(entities, choice_lists)):
                if numpy.all(projection[i] == 0.):
                    # We have no representation of the chain, so can't score the candidates
                    # Give them all equal scores
                    scores[i, :] = 1.
                    continue
                scores[i, :] = self._candidates_similarity(
                    projection[i], entity, choice_list)

                if pbar:
                    pbar.update(i)

        if pbar:
            pbar.finish()

        return scores
Beispiel #3
0
    def _score_choices(self, entities, contexts, choice_lists, progress=False):
        pbar = None
        if progress:
            pbar = get_progress_bar(len(entities), title="Scoring choices")

        scores = numpy.zeros((len(entities), max(len(choices) for choices in choice_lists)), dtype=numpy.float32)

        for test_num, (entity, context_events, choices) in enumerate(zip(entities, contexts, choice_lists)):
            context_size = len(context_events)
            # Project each pair of context event and candidate to get a coherence score
            pairs = [
                (entity, [context_event, choice])
                for choice in choices for context_event in context_events
                ]
            coherences = list(self.get_pair_coherences(pairs, unknown_value=0.))
            # Mean the coherences across the context events for each candidate
            choice_scores = [sum(coherences[choice_num*context_size:(choice_num+1)*context_size], 0.) / context_size
                             for choice_num in range(len(choices))]
            scores[test_num, :] = choice_scores

            if pbar:
                pbar.update(test_num)

        if pbar:
            pbar.finish()

        return scores
Beispiel #4
0
    def _score_choices(self, entities, contexts, choice_lists, progress=False):
        """
        Language model-based models have a common implementation of the scoring function that makes use of the
        scoring function defined by the specific model.

        """
        pbar = None
        if progress:
            pbar = get_progress_bar(len(contexts), title="Scoring choices")

        scores = numpy.zeros(
            (len(contexts), max(len(choices) for choices in choice_lists)),
            dtype=numpy.float64)
        # Process each context in turn
        for i, (entity, context,
                choice_list) in enumerate(zip(entities, contexts,
                                              choice_lists)):
            # Make up full chains by adding the completion event onto each context
            chains = [(entity, context + [completion])
                      for completion in choice_list]
            # Get the model to score the completed chains
            scores[i] = self.score_chains(chains)

            if pbar:
                pbar.update(i)

        if pbar:
            pbar.finish()
        return scores
Beispiel #5
0
    def feature_iter(self):
        pbar = None
        if self.progress:
            pbar = get_progress_bar(len(self.corpus), title=self.progress)

        try:
            for doc_num, document in enumerate(self.corpus):
                if pbar:
                    pbar.update(doc_num)

                for entity, events in document.get_chains():
                    yield [predicate_relation(entity, e) for e in events]
        finally:
            if pbar:
                pbar.finish()
Beispiel #6
0
    def build_for_corpus(corpus, min_length, progress=True, limit=None):
        doc_dir = corpus.directory
        ext = "tar.index" if corpus.tarred else "index"
        if limit is not None:
            output_filename = os.path.join(doc_dir, "verb_chains_%d-%d.%s" % (min_length, limit, ext))
        else:
            output_filename = os.path.join(doc_dir, "verb_chains_%d.%s" % (min_length, ext))

        # Read in each rich document in the directory in turn
        num_docs = len(corpus)
        if limit and num_docs > limit:
            num_docs = limit
        pbar = None
        if progress:
            print "Indexing %d documents" % num_docs
            pbar = get_progress_bar(num_docs, title="Indexing documents", counter=True)

        # Go through every event in all the documents, indexing the verb lemmas
        verb_index = {}
        for doc_num, (archive_name, filename, doc) in enumerate(corpus.archive_iter()):
            if limit and doc_num >= limit:
                break

            # Build an index of the event chains in this document
            chains = [doc.find_events_for_entity(entity) for entity in doc.entities]
            chains = [chain for chain in chains if len(chain) >= min_length]
            events = list(set(sum(chains, [])))

            # Add an entry to the index for each event
            for event in events:
                verb_index.setdefault(event.verb_lemma, []).append((archive_name, filename, doc.events.index(event)))

            if pbar:
                pbar.update(doc_num)

        if pbar:
            pbar.finish()

        if progress:
            print "Indexed %d verb types" % len(verb_index)
            print "Outputting index to %s" % output_filename
        with open(output_filename, 'w') as output_file:
            pickle.dump({
                "index": verb_index,
                "min_length": min_length
            }, output_file)
Beispiel #7
0
    def build_for_corpus(corpus, progress=True, limit=None):
        doc_dir = corpus.directory
        ext = "tar.index" if corpus.tarred else "index"
        if limit is not None:
            output_filename = os.path.join(doc_dir, "verbs-%d.%s" % (limit, ext))
        else:
            output_filename = os.path.join(doc_dir, "verbs.%s" % ext)

        # Read in each rich document in the directory in turn
        num_docs = len(corpus)
        if limit and num_docs > limit:
            num_docs = limit
        pbar = None
        if progress:
            print "Indexing %d documents" % num_docs
            pbar = get_progress_bar(num_docs, title="Indexing documents", counter=True)

        # Go through every event in all the documents, indexing the verb lemmas
        verb_index = {}
        for doc_num, (archive_name, filename, doc) in enumerate(corpus.archive_iter()):
            if limit and doc_num >= limit:
                break

            # Add an entry to the index for each event
            for event_num, event in enumerate(doc.events):
                verb_index.setdefault(event.verb_lemma, []).append((archive_name, filename, event_num))

            if pbar:
                pbar.update(doc_num)

        if pbar:
            pbar.finish()

        if progress:
            print "Indexed %d verb types" % len(verb_index)
            print "Outputting index to %s" % output_filename
        with open(output_filename, 'w') as output_file:
            pickle.dump(verb_index, output_file)
Beispiel #8
0
    def from_counts(elements, counts, progress=False):
        order = PartialOrder(elements)

        if progress:
            pbar = get_progress_bar(len(counts))
        else:
            pbar = None

        for i, ((first, second), count) in enumerate(
                reversed(sorted(counts.items(), key=itemgetter(1)))):
            if first in elements and second in elements and first != second:
                try:
                    order.add(first, second)
                except OrderConflict:
                    # No problem: skip this one
                    pass

            if pbar:
                pbar.update(i)

        if pbar:
            pbar.finish()

        return order
    # Load the model
    model_options = cmd_line_model_options(opts.model_type, opts.opts)
    model_cls = NarrativeChainModel.load_type(opts.model_type)
    # Check this is a suitable type of model
    if not issubclass(model_cls, CoherenceScorer):
        print >>sys.stderr, "Model type %s is not a coherence scorer (does not override CoherenceScorer)" % \
                            opts.model_type
        sys.exit(1)
    model = model_cls.load(opts.model_name, **(model_options or {}))

    # Load the corpus
    tarred = detect_tarred_corpus(opts.corpus_dir)
    corpus = RichEventDocumentCorpus(opts.corpus_dir, tarred=tarred)

    # Open output file
    pbar = get_progress_bar(len(corpus), title="Scoring chains", counter=True)
    next_doc_num_report = 0
    with open(opts.output_file, 'w') as output_file:
        for score, (
                archive_name, doc_num, doc_name, chain_num
        ) in model.chain_coherences(
            (((entity, events),
              (archive_name, doc_num, doc.doc_name, chain_num))
             for doc_num, (archive_name, filename,
                           doc) in pbar(enumerate(corpus.archive_iter()))
             for chain_num, (entity, events) in enumerate(doc.get_chains())),
                batch_size=opts.batch):
            if score is not None:
                print >> output_file, "%s, %s, %d, %g" % (
                    archive_name, doc_name, chain_num, score)
    pbar.finish()
                event = copy.copy(event)
                for event_entity in event.get_entities():
                    # Don't replace the chain entity
                    if event_entity is not chain_entity:
                        # Others get replaced by their headword (or None if a headword can't be found)
                        event.substitute_entity(event_entity, event_entity.get_head_word())
                yield event
    else:
        _filter_events = lambda entity, events: events

    if opts.threshold is not None:
        # Run over the dataset to count up predicates and arg words so we know what to filter out
        predicates = Counter()
        arguments = Counter()
        log.info("Counting event slot words to apply threshold")
        pbar = get_progress_bar(len(corpus), title="Counting")
        for doc in pbar(corpus):
            for entity, events in doc.get_chains():
                events = list(_filter_events(entity, events))
                # Collect the predicate of each event
                predicates.update([predicate_relation(entity, event) for event in events])
                # Collect all np args from the events
                args = sum([event.get_np_argument_words() for event in events], [])
                arguments.update(args)
        pbar.finish()
        # Get just the most common words
        predicates = [p for (p, cnt) in predicates.items() if cnt >= opts.threshold]
        arguments = [a for (a, cnt) in arguments.items() if cnt >= opts.threshold]
        log.info("Predicate set of %d, argument set of %d" % (len(predicates), len(arguments)))

        # Prepare a filter to get rid of any events with rare words
Beispiel #11
0
    def train(self, model_name, corpus, log, opts, chain_features=None):
        from whim.entity_narrative import CandjNarrativeChainModel
        log.info("Training C&J model")

        log.info("Extracting event counts")
        pbar = get_progress_bar(len(corpus), title="Event feature extraction")
        # Loop over all the chains again to collect events and pairs of events
        event_counts = Counter()
        pair_counts = Counter()
        for doc_num, document in enumerate(corpus):
            chains = document.get_chains()
            if len(chains):
                event_chains = list(
                    CandjNarrativeChainModel.extract_chain_feature_dicts(
                        chains))
                # Count all the events
                for chain in event_chains:
                    event_counts.update(chain)
                    # Also count all pairs
                    pairs = []
                    for i in range(len(chain) - 1):
                        for j in range(i + 1, len(chain)):
                            pairs.append(tuple(sorted([chain[i], chain[j]])))
                    pair_counts.update(pairs)

            pbar.update(doc_num)
        pbar.finish()

        if opts.event_threshold is not None and opts.event_threshold > 0:
            log.info("Applying event threshold")
            # Apply a threshold event count
            to_remove = [
                event for (event, count) in event_counts.items()
                if count < opts.event_threshold
            ]
            pbar = get_progress_bar(len(to_remove), title="Filtering counts")
            for i, event in enumerate(to_remove):
                del event_counts[event]
                pbar.update(i)
            pbar.finish()
            # Also remove any pairs involving these events
            pairs_to_remove = [(event0, event1)
                               for ((event0, event1),
                                    count) in pair_counts.items()
                               if event0 in to_remove or event1 in to_remove]
            pbar = get_progress_bar(len(pairs_to_remove),
                                    title="Filtering pair counts")
            for i, pair in enumerate(pairs_to_remove):
                del pair_counts[pair]
                pbar.update(i)
            pbar.finish()
        if opts.pair_threshold is not None and opts.pair_threshold > 0:
            log.info("Apply pair threshold")
            # Apply a threshold pair count
            to_remove = [
                pair for (pair, count) in pair_counts.items()
                if count < opts.pair_threshold
            ]
            pbar = get_progress_bar(len(to_remove),
                                    title="Filtering pair counts")
            for i, pair in enumerate(to_remove):
                del pair_counts[pair]
                pbar.update(i)
            pbar.finish()

        log.info("Saving model: %s" % model_name)
        model = CandjNarrativeChainModel(event_counts, pair_counts)
        model.save(model_name)
        return model
Beispiel #12
0
                    if event_entity is not chain_entity:
                        # Others get replaced by their headword (or None if a headword can't be found)
                        event.substitute_entity(event_entity, event_entity.get_head_word())
                yield event
    else:
        _filter_events = lambda entity, events: events

    if threshold is not None:
        # Run over the dataset to count up predicates and arg words so we know what to filter out
        predicates = Counter()
        arguments = Counter()
        log.info("Counting event slot words to apply threshold")
        max_events = opts.threshold_sample
        if max_events is not None:
            log.info("Limiting to first %d events in corpus" % max_events)
            pbar = get_progress_bar(max_events, title="Counting")
            events_seen = 0
            for doc in corpus:
                for entity, events in doc.get_chains():
                    events = list(_filter_events(entity, events))
                    events_seen += len(events)
                    # Collect the predicate of each event
                    predicates.update([predicate_relation(entity, event) for event in events])
                    # Collect all np args from the events
                    args = sum([event.get_np_argument_words() for event in events], [])
                    arguments.update(args)
                # Stop once we've seen enough
                if events_seen >= max_events:
                    break
                pbar.update(events_seen)
            pbar.finish()
        os.makedirs(output_dir)

        included_verbs = [v for (v, c) in verb_counts if c >= opts.threshold]
        print "Filter will remove %d verb types, leaving %d" % (
            len(verb_counts) - len(included_verbs), len(included_verbs))
        # Output a list of the verbs we're keeping
        print "Outputting verb list to %s" % os.path.join(
            output_dir, "predicates.meta")
        with open(os.path.join(output_dir, "predicates.meta"),
                  'w') as predicates_file:
            predicates_file.write("\n".join(included_verbs))

        print "Counting docs"
        num_docs = len(index.corpus)
        print "Processing %d documents" % num_docs
        pbar = get_progress_bar(num_docs, title="Filtering", counter=True)
        # Do the filtering: read in each document
        current_archive = None
        archives = []
        try:
            for i, (archive_name, filename,
                    doc) in enumerate(index.corpus.archive_iter()):
                pbar.update(i)

                if archive_name != current_archive:
                    # We've moved onto a new archive: create a new archive in the output
                    current_archive = archive_name
                    new_archive = tarfile.open(
                        os.path.join(output_dir, archive_name), 'w')
                    # Add it to the list of archives that will get closed
                    archives.append(new_archive)
Beispiel #14
0
    def train(self, model_name, corpus, log, opts, chain_features=None):
        from whim.entity_narrative import DistributionalVectorsNarrativeChainModel
        log.info("Training context vectors model")

        training_metadata = {
            "data": corpus.directory,
            "pmi": opts.pmi or opts.ppmi,
            "ppmi": opts.ppmi,
        }

        log.info("Extracting event counts")
        pbar = get_progress_bar(len(corpus), title="Event feature extraction")
        # Loop over all the chains again to collect events
        event_counts = Counter()
        for doc_num, document in enumerate(corpus):
            chains = document.get_chains()
            if len(chains):
                event_chains = list(
                    DistributionalVectorsNarrativeChainModel.
                    extract_chain_feature_lists(chains,
                                                only_verb=opts.only_verb,
                                                adjectives=opts.adj))
                # Count all the events
                for chain in event_chains:
                    event_counts.update(chain)

            pbar.update(doc_num)
        pbar.finish()

        if opts.event_threshold is not None and opts.event_threshold > 0:
            log.info("Applying event threshold")
            # Apply a threshold event count
            to_remove = [
                event for (event, count) in event_counts.items()
                if count < opts.event_threshold
            ]
            pbar = get_progress_bar(len(to_remove), title="Filtering counts")
            for i, event in enumerate(to_remove):
                del event_counts[event]
                pbar.update(i)
            pbar.finish()

        log.info("Extracting pair counts")
        pbar = get_progress_bar(len(corpus), title="Pair feature extraction")
        # Loop over all the chains again to collect pairs of events
        pair_counts = Counter()
        for doc_num, document in enumerate(corpus):
            chains = document.get_chains()
            if len(chains):
                event_chains = list(
                    DistributionalVectorsNarrativeChainModel.
                    extract_chain_feature_lists(chains,
                                                only_verb=opts.only_verb,
                                                adjectives=opts.adj))
                # Count all the events
                for chain in event_chains:
                    # Count all pairs
                    pairs = []
                    for i in range(len(chain) - 1):
                        for j in range(i + 1, len(chain)):
                            if chain[i] in event_counts and chain[
                                    j] in event_counts:
                                pairs.append(
                                    tuple(sorted([chain[i], chain[j]])))
                    pair_counts.update(pairs)

            pbar.update(doc_num)
        pbar.finish()

        if opts.pair_threshold is not None and opts.pair_threshold > 0:
            log.info("Applying pair threshold")
            # Apply a threshold pair count
            to_remove = [
                pair for (pair, count) in pair_counts.items()
                if count < opts.pair_threshold
            ]
            if to_remove:
                pbar = get_progress_bar(len(to_remove),
                                        title="Filtering pair counts")
                for i, pair in enumerate(to_remove):
                    del pair_counts[pair]
                    pbar.update(i)
                pbar.finish()
            else:
                log.info("No counts removed")

        # Create a dictionary of the remaining vocabulary
        log.info("Building dictionary")
        dictionary = Dictionary([[event] for event in event_counts.keys()])
        # Put all the co-occurrence counts into a big matrix
        log.info("Building counts matrix: vocab size %d" % len(dictionary))
        vectors = numpy.zeros((len(dictionary), len(dictionary)),
                              dtype=numpy.float64)
        # Fill the matrix with raw counts
        for (event0, event1), count in pair_counts.items():
            if event0 in dictionary.token2id and event1 in dictionary.token2id:
                e0, e1 = dictionary.token2id[event0], dictionary.token2id[
                    event1]
                vectors[e0, e1] = count
                # Add the count both ways (it's only stored once above)
                vectors[e1, e0] = count

        # Now there are many things we could do to these counts
        if opts.pmi or opts.ppmi:
            log.info("Applying %sPMI" % "P" if opts.ppmi else "")
            # Apply PMI to the matrix
            # Compute the total counts for each event (note row and col totals are the same)
            log_totals = numpy.ma.log(vectors.sum(axis=0))
            vectors = numpy.ma.log(vectors * vectors.sum()) - log_totals
            vectors = (vectors.T - log_totals).T
            vectors = vectors.filled(0.)

            if opts.ppmi:
                # Threshold the PMIs at zero
                vectors[vectors < 0.] = 0.

        # Convert to sparse for SVD and storage
        vectors = csr_matrix(vectors)

        if opts.svd:
            log.info("Fitting SVD with %d dimensions" % opts.svd)
            training_metadata["svd from"] = vectors.shape[1]
            training_metadata["svd"] = opts.svd
            vector_svd = TruncatedSVD(opts.svd)
            vectors = vector_svd.fit_transform(vectors)

        log.info("Saving model: %s" % model_name)
        model = DistributionalVectorsNarrativeChainModel(
            dictionary,
            vectors,
            only_verb=opts.only_verb,
            training_metadata=training_metadata,
            adjectives=opts.adj)
        model.save(model_name)
        return model
    parser.add_argument("--tarred",
                        action="store_true",
                        help="The input corpus is tarred")
    opts = parser.parse_args()

    # Prepare a corpus for the input documents
    print "Loading corpus"
    corpus = RichEventDocumentCorpus(opts.doc_dir,
                                     tarred=opts.tarred,
                                     index_tars=False)

    num_docs = len(corpus)
    print "%d documents" % num_docs
    output_filename = os.path.join(opts.doc_dir, "predicate_counts")
    print "Outputting to %s" % output_filename
    pbar = get_progress_bar(num_docs, title="Counting")

    predicate_counter = Counter()
    for doc_num, doc in enumerate(corpus):
        for entity, events in doc.get_chains():
            predicate_counter.update(
                [predicate_relation(entity, event) for event in events])
        pbar.update(doc_num)

    pbar.finish()

    # Convert to a dict of counts
    count_dict = dict(predicate_counter)

    with open(output_filename, 'w') as outfile:
        pickle.dump(count_dict, outfile)
        # Prepare an output tarball for each month
        year = tarball_filename[4:8]
        print "  Year: %s" % year
        month_tarballs = dict([
            (
                "%.2d" % month,
                tarfile.open(os.path.join(output_dir, "nyt_%s%.2d.tar" % (year, month)), 'w')
            ) for month in range(1, 13)
        ])
        # Open up the input tarball so we can iterate over its members
        tarball = tarfile.open(tarball_path, 'r')

        num_files = len(tarball.getnames())

        try:
            pbar = get_progress_bar(num_files, "Splitting tarball")
            for i, tarinfo in enumerate(tarball):
                pbar.update(i)
                # Extract the file to the tmp dir
                f = tarball.extractfile(tarinfo)
                filename = tarinfo.name
                if not filename.startswith("NYT_ENG_"):
                    print "Non-gigaword filename: %s. Skipping" % filename
                else:
                    # The two chars after the year are the month number
                    file_month = filename[12:14]
                    # We should have a tarball for the month
                    if file_month not in month_tarballs:
                        print "No month tarball for month '%s'" % file_month
                    else:
                        month_tarballs[file_month].addfile(tarinfo, f)
            corpus_filename for archive_name, corpus_filename in
            rich_corpus.list_archive_iter()
        ]
        print "Skipping %d files already found in output corpus" % len(
            skip_files)
    else:
        skip_files = []

    total_docs = len(text_corpus) - len(skip_files)
    print "Processing %d documents" % total_docs
    if total_docs < 1:
        # If we've counted 0 or negative docs, go on anyway
        # The counting doesn't actually check the same docs exist and we want to be sure of this
        print "  Strange number of docs, but continuing anyway to check none have been missed"
        total_docs = 1
    pbar = get_progress_bar(total_docs, title="Extracting docs", counter=True)
    for doc_num, document in enumerate(
            RichEventDocument.build_documents(text_corpus,
                                              opts.coref_dir,
                                              opts.deps_dir,
                                              opts.pos_dir,
                                              skip_files=skip_files)):
        if type(document) is tuple:
            # There was an error in extraction
            doc_name, err = document
            output_text = doc_name
        else:
            doc_name = document.doc_name
            output_text = document.to_text()

        if doc_name.startswith("NYT_ENG"):
                                                 "data! Splits 80:10:10")
    parser.add_argument("altavilla_dir", help="Directory containing the Altavilla corpus")
    parser.add_argument("output_dir", help="Directory to output the chapter lists to")
    opts = parser.parse_args()

    chapters_dir = os.path.join(opts.altavilla_dir, "chapters")
    output_dir = os.path.abspath(opts.output_dir)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Get a list of all the chapters from each tarball
    chapters = []
    print "Reading chapter files"
    chapter_tarballs = sum([[os.path.join(chapters_dir, dirpath, filename) for filename in filenames]
                            for dirpath, dirnames, filenames in os.walk(chapters_dir)], [])
    pbar = get_progress_bar(len(chapter_tarballs))

    for i, chapter_filename in enumerate(chapter_tarballs):
        with tarfile.open(chapter_filename, 'r:gz') as tarball:
            # Get all the *.txt filenames
            chapters.extend([member.name for member in tarball.getmembers()
                             if member.isfile() and member.name.endswith(".txt")])

        pbar.update(i)
    pbar.finish()

    print "Read {} chapters".format(len(chapters))
    # Dev/test sets should be 10% of the data each
    subset_size = len(chapters) / 10

    # Choose a random sample to be held out as the dev/test sets
def generate_questions(corpus,
                       output_dir,
                       min_context=4,
                       truncate_context=None,
                       samples=1000,
                       alternatives=5,
                       log=None,
                       unbalanced_sample_rate=None,
                       stoplist=None):
    if log is None:
        log = get_console_logger("Multiple choice questions")

    # Prepare the output directory
    log.info("Outputting to %s" % output_dir)
    if os.path.exists(output_dir):
        nfs_rmtree(output_dir)
    os.makedirs(output_dir)

    # Draw samples samples to evaluate on
    if isinstance(corpus, RichEventDocumentCorpus):
        log.info("Generating %d test samples (unbalanced)" % samples)
        if unbalanced_sample_rate is None:
            log.info("Not subsampling")
            if samples > len(corpus):
                log.warn(
                    "Trying to generate %d samples, but only %d docs in corpus"
                    % (samples, len(corpus)))
        else:
            log.info("Subsampling at a rate of %.3g" % unbalanced_sample_rate)
            if samples > int(
                    float(len(corpus)) * unbalanced_sample_rate * 0.9):
                log.warn(
                    "Trying to generate %d samples, but likely to run out by %d"
                    % float(len(corpus)) * unbalanced_sample_rate)

        questions = MultipleChoiceQuestion.generate_random_unbalanced(
            corpus,
            min_context=min_context,
            truncate_context=truncate_context,
            choices=alternatives,
            subsample=unbalanced_sample_rate,
            stoplist=stoplist)
    else:
        log.info("Generating %d test samples (balanced on verb)" % samples)
        questions = MultipleChoiceQuestion.generate_random_balanced_on_verb(
            corpus,
            min_context=min_context,
            truncate_context=truncate_context,
            choices=alternatives)

    pbar = get_progress_bar(samples, "Generating")
    filename_fmt = "question%%0%dd.txt" % len("%d" % (samples - 1))
    q = 0
    for q, question in enumerate(questions):
        pbar.update(q)
        with open(os.path.join(output_dir, filename_fmt % q), 'w') as q_file:
            q_file.write(question.to_text())
        if q == samples - 1:
            # Got enough samples: stop here
            break
    else:
        log.info("Question generation finished after %d samples" % q)
    pbar.finish()
    def project_from_docs(corpus,
                          model_type,
                          model_name,
                          progress=False,
                          buffer_size=10000,
                          project_events=False,
                          filter_chains=None):
        """
        Project events or chains directly from a document corpus using a vector projection model.
        Yields vectors paired with their events/chains.

        :param corpus:
        :param model_type: narrative chain model type
        :param model_name:
        :param progress: show progress while projecting
        :param buffer_size: batch size to project at once
        :param project_events: project individual events instead of whole chains
        :return:
        """
        model = NarrativeChainModel.load_by_type(model_type, model_name)
        # Allow models that are full vector space models, or ones that implement the projection function we need
        if not (isinstance(model, VectorSpaceNarrativeChainModel) or
                (project_events and hasattr(model, "project_events") or
                 (not project_events and hasattr(model, "project_chains")))):
            raise ValueError(
                "can only build a vector corpus using a vector space model or one that provides "
                "projection functions, not model type %s" % model_type)

        if progress:
            total_docs = len(corpus)
            pbar = get_progress_bar(total_docs, title="Projecting corpus")
        else:
            pbar = None

        # Instead of projecting each chain one by one, or doc by doc, buffer lots and do them in a batch
        # This is hugely faster!
        chain_buffer = []
        source_buffer = []
        for i, (archive, filename, doc) in enumerate(corpus.archive_iter()):
            if pbar:
                pbar.update(i)

            chains = doc.get_chains()
            if filter_chains is not None:
                chains = filter_chains(chains)

            if len(chains):
                if project_events:
                    # Add individual events to the buffers
                    chain_buffer.extend([(entity, event)
                                         for (entity, events) in chains
                                         for event in events])
                    source_buffer.extend([
                        (archive, filename, chain_num, event_num)
                        for chain_num in range(len(chains))
                        for event_num in range(len(chains[chain_num][1]))
                    ])
                else:
                    # Project whole chains
                    chain_buffer.extend(chains)
                    source_buffer.extend([(archive, filename, chain_num)
                                          for chain_num in range(len(chains))])

            if len(chain_buffer) > buffer_size:
                # Project chains/events into vector space using model
                if project_events:
                    chain_vectors = model.project_events(chain_buffer)
                else:
                    chain_vectors = model.project_chains(chain_buffer)
                for j, (source,
                        chain) in enumerate(zip(source_buffer, chain_buffer)):
                    yield (chain_vectors[j], source, chain)
                chain_buffer = []
                source_buffer = []

        if chain_buffer:
            # Clear up remaining buffer
            if project_events:
                chain_vectors = model.project_events(chain_buffer)
            else:
                chain_vectors = model.project_chains(chain_buffer)
            for j, (source,
                    chain) in enumerate(zip(source_buffer, chain_buffer)):
                yield (chain_vectors[j], source, chain)

        if pbar:
            pbar.finish()