def test_persists_save(model, save_file):
    model.save_using(tf.train.Saver())
    t1 = model.sess.run(model.A)
    model.save(save_file)
    m2 = load_tagger_model(save_file)
    t2 = model.sess.run(m2.A)
    np.testing.assert_allclose(t1, t2)
Example #2
0
def test_persists_save(model, save_file):
    model.save_using(tf.train.Saver())
    t1 = model.sess.run(model.A)
    model.save(save_file)
    m2 = load_tagger_model(save_file)
    t2 = model.sess.run(m2.A)
    np.testing.assert_allclose(t1, t2)
Example #3
0
 def _create_model(self, sess, basename, **kwargs):
     model = load_tagger_model(basename, sess=sess, **kwargs)
     softmax_output = tf.nn.softmax(model.probs)
     values, _ = tf.nn.top_k(softmax_output, 1)
     indices = model.best
     if self.return_labels:
         labels = read_json(basename + '.labels')
         list_of_labels = [''] * len(labels)
         for label, idval in labels.items():
             list_of_labels[idval] = label
         class_tensor = tf.constant(list_of_labels)
         table = tf.contrib.lookup.index_to_string_table_from_tensor(class_tensor)
         classes = table.lookup(tf.to_int64(indices))
         return model, classes, values
     else:
         return model, indices, values
Example #4
0
 def _create_model(self, sess, basename, **kwargs):
     model = load_tagger_model(basename, sess=sess, **kwargs)
     softmax_output = tf.nn.softmax(model.probs)
     values, _ = tf.nn.top_k(softmax_output, 1)
     indices = model.best
     if self.return_labels:
         labels = read_json(basename + '.labels')
         list_of_labels = [''] * len(labels)
         for label, idval in labels.items():
             list_of_labels[idval] = label
         class_tensor = tf.constant(list_of_labels)
         table = tf.contrib.lookup.index_to_string_table_from_tensor(class_tensor)
         classes = table.lookup(tf.to_int64(indices))
         return model, classes, values
     else:
         return model, indices, values
Example #5
0
def load_model(modelname, **kwargs):
    return load_tagger_model(BASELINE_TAGGER_LOADERS, modelname, **kwargs)
Example #6
0
def load_model(modelname, **kwargs):
    return load_tagger_model(RNNTaggerModel.load, modelname, **kwargs)