def main(args): model_dir = CHECKPOINT_DIR.format(args.model, args.train_algo) dict_file = DICTIONARY_DIR.format( args.data ) data_file = './data/{}/corpus_new.txt'.format(args.data) if args.clear_model: clear_model(model_dir) # Init config TRAIN_PARAMS = getattr(importlib.import_module('config.{}_config'.format(args.data)), 'TRAIN_PARAMS') RUN_CONFIG = getattr( importlib.import_module( 'config.{}_config'.format( args.data ) ), 'RUN_CONFIG' ) MySpecialToken = getattr( importlib.import_module( 'config.{}_config'.format( args.data ) ), 'MySpecialToken' ) # Init dataset input_pipe = Word2VecDataset( data_file = data_file, dict_file = dict_file, window_size= TRAIN_PARAMS['window_size'], epochs= TRAIN_PARAMS['epochs'], batch_size=TRAIN_PARAMS['batch_size'], buffer_size=TRAIN_PARAMS['buffer_size'], special_token= MySpecialToken, min_count=TRAIN_PARAMS['min_count'], max_count=TRAIN_PARAMS['max_count'], sample_rate = TRAIN_PARAMS['sample_rate'], model= args.model) input_pipe.build_dictionary() TRAIN_PARAMS.update( { 'vocab_size': input_pipe.total_size, 'freq_dict': input_pipe.dictionary, 'pad_index': input_pipe.pad_index, 'train_algo': args.train_algo, 'loss': args.loss, 'model': args.model } ) # Init Estimator estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu, RUN_CONFIG) train_spec = tf.estimator.TrainSpec( input_fn = input_pipe.build_dataset() ) eval_spec = tf.estimator.EvalSpec( input_fn = input_pipe.build_dataset(is_predict=1), steps= 1000, throttle_secs=60 ) tf.estimator.train_and_evaluate( estimator, train_spec, eval_spec )
def main(args): ## Init directory model_dir = CHECKPOINT_DIR.format(args.data, MODEL) dict_file = DICTIONARY_DIR.format(args.data) # Init config TRAIN_PARAMS = ALL_TRAIN_PARAMS[args.data] RUN_CONFIG = ALL_RUN_CONFIG[args.data] MySpecialToken = getattr( importlib.import_module('config.{}_config'.format(args.data)), 'MySpecialToken') if args.step == 'train': data_file = { 'encoder': './data/{}/train_encoder_source.txt'.format(args.data), 'decoder': './data/{}/train_decoder_source.txt'.format(args.data) } else: data_file = { 'encoder': './data/{}/dev_encoder_source.txt'.format(args.data), 'decoder': './data/{}/dev_decoder_source.txt'.format( args.data) # for predict, this can be same as encoder } if args.clear_model: clear_model(model_dir) # Init dataset input_pipe = Seq2SeqDataset( data_file=data_file, dict_file=dict_file, epochs=TRAIN_PARAMS['epochs'], batch_size=TRAIN_PARAMS['batch_size'], min_count=TRAIN_PARAMS['min_count'], max_count=TRAIN_PARAMS['max_count'], buffer_size=TRAIN_PARAMS['buffer_size'], special_token=MySpecialToken, max_len=TRAIN_PARAMS['max_len'], min_len=TRAIN_PARAMS['min_len'], pretrain_model_list=TRAIN_PARAMS['pretrain_model_list']) input_pipe.build_dictionary() TRAIN_PARAMS.update({ 'vocab_size': input_pipe.total_size, 'freq_dict': input_pipe.dictionary, 'pad_index': input_pipe.pad_index, 'model_dir': model_dir, 'start_index': input_pipe.start_index, 'end_index': input_pipe.end_index, 'pretrain_embedding': input_pipe.load_pretrain_embedding() }) estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu, RUN_CONFIG) if args.step == 'train': estimator.train(input_fn=input_pipe.build_dataset()) if args.step == 'predict': # Please disable GPU in prediction to avoid DST exhausted Error prediction = estimator.predict(input_fn=input_pipe.build_dataset( is_predict=1)) res = [] for i in prediction: res.append(i) with open('./data/{}/{}_predict.pkl'.format(args.data, MODEL), 'wb') as f: pickle.dump(res, f)
def main(args): model_dir = CHECKPOINT_DIR.format(args.data, args.model) dict_file = DICTIONARY_DIR.format(args.data) data_file = './data/{}/train.tfrecords'.format(args.data) if args.clear_model: clear_model(model_dir) # Init config TRAIN_PARAMS = getattr( importlib.import_module('config.{}_config'.format(args.data)), 'TRAIN_PARAMS') RUN_CONFIG = getattr( importlib.import_module('config.{}_config'.format(args.data)), 'RUN_CONFIG') MySpecialToken = getattr( importlib.import_module('config.{}_config'.format(args.data)), 'MySpecialToken') TF_PROTO = getattr( importlib.import_module('config.{}_config'.format(args.data)), 'TF_PROTO') # Init dataset input_pipe = FasttextDataset(data_file=data_file, dict_file=dict_file, epochs=TRAIN_PARAMS['epochs'], batch_size=TRAIN_PARAMS['batch_size'], min_count=TRAIN_PARAMS['min_count'], max_count=TRAIN_PARAMS['max_count'], buffer_size=TRAIN_PARAMS['buffer_size'], special_token=MySpecialToken, ngram=TRAIN_PARAMS['ngram'], tf_proto=TF_PROTO) input_pipe.build_dictionary() TRAIN_PARAMS.update({ 'vocab_size': input_pipe.total_size, 'freq_dict': input_pipe.dictionary, 'pad_index': input_pipe.pad_index, 'model_dir': model_dir }) # Init Estimator model_fn = getattr(importlib.import_module('model_{}'.format(args.model)), 'model_fn') estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu, RUN_CONFIG) if args.step == 'train': early_stopping = tf.estimator.experimental.stop_if_no_decrease_hook( estimator, metric_name='loss', max_steps_without_decrease=100 * 1000) train_spec = tf.estimator.TrainSpec( input_fn=input_pipe.build_dataset(), hooks=[early_stopping]) eval_spec = tf.estimator.EvalSpec( input_fn=input_pipe.build_dataset(is_predict=1), steps=500, throttle_secs=60) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) if args.step == 'predict': prediction = estimator.predict(input_fn=input_pipe.build_dataset( is_predict=1)) with open('prediction.pkl', 'wb') as f: pickle.dump(prediction, f)
def main(args): model_dir = CHECKPOINT_DIR.format(args.data, args.model) dict_file = DICTIONARY_DIR.format(args.data) # For Quick-thought model,encoder source = decoder source = all continuous sentences if args.model == 'quick_thought': data_file = { 'encoder': './data/{}/all_sentences.txt'.format(args.data), 'decoder': './data/{}/all_sentences.txt'.format(args.data) } else: data_file = { 'encoder': './data/{}/encoder_source.txt'.format(args.data), 'decoder': './data/{}/decoder_source.txt'.format(args.data) } if args.clear_model: clear_model(model_dir) # Init config TRAIN_PARAMS = getattr( importlib.import_module('config.{}_config'.format(args.data)), 'TRAIN_PARAMS') RUN_CONFIG = getattr( importlib.import_module('config.{}_config'.format(args.data)), 'RUN_CONFIG') MySpecialToken = getattr( importlib.import_module('config.{}_config'.format(args.data)), 'MySpecialToken') ED_PARAMS = getattr( importlib.import_module('config.{}_config'.format(args.data)), 'ED_PARAMS') # Init dataset input_pipe = SkipThoughtDataset(data_file=data_file, dict_file=dict_file, epochs=TRAIN_PARAMS['epochs'], batch_size=TRAIN_PARAMS['batch_size'], min_count=TRAIN_PARAMS['min_count'], max_count=TRAIN_PARAMS['max_count'], buffer_size=TRAIN_PARAMS['buffer_size'], special_token=MySpecialToken, max_len=TRAIN_PARAMS['max_decode_iter'], min_len=TRAIN_PARAMS['min_len']) input_pipe.build_dictionary() TRAIN_PARAMS.update({ 'vocab_size': input_pipe.total_size, 'freq_dict': input_pipe.dictionary, 'pad_index': input_pipe.pad_index, 'model_dir': model_dir, 'start_token': input_pipe.start_token, 'end_token': input_pipe.end_token, 'pretrain_embedding': input_pipe.load_pretrain_embedding() }) TRAIN_PARAMS = set_encoder_decoder_params(args.cell_type, TRAIN_PARAMS, ED_PARAMS) model_fn = getattr(importlib.import_module('skip_thought_archived.model'), '{}_model'.format(args.model)) estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu, RUN_CONFIG) if args.step == 'train': estimator.train(input_fn=input_pipe.build_dataset()) if args.step == 'predict': # Please disable GPU in prediction to avoid DST exhausted Error prediction = estimator.predict(input_fn=input_pipe.build_dataset( is_predict=1)) res = {} for item in prediction: res[' '.join([i.decode('utf-8') for i in item['input_token'] ])] = item['encoder_state'] with open( './data/{}/{}_predict_embedding.pkl'.format( args.data, args.model), 'wb') as f: pickle.dump(res, f)