Example #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--data',
                        type=argparse_utils.existing_file_path,
                        required=True)
    parser.add_argument('--meta',
                        type=argparse_utils.existing_file_path,
                        required=True)

    parser.add_argument('--type', choices=MODELS, required=True)

    parser.add_argument('--iterations',
                        type=argparse_utils.positive_int,
                        default=1)

    parser.add_argument('--batch_size',
                        type=argparse_utils.positive_int,
                        default=1024)

    parser.add_argument('--word_representation_size',
                        type=argparse_utils.positive_int,
                        default=300)
    parser.add_argument('--representation_initializer',
                        type=argparse_utils.existing_file_path,
                        default=None)

    # Specific to VectorSpaceLanguageModel.
    parser.add_argument('--entity_representation_size',
                        type=argparse_utils.positive_int,
                        default=None)
    parser.add_argument('--num_negative_samples',
                        type=argparse_utils.positive_int,
                        default=None)
    parser.add_argument('--one_hot_classes',
                        action='store_true',
                        default=False)

    parser.add_argument('--regularization_lambda',
                        type=argparse_utils.ratio,
                        default=0.01)

    parser.add_argument('--model_output', type=str, required=True)

    args = parser.parse_args()

    if args.entity_representation_size is None:
        args.entity_representation_size = args.word_representation_size

    args.type = MODELS[args.type]

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    logging_utils.log_module_info(theano, lasagne, np, scipy)

    # Load data.
    logging.info('Loading data from %s.', args.data)
    data_sets = np.load(args.data)

    if 'w_train' in data_sets and not args.ignore_weights:
        w_train = data_sets['w_train']
    else:
        logging.warning('No weights found in data set; '
                        'assuming uniform instance weighting.')

        w_train = np.ones(data_sets['x_train'].shape[0], dtype=np.float32)

    training_set = (data_sets['x_train'], data_sets['y_train'][()], w_train)
    validation_set = (data_sets['x_validate'], data_sets['y_validate'][()])

    logging.info('Training instances: %s (%s) %s (%s) %s (%s)',
                 training_set[0].shape, training_set[0].dtype,
                 training_set[1].shape, training_set[1].dtype,
                 training_set[2].shape, training_set[2].dtype)
    logging.info('Validation instances: %s (%s) %s (%s)',
                 validation_set[0].shape, validation_set[0].dtype,
                 validation_set[1].shape, validation_set[1].dtype)

    num_entities = training_set[1].shape[1]
    assert num_entities > 1

    if args.one_hot_classes:
        logging.info('Transforming y-values to one-hot values.')

        if not scipy.sparse.issparse(training_set[1]) or \
           not scipy.sparse.issparse(validation_set[1]):
            raise RuntimeError(
                'Argument --one_hot_classes expects sparse truth values.')

        y_train, (x_train, w_train) = sparse_to_one_hot_multiple(
            training_set[1], training_set[0], training_set[2])

        training_set = (x_train, y_train, w_train)

        y_validate, (x_validate, ) = sparse_to_one_hot_multiple(
            validation_set[1], validation_set[0])

        validation_set = (x_validate, y_validate)

    logging.info('Loading meta-data from %s.', args.meta)
    with open(args.meta, 'rb') as f:
        # We do not load the remaining of the vocabulary.
        data_args, words, tokens = (pickle.load(f) for _ in range(3))

        vocabulary_size = len(words)

    representations = lasagne.init.GlorotUniform().sample(
        (vocabulary_size, args.word_representation_size))

    if args.representation_initializer:
        # This way of creating the dictionary ignores duplicate words in
        # the representation initializer.
        representation_lookup = dict(
            embedding_utils.load_binary_representations(
                args.representation_initializer, tokens))

        representation_init_count = 0

        for word, meta in words.items():
            if word.lower() in representation_lookup:
                representations[meta.id] = \
                    representation_lookup[word.lower()]

                representation_init_count += 1

        logging.info(
            'Initialized representations from '
            'pre-learned collection for %d words (%.2f%%).',
            representation_init_count,
            (representation_init_count / float(len(words))) * 100.0)

    # Allow GC to clear memory.
    del words
    del tokens

    model_options = {
        'batch_size': args.batch_size,
        'window_size': data_args.window_size,
        'representations_init': representations,
        'regularization_lambda': args.regularization_lambda,
        'training_set': training_set,
        'validation_set': validation_set,
    }

    if args.type == models.LanguageModel:
        model_options.update(output_layer_size=num_entities)
    elif args.type == models.VectorSpaceLanguageModel:
        entity_representations = lasagne.init.GlorotUniform().sample(
            (num_entities, args.entity_representation_size))

        model_options.update(
            entity_representations_init=entity_representations,
            num_negative_samples=args.num_negative_samples)

    # Construct neural net.
    model = args.type(**model_options)

    train(model,
          args.iterations,
          args.model_output,
          abort_threshold=1e-5,
          early_stopping=False,
          additional_args=[args])
