Exemple #1
0
def main(args):

    model_dir = CHECKPOINT_DIR.format(args.model, args.train_algo)
    dict_file = DICTIONARY_DIR.format( args.data )
    data_file = './data/{}/corpus_new.txt'.format(args.data)

    if args.clear_model:
        clear_model(model_dir)

    # Init config
    TRAIN_PARAMS = getattr(importlib.import_module('config.{}_config'.format(args.data)), 'TRAIN_PARAMS')
    RUN_CONFIG = getattr( importlib.import_module( 'config.{}_config'.format( args.data ) ), 'RUN_CONFIG' )
    MySpecialToken = getattr( importlib.import_module( 'config.{}_config'.format( args.data ) ), 'MySpecialToken' )

    # Init dataset
    input_pipe = Word2VecDataset( data_file = data_file,
                                  dict_file = dict_file,
                                  window_size= TRAIN_PARAMS['window_size'],
                                  epochs= TRAIN_PARAMS['epochs'],
                                  batch_size=TRAIN_PARAMS['batch_size'],
                                  buffer_size=TRAIN_PARAMS['buffer_size'],
                                  special_token= MySpecialToken,
                                  min_count=TRAIN_PARAMS['min_count'],
                                  max_count=TRAIN_PARAMS['max_count'],
                                  sample_rate = TRAIN_PARAMS['sample_rate'],
                                  model= args.model)

    input_pipe.build_dictionary()

    TRAIN_PARAMS.update(
        {
            'vocab_size': input_pipe.total_size,
            'freq_dict': input_pipe.dictionary,
            'pad_index': input_pipe.pad_index,
            'train_algo': args.train_algo,
            'loss': args.loss,
            'model': args.model
        }
    )

    # Init Estimator
    estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu, RUN_CONFIG)

    train_spec = tf.estimator.TrainSpec( input_fn = input_pipe.build_dataset() )

    eval_spec = tf.estimator.EvalSpec( input_fn = input_pipe.build_dataset(is_predict=1),
                                       steps= 1000,
                                       throttle_secs=60 )

    tf.estimator.train_and_evaluate( estimator, train_spec, eval_spec )
Exemple #2
0
def main(args):
    ## Init directory
    model_dir = CHECKPOINT_DIR.format(args.data, MODEL)
    dict_file = DICTIONARY_DIR.format(args.data)

    # Init config
    TRAIN_PARAMS = ALL_TRAIN_PARAMS[args.data]
    RUN_CONFIG = ALL_RUN_CONFIG[args.data]
    MySpecialToken = getattr(
        importlib.import_module('config.{}_config'.format(args.data)),
        'MySpecialToken')

    if args.step == 'train':
        data_file = {
            'encoder': './data/{}/train_encoder_source.txt'.format(args.data),
            'decoder': './data/{}/train_decoder_source.txt'.format(args.data)
        }
    else:
        data_file = {
            'encoder': './data/{}/dev_encoder_source.txt'.format(args.data),
            'decoder': './data/{}/dev_decoder_source.txt'.format(
                args.data)  # for predict, this can be same as encoder
        }

    if args.clear_model:
        clear_model(model_dir)

    # Init dataset
    input_pipe = Seq2SeqDataset(
        data_file=data_file,
        dict_file=dict_file,
        epochs=TRAIN_PARAMS['epochs'],
        batch_size=TRAIN_PARAMS['batch_size'],
        min_count=TRAIN_PARAMS['min_count'],
        max_count=TRAIN_PARAMS['max_count'],
        buffer_size=TRAIN_PARAMS['buffer_size'],
        special_token=MySpecialToken,
        max_len=TRAIN_PARAMS['max_len'],
        min_len=TRAIN_PARAMS['min_len'],
        pretrain_model_list=TRAIN_PARAMS['pretrain_model_list'])
    input_pipe.build_dictionary()

    TRAIN_PARAMS.update({
        'vocab_size':
        input_pipe.total_size,
        'freq_dict':
        input_pipe.dictionary,
        'pad_index':
        input_pipe.pad_index,
        'model_dir':
        model_dir,
        'start_index':
        input_pipe.start_index,
        'end_index':
        input_pipe.end_index,
        'pretrain_embedding':
        input_pipe.load_pretrain_embedding()
    })

    estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu,
                                RUN_CONFIG)

    if args.step == 'train':
        estimator.train(input_fn=input_pipe.build_dataset())

    if args.step == 'predict':
        # Please disable GPU in prediction to avoid DST exhausted Error
        prediction = estimator.predict(input_fn=input_pipe.build_dataset(
            is_predict=1))
        res = []
        for i in prediction:
            res.append(i)
        with open('./data/{}/{}_predict.pkl'.format(args.data, MODEL),
                  'wb') as f:
            pickle.dump(res, f)
