Exemplo n.º 1
0
def train(
    trainset_dir,
    word2vec_path=WORD2VEC_MODELPATH,
    ontology_path=ONTOLOGY_PATH,
    model_path=MODEL_PATH,
    recreate_ontology=False,
    verbose=True,
):
    """
    Train and save the model on a given dataset
    :param trainset_dir: path to the directory with the training set
    :param word2vec_path: path to the gensim word2vec model
    :param ontology_path: path to the ontology file
    :param model_path: path where the model should be pickled
    :param recreate_ontology: boolean flag whether to recreate the ontology
    :param verbose: whether to print computation times

    :return None if everything goes fine, error otherwise
    """
    ontology = get_ontology(path=ontology_path, recreate=recreate_ontology)
    docs = get_documents(trainset_dir)

    global_index = build_global_frequency_index(trainset_dir, verbose=verbose)
    word2vec_model = Word2Vec.load(word2vec_path)
    model = LearningModel(global_index, word2vec_model)

    tick = time.clock()

    x, y = build_train_matrices(docs, model, trainset_dir, ontology)

    if verbose:
        print("Matrices built in: {0:.2f}s".format(time.clock() - tick))
    t1 = time.clock()

    if verbose:
        print("X size: {}".format(x.shape))

    # Normalize features
    x = model.maybe_fit_and_scale(x)

    # Train the model
    model.fit_classifier(x, y)

    if verbose:
        print("Fitting the model: {0:.2f}s".format(time.clock() - t1))

    # Pickle the model
    save_to_disk(model_path, model, overwrite=True)
Exemplo n.º 2
0
def batch_train(
    trainset_dir,
    testset_dir,
    nb_epochs=NB_EPOCHS,
    batch_size=BATCH_SIZE,
    ontology_path=ONTOLOGY_PATH,
    model_path=MODEL_PATH,
    recreate_ontology=False,
    word2vec_path=WORD2VEC_MODELPATH,
    verbose=True,
):
    """
    Train and save the model on a given dataset
    :param trainset_dir: path to the directory with the training set
    :param testset_dir: path to the directory with the test set
    :param nb_epochs: number of passes over the training set
    :param batch_size: the size of a single batch
    :param ontology_path: path to the ontology file
    :param model_path: path to the pickled LearningModel object
    :param word2vec_path: path to the gensim word2vec model
    :param recreate_ontology: boolean flag whether to recreate the ontology
    :param verbose: whether to print computation times

    :return None if everything goes fine, error otherwise
    """
    ontology = get_ontology(path=ontology_path, recreate=recreate_ontology, verbose=False)

    global_index = build_global_frequency_index(trainset_dir, verbose=False)
    word2vec_model = Word2Vec.load(word2vec_path)
    model = LearningModel(global_index, word2vec_model)
    previous_best = -1

    for epoch in xrange(nb_epochs):
        doc_generator = get_documents(
            trainset_dir,
            as_generator=True,
            shuffle=True,
        )
        epoch_start = time.clock()

        if verbose:
            print("Epoch {}".format(epoch + 1), end=' ')

        no_more_samples = False
        batch_number = 0
        while not no_more_samples:
            batch_number += 1

            batch = []
            for i in xrange(batch_size):
                try:
                    batch.append(doc_generator.next())
                except StopIteration:
                    no_more_samples = True
                    break

            if not batch:
                break

            x, y = build_train_matrices(batch, model, trainset_dir, ontology)

            # Normalize features
            x = model.maybe_fit_and_scale(x)

            # Train the model
            model.partial_fit_classifier(x, y)

            if verbose:
                sys.stdout.write(b'.')
                sys.stdout.flush()

        if verbose:
            print(" {0:.2f}s".format(time.clock() - epoch_start))

        metrics = test(
            testset_dir,
            model=model,
            ontology=ontology,
            verbose=False
        )

        for k, v in metrics.iteritems():
            print("{0}: {1}".format(k, v))

        if metrics['map'] > previous_best:
            previous_best = metrics['map']
            save_to_disk(model_path, model, overwrite=True)