def trainer_helper(configFile,dataSetFile,tempModel): print "Training model on ",configFile,dataSetFile config = get_training_config_from_json(configFile) sentences, vocab, labels = build_data(dataSetFile,True) word_vecs = wordvecs.load_wordvecs(config.word2vec,vocab) trainer = TextCNNModelTrainer(config,word_vecs,sentences,labels) trainer.train(tempModel) print "Succesfully trained model on ",configFile,dataSetFile," and model is at ",tempModel print "Will proceed at testing the model on same data. If everything is correct, you should see the same accuracy" model = cPickle.load(open(tempModel,"rb")) op = model.classify(sentences) os.remove(tempModel)
__author__ = 'devashish.shankar' if __name__ == "__main__": if len(sys.argv) < 5: print "Usage: training.py" print "\t<model config file path>" print "\t<training data file path>" print "\t<file path to store classifier model>" print "\t<true/false(preprocessing flag)>" exit(0) #processing.. config_file = sys.argv[1] train_data_file = sys.argv[2] model_output_file = sys.argv[3] preprocess = sys.argv[4].lower() training_config = config.get_training_config_from_json(config_file) sentences, vocab, labels = datasets.build_data(train_data_file, preprocess) word_vecs = wordvecs.load_wordvecs(training_config.word2vec, vocab) if training_config.mode == "multichannel": nntrainer = MultiChannelTrainer(training_config, word_vecs, sentences, labels) else: nntrainer = TextCNNModelTrainer(training_config, word_vecs, sentences, labels) nntrainer.train(model_output_file)
def test_config_reader(): #TODO improve this test case, probably check if values are actually getting correctly parsed from config config = get_training_config_from_json("testConfig.json") assert config.mode == "static" print config
if __name__=="__main__": if len(sys.argv)<5: print "Usage: training.py" print "\t<model config file path>" print "\t<training data file path>" print "\t<file path to store classifier model>" print "\t<true/false(preprocessing flag)>" exit(0) #processing.. config_file=sys.argv[1] train_data_file=sys.argv[2] model_output_file=sys.argv[3] preprocess=sys.argv[4].lower() training_config = config.get_training_config_from_json(config_file) sentences, vocab, labels = datasets.build_data(train_data_file,preprocess) print "Dataset loaded" word_vecs = wordvecs.load_wordvecs(training_config.word2vec,vocab) print "Loaded word vecs from file" if training_config.mode=="multichannel": nntrainer = MultiChannelTrainer(training_config,word_vecs,sentences,labels) else: nntrainer = TextCNNModelTrainer(training_config,word_vecs,sentences,labels) nntrainer.train(model_output_file)