# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from bertsota.parser.dep_parser import SDPParser if __name__ == '__main__': parser = SDPParser() parser.train(train_file='data/semeval15/en.psd.conllu', dev_file='data/semeval15/en.id.psd.conllu', test_file='data/semeval15/en.id.psd.conllu', save_dir='data/model/psd-id-eval', pretrained_embeddings_file='data/embedding/glove.6B.100d.txt', validate_every=1000) parser.load('data/model/psd-id') parser.evaluate(test_file='data/semeval15/en.id.psd.conllu', save_dir='data/model/psd-id-eval', num_buckets_test=10)
# coding: utf-8 from bertsota.common.data import ParserVocabulary, DataLoader from bertsota.common.utils import init_logger from bertsota.parser.dep_parser import SDPParser for p in range(20, 100, 20): save_dir = 'data/model/bert-base-pas{}'.format(p) parser = SDPParser() parser.train( train_file='data/semeval15/en.pas.train.{}.conllu'.format(p), dev_file='data/semeval15/en.pas.dev.conllu', save_dir=save_dir, pretrained_embeddings_file='data/embedding/glove.6B.100d.shrinked.txt', bert_path=[ 'data/embedding/bert_base_sum/en.train.bert', 'data/embedding/bert_base_sum/en.dev.bert' ]) parser.load(save_dir) logger = init_logger(save_dir, 'test.log') parser.evaluate(test_file='data/semeval15/en.id.pas.conllu', bert_path='data/embedding/bert_base_sum/en.id.bert', save_dir=save_dir, logger=logger) parser.evaluate(test_file='data/semeval15/en.ood.pas.conllu', bert_path='data/embedding/bert_base_sum/en.ood.bert', save_dir=save_dir, logger=logger)
# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from bertsota.parser.dep_parser import SDPParser if __name__ == '__main__': parser = SDPParser() parser.train(train_file='data/semeval15/en.psd.conll', dev_file='data/semeval15/en.psd.conll', test_file='data/semeval15/en.psd.conll', save_dir='data/model/psd-id', pretrained_embeddings_file='data/embedding/glove.6B.100d.txt') parser.load('data/model/dep') parser.evaluate(test_file='data/semeval15/en.id.psd.conll', save_dir='data/model/over-fitting-psd-id', num_buckets_test=10)