Exemplo n.º 1
0
def evaluate_main():
    positive_class_weight = 2.0
    label = 'label_needsmoderation'
    logger.info(f'Label: {label}')
    data = m.Posts()
    data.set_label(label=label)
    data.set_balance_method(balance_method='translate', sampling_strategy=0.9)
    state_dict = 'model_gbertbase_label_needsmoderation_210423_014254'
    logger.info(f"Loading model from state_dict ./models/{state_dict}.bin")
    load_model_and_evaluate(state_dict,data, label, 2.0)
Exemplo n.º 2
0
def grid_search_main():
    TARGET_LABELS = ['label_discriminating', 'label_inappropriate',
        'label_sentimentnegative', 'label_needsmoderation']
    #TARGET_LABELS = ['label_needsmoderation']
    TRANS_OS = {'translate':[0.9], 'oversample':[0.9]}
    LEARNING_RATES = [1e-5, 5e-6, 25e-7, 125e-8]
    POSITIVE_WEIGHTS = [2., 3.]
    for positive_class_weight in POSITIVE_WEIGHTS:
        for learning_rate in LEARNING_RATES:
            for label in TARGET_LABELS:
                for method, strat in TRANS_OS.items():
                    for strategy in strat:
                        data = m.Posts()
                        data.set_label(label=label)
                        data.set_balance_method(balance_method=method, sampling_strategy=strategy)
                        logger.info('-' * 50)
                        logger.info(f'positive_class_weight: {positive_class_weight}, learning-rate: {learning_rate}')
                        logger.info(f'Label: {label}')
                        logger.info(f'Balance-method: {method}, Balance-strategy: {strategy}')
                        logger.info('-' * 50)
                        make_model(data, label, learning_rate, positive_class_weight)
Exemplo n.º 3
0
import mlflow
from modeling.config import TRACKING_URI, EXPERIMENT_NAME  #, TRACKING_URI_DEV
import logging

from time import time

# set logging
logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s: %(message)s")
logging.getLogger("pyhive").setLevel(logging.CRITICAL)  # avoid excessive logs
logger.setLevel(logging.INFO)

if __name__ == "__main__":

    data = m.Posts()
    #embedding_dict_glove = transformers.load_embedding_vectors(embedding_style='glove')
    #embedding_dict_w2v = transformers.load_embedding_vectors(embedding_style='word2vec')

    trans_os = {
        None: [None],
        'translate': [0.8, 0.9, 1.0],
        'oversample': [0.8, 0.9, 1.0]
    }

    #vecs = {CountVectorizer(): 'count',
    #        TfidfVectorizer(): 'tfidf',
    #        transformers.MeanEmbeddingVectorizer(embedding_dict=embedding_dict_glove): 'glove',
    #       transformers.MeanEmbeddingVectorizer(embedding_dict=embedding_dict_w2v): 'word2vec',
    #       }
    vecs = {