示例#1
0
文件: train.py 项目: yushu-liu/JLM
def train_RNNLM(_run):
    # maintain consistency between sacred config and experiment config
    config = ExperimentConfig(**_run.config)

    experiment_dump_path = os.path.join(experiment_path, str(_run._id), "tf_dump")
    if not os.path.exists(experiment_dump_path):
        os.makedirs(experiment_dump_path)

    if config.debug:
        print("Running in debug mode...")

    with tf.Graph().as_default():
        # set random seed before the graph is built
        tf.set_random_seed(config.tf_random_seed)

        model = RNNLM_Model(config, load_corpus=True)

        init = tf.global_variables_initializer()
        saver = tf.train.Saver()

        tf_config = tf.ConfigProto()
        tf_config.gpu_options.allow_growth = True

        with tf.Session(config=tf_config) as session:
            best_val_pp = float('inf')
            best_val_epoch = 0

            tf.summary.FileWriter(os.path.join(train_path, "TensorLog"), session.graph)

            session.run(init)
            for epoch in range(config.max_epochs):
                print('Epoch {}'.format(epoch))
                start = time.time()
                train_pp = model.run_epoch(
                    session, model.encoded_train,
                    train_op=model.train_step)
                print('Training perplexity: {}'.format(train_pp))
                print('Total Training time: {}'.format(time.time() - start))
                valid_pp = model.run_epoch(session, model.encoded_dev)
                print('Validation perplexity: {}'.format(valid_pp))
                if valid_pp < best_val_pp:
                    best_val_pp = valid_pp
                    best_val_epoch = epoch
                    saver.save(session, os.path.join(experiment_dump_path, 'rnnlm.weights'))
                if epoch - best_val_epoch > config.early_stopping:
                    break
                print('Total time: {}'.format(time.time() - start))

            test_pp = model.run_epoch(session, model.encoded_test)
            print('Test perplexity: {}'.format(test_pp))
示例#2
0
文件: weights.py 项目: nonva/JLM
def dump_trained_weights(experiment, verbose):
    config = get_configs(experiment)

    # Still need to load the model to build graph
    # Graph is not saved
    RNNLM_Model(config)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as session:
        session.run(init)
        saver.restore(
            session,
            os.path.join(experiment_path, str(experiment), "tf_dump",
                         'rnnlm.weights'))

        dump_vars = [
            'HMi', 'HMf', 'HMo', 'HMo', 'HMg', 'IMi', 'IMf', 'IMo', 'IMg',
            'LM', 'bi', 'bf', 'bo', 'bg', 'b2'
        ]

        if config.share_embedding:
            dump_vars += ['PM']
        else:
            dump_vars += ['UM']

        if config.V_table:
            dump_vars.remove('LM')
            for i, seg in enumerate(config.embedding_seg):
                if i != 0:
                    dump_vars += ['VT{}'.format(i)]
                dump_vars += ['LM{}'.format(i)]

        weight_dict = tf_weights_to_np_weights_dict(session, dump_vars)

        if config.D_softmax:
            # instead save the full patched embedding, split each block in "LM" into list of matrices
            blocks = []
            col_s = 0
            for size, s, e in config.embedding_seg:
                if e is None:
                    e = weight_dict['LM'].shape[0]
                blocks.append(weight_dict['LM'][s:e, col_s:col_s + size])
                col_s += size
            weight_dict['LM'] = blocks

        weight_dump_dir = os.path.join(experiment_path, str(experiment),
                                       "weights")
        dump_weights(weight_dict, weight_dump_dir, verbose)
示例#3
0
def auto_generate_sentence(experiment=1):
    gen_config = get_configs(experiment)
    gen_config.batch_size = gen_config.num_steps = 1

    gen_model = RNNLM_Model(gen_config)

    init = tf.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as session:
        session.run(init)
        saver.restore(
            session,
            os.path.join(experiment_path, str(experiment), "tf_dump",
                         'rnnlm.weights'))
        starting_text = '<eos>'
        while starting_text:
            sen = generate_sentence(session,
                                    gen_model,
                                    gen_config,
                                    starting_text=starting_text,
                                    temp=1.0)
            print(' '.join([w.split('/')[0] for w in sen]))
            starting_text = input('> ')
示例#4
0
文件: train.py 项目: nonva/JLM
def create_RNNLM(config):
    return RNNLM_Model(config, load_corpus=True)