예제 #1
0
파일: train.py 프로젝트: thangduong/grammar
from word_classifier.data import ClassifierData
import framework.subgraph.losses as losses
import framework.utils.common as utils
from framework.trainer import Trainer, _default_train_iteration_done
from time import time
import model
import os
import shutil

param_file = 'params.py'
params = utils.load_param_file(param_file)
if not utils.get_dict_value(params, 'ignore_negative_data', False):
    params['num_classes'] = len(params['keywords']) + 1
else:
    params['num_classes'] = len(params['keywords'])
indexer = TextIndexer.from_txt_file(utils.get_dict_value(params, 'vocab_file'))
indexer.add_token('<pad>')
indexer.add_token('unk')
os.makedirs(utils.get_dict_value(params, 'output_location'), exist_ok=True)
indexer.save_vocab_as_pkl(
    os.path.join(utils.get_dict_value(params, 'output_location'), 'vocab.pkl'))
shutil.copyfile(
    param_file,
    os.path.join(utils.get_dict_value(params, 'output_location'), param_file))

params['vocab_size'] = indexer.vocab_size()
training_data = ClassifierData.get_monolingual_training(
    base_dir=params['monolingual_dir'], indexer=indexer, params=params)


def on_checkpoint_saved(trainer, params, save_path):
예제 #2
0
import framework.subgraph.losses as losses
import framework.utils.common as utils
import data
from framework.trainer import Trainer, _default_train_iteration_done
from time import time
import pickle
import model
import os
import shutil
import copy
import numpy as np

param_file = 'params.py'
params = utils.load_param_file(param_file)
params['num_classes'] = len(params['keywords'])+1
indexer = TextIndexer.from_txt_file(utils.get_dict_value(params, 'vocab_file'), max_size=utils.get_dict_value(params,'max_vocab_size',-1))
indexer.add_token('<pad>')
indexer.add_token('unk')
output_indexer = copy.deepcopy(indexer)
output_indexer.add_token('<blank>')
os.makedirs(utils.get_dict_value(params,'output_location'), exist_ok=True)
indexer.save_vocab_as_pkl(os.path.join(utils.get_dict_value(params,'output_location'), 'vocab.pkl'))

files_to_copy = [param_file]
for file in files_to_copy:
	shutil.copyfile(file,os.path.join(utils.get_dict_value(params,'output_location'), file))

params['vocab_size'] = indexer.vocab_size()

if 'training_data_dir' in params:
	training_data = ClassifierData.get_training_data(base_dir=params['training_data_dir'], indexer=indexer, params=params,
예제 #3
0
MODEL_NAME = "v0"

params = {
    'num_words_before': 10,
    'num_words_after': 10,
    'embedding_size': 300,
    'vocab_size': 100000,
    'embedding_device': None,
    'batch_size': 128,
    'num_classes': 4,
    'keywords': ['a', 'an', 'the'],
    'mini_batches_between_checkpoint': 100,
    'monolingual_dir': '/mnt/work/1-billion-word-language-modeling-benchmark'
}

indexer = TextIndexer.from_txt_file(
    os.path.join(params['monolingual_dir'], '1b_word_vocab.txt'))
indexer.add_token('<pad>')
indexer.add_token('unk')
indexer.save_vocab_as_pkl('vocab.pkl')
params['vocab_size'] = indexer.vocab_size()
training_data = ClassifierData.get_monolingual_training(
    base_dir=params['monolingual_dir'], indexer=indexer, params=params)


#print(training_data.next_batch(10))
#exit(0)
#print(training_data.next_batch())
def on_checkpoint_saved(trainer, params, save_path):
    msg = 'saved checkpoint: ' + save_path
    print(msg)