Exemple #3
0
def main(args):

    model_dir = CHECKPOINT_DIR.format(args.data, args.model)
    dict_file = DICTIONARY_DIR.format(args.data)
    data_file = './data/{}/train.tfrecords'.format(args.data)

    if args.clear_model:
        clear_model(model_dir)

    # Init config
    TRAIN_PARAMS = getattr(
        importlib.import_module('config.{}_config'.format(args.data)),
        'TRAIN_PARAMS')
    RUN_CONFIG = getattr(
        importlib.import_module('config.{}_config'.format(args.data)),
        'RUN_CONFIG')
    MySpecialToken = getattr(
        importlib.import_module('config.{}_config'.format(args.data)),
        'MySpecialToken')
    TF_PROTO = getattr(
        importlib.import_module('config.{}_config'.format(args.data)),
        'TF_PROTO')

    # Init dataset
    input_pipe = FasttextDataset(data_file=data_file,
                                 dict_file=dict_file,
                                 epochs=TRAIN_PARAMS['epochs'],
                                 batch_size=TRAIN_PARAMS['batch_size'],
                                 min_count=TRAIN_PARAMS['min_count'],
                                 max_count=TRAIN_PARAMS['max_count'],
                                 buffer_size=TRAIN_PARAMS['buffer_size'],
                                 special_token=MySpecialToken,
                                 ngram=TRAIN_PARAMS['ngram'],
                                 tf_proto=TF_PROTO)
    input_pipe.build_dictionary()

    TRAIN_PARAMS.update({
        'vocab_size': input_pipe.total_size,
        'freq_dict': input_pipe.dictionary,
        'pad_index': input_pipe.pad_index,
        'model_dir': model_dir
    })

    # Init Estimator
    model_fn = getattr(importlib.import_module('model_{}'.format(args.model)),
                       'model_fn')
    estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu,
                                RUN_CONFIG)

    if args.step == 'train':
        early_stopping = tf.estimator.experimental.stop_if_no_decrease_hook(
            estimator,
            metric_name='loss',
            max_steps_without_decrease=100 * 1000)

        train_spec = tf.estimator.TrainSpec(
            input_fn=input_pipe.build_dataset(), hooks=[early_stopping])

        eval_spec = tf.estimator.EvalSpec(
            input_fn=input_pipe.build_dataset(is_predict=1),
            steps=500,
            throttle_secs=60)

        tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

    if args.step == 'predict':
        prediction = estimator.predict(input_fn=input_pipe.build_dataset(
            is_predict=1))
        with open('prediction.pkl', 'wb') as f:
            pickle.dump(prediction, f)
Exemple #4
0
def main(args):

    model_dir = CHECKPOINT_DIR.format(args.data, args.model)
    dict_file = DICTIONARY_DIR.format(args.data)

    # For Quick-thought model,encoder source = decoder source = all continuous sentences
    if args.model == 'quick_thought':
        data_file = {
            'encoder': './data/{}/all_sentences.txt'.format(args.data),
            'decoder': './data/{}/all_sentences.txt'.format(args.data)
        }
    else:
        data_file = {
            'encoder': './data/{}/encoder_source.txt'.format(args.data),
            'decoder': './data/{}/decoder_source.txt'.format(args.data)
        }

    if args.clear_model:
        clear_model(model_dir)

    # Init config
    TRAIN_PARAMS = getattr(
        importlib.import_module('config.{}_config'.format(args.data)),
        'TRAIN_PARAMS')
    RUN_CONFIG = getattr(
        importlib.import_module('config.{}_config'.format(args.data)),
        'RUN_CONFIG')
    MySpecialToken = getattr(
        importlib.import_module('config.{}_config'.format(args.data)),
        'MySpecialToken')
    ED_PARAMS = getattr(
        importlib.import_module('config.{}_config'.format(args.data)),
        'ED_PARAMS')

    # Init dataset
    input_pipe = SkipThoughtDataset(data_file=data_file,
                                    dict_file=dict_file,
                                    epochs=TRAIN_PARAMS['epochs'],
                                    batch_size=TRAIN_PARAMS['batch_size'],
                                    min_count=TRAIN_PARAMS['min_count'],
                                    max_count=TRAIN_PARAMS['max_count'],
                                    buffer_size=TRAIN_PARAMS['buffer_size'],
                                    special_token=MySpecialToken,
                                    max_len=TRAIN_PARAMS['max_decode_iter'],
                                    min_len=TRAIN_PARAMS['min_len'])
    input_pipe.build_dictionary()

    TRAIN_PARAMS.update({
        'vocab_size':
        input_pipe.total_size,
        'freq_dict':
        input_pipe.dictionary,
        'pad_index':
        input_pipe.pad_index,
        'model_dir':
        model_dir,
        'start_token':
        input_pipe.start_token,
        'end_token':
        input_pipe.end_token,
        'pretrain_embedding':
        input_pipe.load_pretrain_embedding()
    })

    TRAIN_PARAMS = set_encoder_decoder_params(args.cell_type, TRAIN_PARAMS,
                                              ED_PARAMS)

    model_fn = getattr(importlib.import_module('skip_thought_archived.model'),
                       '{}_model'.format(args.model))

    estimator = build_estimator(TRAIN_PARAMS, model_dir, model_fn, args.gpu,
                                RUN_CONFIG)

    if args.step == 'train':
        estimator.train(input_fn=input_pipe.build_dataset())

    if args.step == 'predict':
        # Please disable GPU in prediction to avoid DST exhausted Error
        prediction = estimator.predict(input_fn=input_pipe.build_dataset(
            is_predict=1))
        res = {}
        for item in prediction:
            res[' '.join([i.decode('utf-8') for i in item['input_token']
                          ])] = item['encoder_state']

        with open(
                './data/{}/{}_predict_embedding.pkl'.format(
                    args.data, args.model), 'wb') as f:
            pickle.dump(res, f)