Example #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('index', type=argparse_utils.existing_file_path)

    parser.add_argument('session_file', type=argparse_utils.existing_file_path)

    parser.add_argument('--num_workers', type=int, default=1)

    parser.add_argument('--harvested_links_file',
                        type=argparse_utils.existing_file_path,
                        default=None)

    parser.add_argument('--qrel',
                        type=argparse_utils.existing_file_path,
                        nargs='*')

    parser.add_argument('--configuration', type=str, nargs='+')

    parser.add_argument('--top_sessions',
                        type=argparse_utils.positive_int,
                        default=None)

    parser.add_argument('--out_base',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    logging_utils.log_module_info(np, scipy)

    configuration = sesh_pb2.ScoreSessionsConfig()
    pb.text_format.Merge(' '.join(args.configuration), configuration)

    if not configuration.modifier:
        configuration.modifier.add()  # Create an empty modifier.
    elif len(configuration.modifier) > 1:
        modifier_identifiers = [
            modifier.identifier for modifier in configuration.modifier]

        assert all(modifier_identifiers), \
            'All session modifiers should have an identifier.'

        assert len(modifier_identifiers) == len(set(modifier_identifiers)), \
            'All session modifier identifiers should be unique: {}.'.format(
                modifier_identifiers)

    logging.info('Configuration: %s', configuration)

    logging.info('Loading index.')
    index = pyndri.Index(args.index)

    num_documents = index.document_count()
    logging.debug('Index contains %d documents.', num_documents)

    logging.info('Loading dictionary.')
    dictionary = pyndri.extract_dictionary(index)

    logging.info('Loading background corpus.')
    background_prob_dist = pyndri_utils.extract_background_prob_dist(index)

    for modifier in configuration.modifier:
        out_base = os.path.join(args.out_base, modifier.identifier)
        assert not os.path.exists(out_base)

        os.makedirs(out_base)

        logging.info('Loading sessions using %s and outputting to %s.',
                     modifier or 'no modifier', out_base)

        with codecs.open(args.session_file, 'r', 'utf8') as f_xml:
            track_edition, _, sessions, session_id_to_topic_id = \
                domain.construct_sessions(
                    f_xml, args.top_sessions, dictionary)

        logging.info('Discovered %d sessions.', len(sessions))

        sessions = domain.alter_sessions(sessions, modifier)
        documents = domain.get_document_set(sessions.values())

        logging.info('Retained %d sessions (%d SERP documents) '
                     'after filtering.',
                     len(sessions), len(documents))

        # Load QRels for debugging and oracle runs.
        qrels_per_session = []

        for qrel_path in args.qrel:
            with open(qrel_path, 'r') as f_qrel:
                qrels_per_session.append(sesh.parse_qrel(f_qrel, None))

        scorer_impls = {}

        for scorer_desc in configuration.scorer:
            assert scorer_desc.type in scorers.SESSION_SCORERS

            identifier = scorer_desc.identifier or scorer_desc.type
            assert identifier not in scorer_impls

            scorer = scorers.create_scorer(scorer_desc, qrels_per_session)
            logging.info('Scoring using %s.', repr(scorer))

            scorer_impls[identifier] = scorer

        anchor_texts = None

        if args.harvested_links_file is not None:
            urls = set(document.url for document in documents)

            logging.info('Loading anchor texts for session SERPs (%d URLs).',
                         len(urls))

            with codecs.open(args.harvested_links_file, 'r', 'latin1') \
                    as f_harvested_links:
                anchor_texts = load_anchor_texts(f_harvested_links, urls)

            logging.info('Discovered anchor texts for %d URLs (%d total).',
                         len(anchor_texts), len(urls))
        else:
            logging.info('No anchor texts loaded.')

        # The following will hold all the rankings.
        document_assessments_per_session_per_scorer = collections.defaultdict(
            lambda: collections.defaultdict(
                lambda: collections.defaultdict(float)))

        assert configuration.candidate_generator in \
            DOCUMENT_CANDIDATE_GENERATORS

        # Document candidate generation.
        candidate_generator = DOCUMENT_CANDIDATE_GENERATORS[
            configuration.candidate_generator](**locals())

        logging.info('Using %s for document candidate generation.',
                     candidate_generator)

        result_queue = multiprocessing.Queue()

        initargs = [
            result_queue,
            args,
            configuration,
            out_base,
            background_prob_dist,
            candidate_generator,
            scorer_impls,
            index,
            dictionary,
            anchor_texts]

        pool = multiprocessing.Pool(
            args.num_workers,
            initializer=score_session_initializer,
            initargs=initargs)

        worker_result = pool.map_async(
            score_session_worker,
            sessions.values())

        # We will not submit any more tasks to the pool.
        pool.close()

        it = multiprocessing_utils.QueueIterator(
            pool, worker_result, result_queue)

        while True:
            try:
                result = next(it)
            except StopIteration:
                break

            scorer_name, session_id, ranking = result

            document_assessments_per_session_per_scorer[
                scorer_name][session_id] = ranking

        for scorer_name in document_assessments_per_session_per_scorer:
            # Switch object asssessments to lists.
            for topic_id, object_assesments in \
                    document_assessments_per_session_per_scorer[
                        scorer_name].items():
                document_assessments_per_session_per_scorer[
                    scorer_name][topic_id] = [
                    (score, document_id)
                    for document_id, score in object_assesments.items()]

        # Write the runs.
        for scorer_name in document_assessments_per_session_per_scorer:
            run_out_path = os.path.join(
                out_base, '{0}.run'.format(scorer_name))

            with io.open(run_out_path, 'w', encoding='utf8') as f_run_out:
                trec_utils.write_run(
                    scorer_name,
                    document_assessments_per_session_per_scorer[scorer_name],
                    f_run_out)
Example #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--seed',
                        type=argparse_utils.positive_int,
                        required=True)

    parser.add_argument('document_paths',
                        type=argparse_utils.existing_file_path, nargs='+')

    parser.add_argument('--encoding', type=str, default='latin1')

    parser.add_argument('--assoc_path',
                        type=argparse_utils.existing_file_path, required=True)

    parser.add_argument('--num_workers', type=int, default=1)

    parser.add_argument('--vocabulary_min_count', type=int, default=2)
    parser.add_argument('--vocabulary_min_word_size', type=int, default=2)
    parser.add_argument('--vocabulary_max_size', type=int, default=65536)

    parser.add_argument('--remove_stopwords', type=str, default='nltk')

    parser.add_argument('--validation_set_ratio',
                        type=argparse_utils.ratio, default=0.01)

    parser.add_argument('--window_size', type=int, default=10)
    parser.add_argument('--overlapping', action='store_true', default=False)
    parser.add_argument('--stride',
                        type=argparse_utils.positive_int,
                        default=None)

    parser.add_argument('--resample', action='store_true', default=False)

    parser.add_argument('--no_shuffle', action='store_true', default=False)
    parser.add_argument('--no_padding', action='store_true', default=False)

    parser.add_argument('--no_instance_weights',
                        action='store_true',
                        default=False)

    parser.add_argument('--meta_output',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)
    parser.add_argument('--data_output',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    # Seed RNG.
    np.random.seed(args.seed)

    logging_utils.log_module_info(np, scipy, sklearn)

    ignore_words = ['<doc>', '</doc>', '<docno>', '<text>', '</text>']
    if args.remove_stopwords == 'none':
        logging.info('Stopwords will be included in instances.')
    elif args.remove_stopwords == 'nltk':
        logging.info('Using NLTK stopword list.')

        ignore_words.extend(stopwords.words('english'))
        ignore_words.extend(stopwords.words('dutch'))
    elif os.path.exists(args.remove_stopwords):
        logging.info('Using custom stopword list (%s).', args.remove_stopwords)

        with open(args.remove_stopwords, 'r') as f:
            ignore_words.extend(
                filter(len, map(str.strip, map(str.lower, f.readlines()))))
    else:
        logging.error('Invalid stopword removal strategy "%s".',
                      args.remove_stopwords)

        return -1

    logging.info('Ignoring words: %s.', ignore_words)

    # TODO(cvangysel): maybe switch by encapsulated call?
    words, tokens = io_utils.extract_vocabulary(
        args.document_paths,
        min_count=args.vocabulary_min_count,
        max_vocab_size=args.vocabulary_max_size,
        min_word_size=args.vocabulary_min_word_size,
        num_workers=args.num_workers,
        ignore_tokens=ignore_words,
        encoding=args.encoding)

    logging.info('Loading document identifiers.')

    reader = trec_utils.TRECTextReader(args.document_paths,
                                       encoding=args.encoding)

    document_ids = reader.iter_document_ids(num_workers=args.num_workers)

    with open(args.assoc_path, 'r') as f_assocs:
        assocs = trec_utils.EntityDocumentAssociations(
            f_assocs,
            document_ids=document_ids)

    logging.info('Found %d unique entities.', len(assocs.entities))

    logging.info('Document-per-entity stats: %s',
                 list(map(lambda kv: (kv[0], len(kv[1])),
                          sorted(assocs.documents_per_entity.items()))))

    logging.info(
        'Entity-per-document association stats: %s',
        collections.Counter(
            map(len, assocs.entities_per_document.values())).items())

    # Estimate the position in authorship distribution.
    num_associations_distribution = np.zeros(
        assocs.max_entities_per_document, dtype=np.int32)

    for association_length in (
            len(associated_entities)
            for associated_entities in
            assocs.entities_per_document.values()):
        num_associations_distribution[association_length - 1] += 1

    logging.info('Number of associations distribution: %s',
                 num_associations_distribution)

    position_in_associations_distribution = np.cumsum(
        num_associations_distribution[::-1])[::-1]

    logging.info('Position in associations distribution: %s',
                 position_in_associations_distribution)

    instances_and_labels = []

    num_documents = 0
    num_non_associated_documents = 0

    documents_per_entity = collections.defaultdict(int)
    instances_per_entity = collections.defaultdict(int)
    instances_per_document = {}

    global_label_distribution = collections.defaultdict(float)

    if args.overlapping and args.stride is None:
        args.stride = 1
    elif args.overlapping and args.stride is not None:
        logging.error('Option --overlapping passed '
                      'concurrently with --stride.')

        return -1
    elif args.stride is None:
        logging.info('Defaulting stride to window size.')

        args.stride = args.window_size

    logging.info('Generating instances with stride %d.', args.stride)

    result_q = multiprocessing.Queue()

    pool = multiprocessing.Pool(
        args.num_workers,
        initializer=prepare_initializer,
        initargs=[result_q,
                  args, assocs.entities, assocs.entities_per_document,
                  position_in_associations_distribution,
                  tokens, words,
                  args.encoding])

    max_document_length = 0

    worker_result = pool.map_async(
        prepare_worker, args.document_paths)

    # We will not submit any more tasks to the pool.
    pool.close()

    it = multiprocessing_utils.QueueIterator(
        pool, worker_result, result_q)

    def _extract_key(obj):
        return tuple(sorted(obj))

    num_labels = 0
    num_instances = 0

    instances_per_label = collections.defaultdict(list)

    while True:
        try:
            result = next(it)
        except StopIteration:
            break

        num_documents += 1

        if result:
            document_id, \
                document_instances_and_labels, \
                document_label = result

            assert document_id not in instances_per_document

            num_instances_for_doc = len(document_instances_and_labels)
            instances_per_document[document_id] = num_instances_for_doc

            max_document_length = max(max_document_length,
                                      num_instances_for_doc)

            # For statistical purposes.
            for entity_id in assocs.entities_per_document[document_id]:
                documents_per_entity[entity_id] += 1

            # For statistical purposes.
            for entity_id in assocs.entities_per_document[document_id]:
                num_labels += num_instances_for_doc

            # Aggregate.
            instances_per_label[
                _extract_key(document_label.keys())].extend(
                document_instances_and_labels)

            num_instances += len(document_instances_and_labels)

            # Some more accounting.
            for entity_id, mass in document_label.items():
                global_label_distribution[entity_id] += \
                    num_instances_for_doc * mass
        else:
            num_non_associated_documents += 1

    # assert result_q.empty()

    logging.info('Global unnormalized distribution: %s',
                 global_label_distribution)

    if args.resample:
        num_entities = len(instances_per_label)
        avg_instances_per_doc = (
            float(num_instances) / float(num_entities))

        min_instances_per_entity = int(avg_instances_per_doc)

        logging.info('Setting number of sampled instances '
                     'to %d per document.',
                     avg_instances_per_doc)

    for label in list(instances_per_label.keys()):
        label_instances_and_labels = instances_per_label.pop(label)

        if not label_instances_and_labels:
            logging.warning('Label %s has no instances; skipping.', label)

            continue

        label_num_instances = len(label_instances_and_labels)

        if args.resample:
            assert min_instances_per_entity > 0

            label_instances_and_labels_sample = []
        else:
            label_instances_and_labels_sample = \
                label_instances_and_labels

        while (len(label_instances_and_labels_sample) <
               min_instances_per_entity):
            label_instances_and_labels_sample.append(
                label_instances_and_labels[
                    np.random.randint(label_num_instances)])

        for entity_id in label:
            instances_per_entity[entity_id] += len(
                label_instances_and_labels_sample)

        instances_and_labels.extend(label_instances_and_labels_sample)

    logging.info(
        'Documents-per-indexed-entity stats (mean=%.2f, std_dev=%.2f): %s',
        np.mean(list(documents_per_entity.values())),
        np.std(list(documents_per_entity.values())),
        sorted(documents_per_entity.items()))

    logging.info(
        'Instances-per-indexed-entity stats '
        '(mean=%.2f, std_dev=%.2f, min=%d, max=%d): %s',
        np.mean(list(instances_per_entity.values())),
        np.std(list(instances_per_entity.values())),
        np.min(instances_per_entity.values()),
        np.max(instances_per_entity.values()),
        sorted(instances_per_entity.items()))

    logging.info(
        'Instances-per-document stats (mean=%.2f, std_dev=%.2f, max=%d).',
        np.mean(list(instances_per_document.values())),
        np.std(list(instances_per_document.values())),
        max_document_length)

    logging.info('Observed %d documents of which %d (ratio=%.2f) '
                 'are not associated with an entity.',
                 num_documents, num_non_associated_documents,
                 (float(num_non_associated_documents) / num_documents))

    training_instances_and_labels, validation_instances_and_labels = \
        sklearn.cross_validation.train_test_split(
            instances_and_labels, test_size=args.validation_set_ratio)

    num_training_instances = len(training_instances_and_labels)
    num_validation_instances = len(validation_instances_and_labels)

    num_instances = num_training_instances + num_validation_instances

    logging.info(
        'Processed %d instances; training=%d, validation=%d (ratio=%.2f).',
        num_instances,
        num_training_instances, num_validation_instances,
        (float(num_validation_instances) /
         (num_training_instances + num_validation_instances)))

    # Figure out if there are any entities with no instances, and
    # do not consider them during training.
    entity_indices = {}
    entity_indices_inv = {}

    for entity_id, num_instances in instances_per_entity.items():
        if not num_instances:
            continue

        entity_index = len(entity_indices)

        entity_indices[entity_id] = entity_index
        entity_indices_inv[entity_index] = entity_id

    logging.info('Retained %d entities after instance creation.',
                 len(entity_indices))

    directories = list(map(os.path.dirname,
                           [args.meta_output, args.data_output]))

    # Create those directories.
    [os.makedirs(directory) for directory in directories
     if not os.path.exists(directory)]

    # Dump vocabulary.
    with open(args.meta_output, 'wb') as f:
        for obj in (args, words, tokens,
                    entity_indices_inv, assocs.documents_per_entity):
            pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)

        logging.info('Saved vocabulary to "%s".', args.meta_output)

    instance_dtype = np.min_scalar_type(len(words) - 1)
    logging.info('Instance elements will be stored using %s.', instance_dtype)

    data = {}

    x_train, y_train = instances_and_labels_to_arrays(
        training_instances_and_labels,
        args.window_size,
        entity_indices,
        instance_dtype,
        not args.no_shuffle)

    data['x_train'] = x_train
    data['y_train'] = y_train

    if not args.no_instance_weights:
        w_train = np.fromiter(
            (float(max_document_length) / instances_per_document[doc_id]
             for doc_id, _, _ in training_instances_and_labels),
            np.float32, len(training_instances_and_labels))

        assert w_train.shape == (x_train.shape[0],)

        data['w_train'] = w_train

    x_validate, y_validate = instances_and_labels_to_arrays(
        validation_instances_and_labels,
        args.window_size,
        entity_indices,
        instance_dtype,
        not args.no_shuffle)

    data['x_validate'] = x_validate
    data['y_validate'] = y_validate

    with open(args.data_output, 'wb') as f:
        np.savez(f, **data)

    logging.info('Saved data sets.')

    logging.info('Entity-per-document association stats: {0}'.format(
        collections.Counter(
            map(len, assocs.entities_per_document.values())).items()))

    logging.info('Documents-per-entity stats: {0}'.format(
        collections.Counter(
            map(len, assocs.documents_per_entity.values())).items()))

    logging.info('Done.')
