コード例 #1
0
def load_snorkel():
    filename = 'snorkel_model'
    gms = []
    for i in range(6):
        gm = GenerativeModel()
        gm.load(filename + str(i))
        gms.append(gm)
    return gms
コード例 #2
0
 def __init__(self, *args, **kwargs):
     super(SnorkelAgent, self).__init__(*args, **kwargs)
     #TODO: load model
     #        self.models = np.load(filename)['m'].item()
     gms = []
     for i in range(6):
         gm = GenerativeModel()
         gm.load(filename + str(i))
         gms.append(gm)
     self.models = gms
コード例 #3
0
def score_gen_model(predicate_resume,
                    session,
                    gen_model_name=None,
                    parallelism=16):
    if gen_model_name is None:
        model_name = "G" + predicate_resume["predicate_name"] + "Latest"
    logging.info("Stats logging")
    key_group = predicate_resume["label_group"]
    train_cids_query = get_train_cids_with_span(predicate_resume, session)
    L_train = load_ltrain(predicate_resume, session)
    gen_model = GenerativeModel()
    gen_model.load(model_name)
    gen_model.train(L_train,
                    epochs=100,
                    decay=0.95,
                    step_size=0.1 / L_train.shape[0],
                    reg_param=1e-6)
    logging.info(gen_model.weights.lf_accuracy)
    print(gen_model.weights.lf_accuracy)
    train_marginals = gen_model.marginals(L_train)
    fig = plt.figure()
    #hist=plt.hist(train_marginals, bins=20)
    #plt.savefig("plt"+strftime("%d-%m-%Y_%H_%M_%S", gmtime())+".png", dpi=fig.dpi)
    gen_model.learned_lf_stats()
コード例 #4
0
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs)
L_eval = labeler.apply(split=eval_split, parallelism=parallelism)

# defining model
from snorkel.learning import GenerativeModel
# Creating generative model
gen_model = GenerativeModel()

# defining saved weights directory and name
model_name = 'Price_Gen_20K'  # this was provided when the model was saved!
save_dir = '/dfs/scratch0/jdunnmon/data/memex-data/extractor_checkpoints/Price_Gen_20K'  # this was provided when the model was saved!

# loading
print("Loading generative model...")
gen_model.load(model_name=model_name, save_dir=save_dir, verbose=True)

# Evaluating LSTM
print("Evaluating marginals...")
eval_marginals = gen_model.marginals(L_eval)

# Geocoding
from gm_utils import create_extractions_dict
# Enter googlemaps api key to get geocodes, leave blank to just use extracted locations
geocode_key = None
# geocode_key = 'AIzaSyBlLyOaasYMgMxFGUh2jJyxIG0_pZFF_jM'
print("Creating extractions dictionary...")
doc_extractions = create_extractions_dict(session,
                                          L_eval,
                                          eval_marginals,
                                          extractions=[extraction_type],