コード例 #1
0
    def load(self, dirpath):
        """ Loads a trained model from local disk, given the dirpath

            Parameters
            ----------
            dirpath : str
                a directory where model artifacts are saved.

            Returns
            -------
            self
        """
        if not os.path.exists(dirpath):
            raise ValueError("Model directory not found: {:s}".format(dirpath))

        weights_file = os.path.join(dirpath, "weights.h5")
        params_file = os.path.join(dirpath, "params.json")
        preprocessor_file = os.path.join(dirpath, "preprocessor.pkl")

        if not (os.path.exists(weights_file) or 
                os.path.exists(params_file) or
                os.path.exists(preprocessor_file)):
            raise ValueError("Model files may be corrupted, exiting")
        
        self.model_ = load_model(weights_file, params_file)
        self.preprocessor_ = IndexTransformer.load(preprocessor_file)
        self.tagger_ = Tagger(self.model_, preprocessor=self.preprocessor_)

        return self
コード例 #2
0
ファイル: wrapper.py プロジェクト: sonvx/anago
    def load(cls, weights_file, params_file, preprocessor_file):
        self = cls()
        self.p = IndexTransformer.load(preprocessor_file)
        self.model = load_model(weights_file, params_file)
        # Added by Sonvx on Jan 14, 2021: fix issue ("<tensor> is not an element of this graph." when loading model)
        self.model._make_predict_function()

        return self
コード例 #3
0
 def __init__(self, process_proper_nouns=False):
     super().__init__(process_proper_nouns)
     model = load_model(os.path.join(ELMO_TAGGER_PATH, 'weights.h5'),
                        os.path.join(ELMO_TAGGER_PATH, 'params.json'))
     it = IndexTransformer.load(
         os.path.join(ELMO_TAGGER_PATH, 'preprocessor.pkl'))
     self.pos_tagger = Tagger(model,
                              preprocessor=it,
                              tokenizer=wordpunct_tokenize)
コード例 #4
0
ファイル: tagger_example.py プロジェクト: chokolet/anago
def main(args):
    print('Loading objects...')
    # model = BiLSTMCRF.load(args.weights_file, args.params_file)
    model = load_model(args.weights_file, args.params_file)
    it = IndexTransformer.load(args.preprocessor_file)
    tagger = Tagger(model, preprocessor=it)

    print('Tagging a sentence...')
    res = tagger.analyze(args.sent)
    pprint(res)
コード例 #5
0
    def setUpClass(cls):
        weights_file = os.path.join(SAVE_ROOT, 'weights.h5')
        params_file = os.path.join(SAVE_ROOT, 'params.json')
        preprocessor_file = os.path.join(SAVE_ROOT, 'preprocessor.pickle')

        # Load preprocessor
        p = IndexTransformer.load(preprocessor_file)

        # Load the model.
        model = load_model(weights_file, params_file)

        # Build a tagger
        cls.tagger = anago.Tagger(model, preprocessor=p)

        cls.sent = 'President Obama is speaking at the White House.'
コード例 #6
0
    def test_save_and_load(self):
        char_vocab_size = 100
        word_vocab_size = 10000
        num_labels = 10

        model = BiLSTMCRF(char_vocab_size=char_vocab_size,
                          word_vocab_size=word_vocab_size,
                          num_labels=num_labels)
        model, loss = model.build()

        self.assertFalse(os.path.exists(self.weights_file))
        self.assertFalse(os.path.exists(self.params_file))

        save_model(model, self.weights_file, self.params_file)

        self.assertTrue(os.path.exists(self.weights_file))
        self.assertTrue(os.path.exists(self.params_file))

        model = load_model(self.weights_file, self.params_file)
コード例 #7
0
    def load(cls, weights_file, params_file, preprocessor_file):
        self = cls()
        self.p = IndexTransformer.load(preprocessor_file)
        self.model = load_model(weights_file, params_file)

        return self