Example #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--data',
                        type=argparse_utils.existing_file_path, required=True)
    parser.add_argument('--meta',
                        type=argparse_utils.existing_file_path, required=True)

    parser.add_argument('--type', choices=MODELS, required=True)

    parser.add_argument('--iterations',
                        type=argparse_utils.positive_int, default=1)

    parser.add_argument('--batch_size',
                        type=argparse_utils.positive_int, default=1024)

    parser.add_argument('--word_representation_size',
                        type=argparse_utils.positive_int, default=300)
    parser.add_argument('--representation_initializer',
                        type=argparse_utils.existing_file_path, default=None)

    # Specific to VectorSpaceLanguageModel.
    parser.add_argument('--entity_representation_size',
                        type=argparse_utils.positive_int, default=None)
    parser.add_argument('--num_negative_samples',
                        type=argparse_utils.positive_int, default=None)
    parser.add_argument('--one_hot_classes',
                        action='store_true',
                        default=False)

    parser.add_argument('--regularization_lambda',
                        type=argparse_utils.ratio, default=0.01)

    parser.add_argument('--model_output', type=str, required=True)

    args = parser.parse_args()

    if args.entity_representation_size is None:
        args.entity_representation_size = args.word_representation_size

    args.type = MODELS[args.type]

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    logging_utils.log_module_info(theano, lasagne, np, scipy)

    # Load data.
    logging.info('Loading data from %s.', args.data)
    data_sets = np.load(args.data)

    if 'w_train' in data_sets and not args.ignore_weights:
        w_train = data_sets['w_train']
    else:
        logging.warning('No weights found in data set; '
                        'assuming uniform instance weighting.')

        w_train = np.ones(data_sets['x_train'].shape[0], dtype=np.float32)

    training_set = (data_sets['x_train'], data_sets['y_train'][()], w_train)
    validation_set = (data_sets['x_validate'], data_sets['y_validate'][()])

    logging.info('Training instances: %s (%s) %s (%s) %s (%s)',
                 training_set[0].shape, training_set[0].dtype,
                 training_set[1].shape, training_set[1].dtype,
                 training_set[2].shape, training_set[2].dtype)
    logging.info('Validation instances: %s (%s) %s (%s)',
                 validation_set[0].shape, validation_set[0].dtype,
                 validation_set[1].shape, validation_set[1].dtype)

    num_entities = training_set[1].shape[1]
    assert num_entities > 1

    if args.one_hot_classes:
        logging.info('Transforming y-values to one-hot values.')

        if not scipy.sparse.issparse(training_set[1]) or \
           not scipy.sparse.issparse(validation_set[1]):
            raise RuntimeError(
                'Argument --one_hot_classes expects sparse truth values.')

        y_train, (x_train, w_train) = sparse_to_one_hot_multiple(
            training_set[1], training_set[0], training_set[2])

        training_set = (x_train, y_train, w_train)

        y_validate, (x_validate,) = sparse_to_one_hot_multiple(
            validation_set[1], validation_set[0])

        validation_set = (x_validate, y_validate)

    logging.info('Loading meta-data from %s.', args.meta)
    with open(args.meta, 'rb') as f:
        # We do not load the remaining of the vocabulary.
        data_args, words, tokens = (pickle.load(f) for _ in range(3))

        vocabulary_size = len(words)

    representations = lasagne.init.GlorotUniform().sample(
        (vocabulary_size, args.word_representation_size))

    if args.representation_initializer:
        # This way of creating the dictionary ignores duplicate words in
        # the representation initializer.
        representation_lookup = dict(
            embedding_utils.load_binary_representations(
                args.representation_initializer, tokens))

        representation_init_count = 0

        for word, meta in words.items():
            if word.lower() in representation_lookup:
                representations[meta.id] = \
                    representation_lookup[word.lower()]

                representation_init_count += 1

        logging.info('Initialized representations from '
                     'pre-learned collection for %d words (%.2f%%).',
                     representation_init_count,
                     (representation_init_count /
                      float(len(words))) * 100.0)

    # Allow GC to clear memory.
    del words
    del tokens

    model_options = {
        'batch_size': args.batch_size,
        'window_size': data_args.window_size,
        'representations_init': representations,
        'regularization_lambda': args.regularization_lambda,
        'training_set': training_set,
        'validation_set': validation_set,
    }

    if args.type == models.LanguageModel:
        model_options.update(
            output_layer_size=num_entities)
    elif args.type == models.VectorSpaceLanguageModel:
        entity_representations = lasagne.init.GlorotUniform().sample(
            (num_entities, args.entity_representation_size))

        model_options.update(
            entity_representations_init=entity_representations,
            num_negative_samples=args.num_negative_samples)

    # Construct neural net.
    model = args.type(**model_options)

    train(model, args.iterations, args.model_output,
          abort_threshold=1e-5,
          early_stopping=False,
          additional_args=[args])