def disambiguation_preprocessing(inputs): glove = Glove300() input_pipe = LmInputDataPipeline(glove) t_words = tf.placeholder(dtype=tf.string) t_vocab_ids = glove.word_to_id_op()(t_words) t_genralized_ids = input_pipe._vocab_generalized.vocab_id_to_generalized_id( )(t_vocab_ids) meanings_all = set() for sentence in inputs: for word in sentence: for meaning in word.split("^"): meanings_all.add(meaning) meanings_all = list(meanings_all) with tf.Session() as sess: sess.run(tf.tables_initializer()) sess.run(tf.global_variables_initializer()) glove.after_session_created_hook_fn(sess) ids_all = sess.run(t_genralized_ids, feed_dict={t_words: meanings_all}) mapping = {meaning: id for meaning, id in zip(meanings_all, ids_all)} sentences_as_ids = [] for sentence in inputs: sentence_as_ids = [] for word in sentence: allowables = [] for meaning in word.split("^"): allowables.append(mapping[meaning]) sentence_as_ids.append(allowables) sentences_as_ids.append(sentence_as_ids) return sentences_as_ids
def test_two_way_words_id_transformation_glove_estimator(): glove = Glove300() def get_input(): dataset = get_test_dataset() dataset = text_dataset_to_token_ids(dataset, glove.word_to_id_op) dataset = token_ids_to_text_dataset(dataset, glove.id_to_word_op) return dataset.batch(1) def mock_model_fn(features, labels, mode, params): if mode == tf.estimator.ModeKeys.PREDICT: spec = tf.estimator.EstimatorSpec(mode=mode, predictions=features) else: spec = tf.estimator.EstimatorSpec( mode=mode, predictions=features, train_op=tf.train.get_global_step().assign_add( 1 ), # without it "training" would last forever regardles for number of steps loss=tf.constant(0)) return spec estimator = tf.estimator.Estimator(mock_model_fn) r = estimator.train(get_input, max_steps=1) r = estimator.predict(get_input) l1, l2, l3 = islice(r, 3) assert (l1 == s1.encode().split()).all() assert (l2 == s2.encode().split()).all() assert (l3 == s3.encode().split()).all()
def test_words_to_id_glove(): glove = Glove300() dataset = get_test_dataset() dataset = text_dataset_to_token_ids(dataset, glove.word_to_id_op) it = dataset.make_initializable_iterator() next_element = it.get_next() with tf.Session() as sess: sess.run(tf.tables_initializer()) sess.run(it.initializer) l1 = sess.run(next_element) l2 = sess.run(next_element) l3 = sess.run(next_element) assert l1 == approx(np.array([96, 21, 30, 40, 536, 21594])) assert l2 == approx( np.array([ 42, 212, 2, 94, 8707, 1126, 2616, 127, 40, 1213, 3195, 15014, 28, 2, 70717, 23749, 3660, 1149, 29571, 1630, 504, 122, 5, 21, 7, 2, 1009, 1193, 21, 5333, 2470, 4, 767, 27, 401, 5, 10436 ])) assert l3 == approx( np.array([ 85, 3972, 35185, 2545, 149, 2, 28428, 1630, 4575, 2377, 58, 106, 804, 12445, 216, 3644, 4, 970, 2, 1929, 8618, 7, 227, 5881, 3, 14521 ]))
def test_words_to_id_glove_estimator(): glove = Glove300() def get_input(): dataset = get_test_dataset() dataset = text_dataset_to_token_ids(dataset, glove.word_to_id_op) return dataset.batch(1) def mock_model_fn(features, labels, mode, params): if mode == tf.estimator.ModeKeys.PREDICT: spec = tf.estimator.EstimatorSpec(mode=mode, predictions=features) else: spec = tf.estimator.EstimatorSpec( mode=mode, predictions=features, train_op=tf.train.get_global_step().assign_add( 1 ), # without it "training" would last forever regardles for number of steps loss=tf.constant(0)) return spec estimator = tf.estimator.Estimator(mock_model_fn) r = estimator.train(get_input, max_steps=1) r = estimator.predict(get_input) l1, l2, l3 = islice(r, 3) assert l1 == approx(np.array([96, 21, 30, 40, 536, 21594])) assert l2 == approx( np.array([ 42, 212, 2, 94, 8707, 1126, 2616, 127, 40, 1213, 3195, 15014, 28, 2, 70717, 23749, 3660, 1149, 29571, 1630, 504, 122, 5, 21, 7, 2, 1009, 1193, 21, 5333, 2470, 4, 767, 27, 401, 5, 10436 ])) assert l3 == approx( np.array([ 85, 3972, 35185, 2545, 149, 2, 28428, 1630, 4575, 2377, 58, 106, 804, 12445, 216, 3644, 4, 970, 2, 1929, 8618, 7, 227, 5881, 3, 14521 ]))
def train_lm_on_simple_examples_with_glove(model_dir): glove = Glove300() def create_input(): simple_examples = SimpleExamplesCorpus() train_data = simple_examples.get_tokens_dataset( DatasetType.TRAIN).repeat().shuffle(1000, seed=0) input_pipe = LmInputDataPipeline(glove, 8) return input_pipe.load_data(train_data) def model_function(features, labels, mode, params): input_pipe = LmInputDataPipeline(glove) vocab_size = glove.vocab_size() id_to_embeding_fn = input_pipe.get_id_to_embedding_mapping() with tf.device(device_assignment_function): concrete_model_fn = get_autoregressor_model_fn( vocab_size, id_to_embeding_fn) estimator_spec = concrete_model_fn(features, labels, mode, params) training_hooks = [InitializeVocabularyHook(glove)] estimator_spec_with_hooks = tf.estimator.EstimatorSpec( mode=estimator_spec.mode, loss=estimator_spec.loss, train_op=estimator_spec.train_op, eval_metric_ops=estimator_spec.eval_metric_ops, predictions=estimator_spec.predictions, training_hooks=training_hooks) return estimator_spec_with_hooks params = {"learning_rate": 0.05, "number_of_alternatives": 1} estimator = tf.estimator.Estimator(model_function, params=params, model_dir=model_dir) t1 = datetime.datetime.now() estimator.train(create_input, max_steps=4) t2 = datetime.datetime.now() print("start:", t1) print("stop:", t2) print("duration:", t2 - t1)
def test_word_to_id_and_then_vector(): tokens_input = tf.constant(["no", "it", "was", "n't", "black", "monday"], dtype=tf.string) glove = Glove300() ids = glove.word_to_id_op()(tokens_input) vectors = glove.id_to_vector_op()(ids) with tf.Session() as sess: sess.run(tf.tables_initializer()) glove.initialize_embeddings_in_graph(tf.get_default_graph(), sess) r_vectors = sess.run(vectors) assert r_vectors == approx( np.array([ [ -0.14121, 0.034641, -0.443, -0.093265, -0.010022, -0.069041, 0.16335, -0.12964, 0.0045672, 2.3127, -0.12048, 0.054694, -0.22722, 0.059882, -0.28076, -0.2715, 0.17744, 1.4719, 0.14243, 0.25179, 0.039256, -0.19574, 0.25275, -0.12224, -0.23064, -0.0449, 0.18679, -0.27084, 0.67684, -0.13295, 0.13029, 0.2128, -0.25393, -0.34708, -0.013974, 0.17852, 0.16488, 0.080326, 0.029319, -0.56489, -0.17003, 0.20811, 0.43094, 0.2132, 0.26778, 0.063854, -0.23329, 0.18415, 0.14159, 0.10566, 0.042333, 0.16718, 0.14764, 0.051008, 0.07869, 0.29462, -0.031126, -0.024006, -0.13177, -0.38212, 0.049503, 0.08338, 0.17229, 0.10892, 0.40207, 0.16887, 0.20803, -0.16576, -0.10935, 0.25171, 0.2537, 0.12471, -0.065506, 0.11825, -0.083037, -0.12088, 0.17466, -0.12045, 0.42763, 0.65073, 0.065299, -0.18887, -0.40152, -0.078146, -0.45914, -0.096453, 0.36708, -0.28231, 0.38404, -0.07597, -0.1878, 0.11948, -0.22832, -0.16095, 0.14309, -0.0090158, 0.2809, 0.023625, 0.44597, -0.25256, -0.62236, 0.5481, -0.3839, 0.0094859, 0.2257, -0.99585, -0.28107, 0.067278, -0.10536, -0.049949, -0.025037, 0.070037, -0.14745, -0.053963, 0.37517, -0.31097, 0.10935, -0.12523, -0.031915, -0.43703, -0.12165, 0.09749, 0.073047, 0.049151, -0.21212, -0.15012, -0.022766, -0.30876, 0.028561, -0.10836, 0.069416, -0.10536, -0.16433, -0.32558, 0.50645, 0.13393, -0.21098, -0.029829, 0.093212, -0.45122, -1.6421, 0.078953, 0.35313, -0.28202, 0.26932, -0.0094641, 0.099173, 0.074177, -0.29891, 0.056616, 0.25049, -0.45163, 0.49712, 0.11657, -0.15597, -0.028287, 0.072622, -0.26618, 0.18588, 0.043537, -0.16162, -0.17738, -0.12787, 0.12671, -0.12695, 0.13798, -0.17422, 0.19985, 0.18507, -0.02821, -0.27801, -0.11924, 0.27196, -0.093397, -0.24152, 0.71304, 0.058818, 0.23003, -0.18196, -0.038031, 0.061856, -0.095681, 0.20094, -0.059437, -0.32578, 0.23677, 0.18845, -0.093786, 0.29071, 0.074056, 0.16738, 0.31971, -0.48415, -0.076829, 0.0065072, -0.12015, -0.12628, 0.0627, 0.1041, 0.81267, 0.13162, 0.37337, -0.20733, -0.1292, 0.38116, 0.2539, 0.1688, -0.26463, 0.10273, 0.28119, -0.15295, 0.1548, 0.24093, -0.18426, 0.23231, -0.19153, -0.15871, 0.1979, -0.10288, 0.2818, -0.44362, 0.01553, -0.064855, 0.16203, 0.15307, 0.35869, -0.0072469, -0.056632, -0.18384, -0.020465, -0.059468, -0.1433, 0.3238, -0.16607, -0.014596, 0.36057, 0.54558, -0.34755, -0.22197, 0.034603, 0.089877, 0.66228, -0.094153, 0.17281, 0.043724, -0.23963, 0.0040285, 0.10264, -0.060451, 0.30558, -0.12715, -0.44602, -0.36197, -0.20433, -0.15639, 0.75049, -0.49277, -0.314, 0.23212, 0.14506, -0.10745, 0.26306, -0.13694, 0.49217, -0.3333, 0.13349, 0.33744, -0.04892, 0.7233, -0.035786, 0.36221, 0.28324, -0.18857, -0.0158, 0.19572, 0.14628, -0.0049576, -0.20363, -0.21408, 0.21958, -0.1376, 0.051295, -0.035402, -0.33176, -0.39541, 0.16886, -0.36042, -0.18925, -0.10028, -0.18858, -0.22911, -0.09778, -0.27021, -0.034178, 0.36786, 0.00639, -0.039546, -0.29866, 0.013515, 0.025409 ], [ 0.0013629, 0.35653, -0.055497, -0.16607, 0.0031402, -0.061926, -0.24759, -0.22897, -0.09105, 2.6751, -0.15062, 0.072403, 0.0061949, -0.0065698, -0.26418, -0.19543, -0.15048, 1.2156, -0.12551, -0.12572, 0.023065, 0.024727, 0.14311, 0.10148, -0.10566, 0.07864, -0.10306, -0.11968, 0.04202, -0.36815, -0.087136, 0.38589, 0.0044597, -0.18259, -0.1226, -0.10454, 0.16039, 0.27415, 0.042427, -0.049497, 0.041286, 0.12223, 0.10821, -0.056199, 0.21754, 0.10983, -0.38878, -0.10935, -0.36647, 0.1342, -0.076634, 0.38148, -0.19979, 0.09391, 0.35189, -0.11133, 0.095313, -0.29593, 0.29022, -0.1966, -0.10331, -0.21995, -0.041991, 0.16631, 0.01523, -0.29185, -0.05472, -0.040665, 0.084861, -0.009206, 0.24625, 0.081873, 0.34256, -0.16768, -0.079394, 0.13206, 0.2156, -0.11199, -0.39589, 0.32299, 0.089602, -0.026041, -0.23981, 0.049861, 0.055241, -0.50554, 0.23002, -0.54613, 0.58194, 0.096957, -0.015559, 0.069833, -0.009668, 0.19936, 0.19006, 0.32913, -0.064844, -0.22404, -0.031196, 0.1818, -0.071896, -0.072126, -0.082155, 0.064145, 0.11215, -1.0712, 0.29581, 0.081019, -0.24954, -0.087734, 0.015893, -0.15779, 0.055281, -0.080948, -0.11288, 0.099631, -0.17395, 0.1223, -0.099904, -0.20707, -0.014483, -0.10597, -0.0031063, 0.04332, -0.20624, 0.25975, -0.12992, -0.32278, 0.17298, -0.15952, -0.19577, -0.22784, 0.022428, 0.19393, -0.059827, -0.011456, 0.17045, -0.041847, -0.1288, -0.067707, -1.9181, 0.29674, 0.27561, 0.29993, 0.17498, -0.25321, -0.2017, 0.13409, -0.049065, 0.13186, -0.10889, 0.23631, 0.17895, 0.024289, 0.0092826, -0.29032, -0.27692, -0.22906, -0.13153, 0.10656, -0.25672, 0.12929, 0.10399, -0.098408, -0.42153, -0.1332, 0.091175, 0.040061, 0.2084, -0.17849, -0.057709, -0.10256, -0.095817, -0.43439, -0.064794, 0.12916, 0.085435, -0.3746, -0.069798, 0.020042, 0.041425, -0.021527, -0.086333, -0.0095633, 0.025466, -0.16101, -0.089574, -0.21178, 0.088594, 0.087381, 0.052047, -0.20386, -0.29424, 0.10097, 0.076137, 0.10431, -0.19752, -0.34268, 0.058982, 0.26035, -0.16364, 0.03294, -0.30368, 0.0734, 0.071074, 0.17129, -0.14442, -0.041817, 0.20912, 0.032747, -0.10649, -0.33475, 0.0088305, -0.32619, 0.21179, 0.29881, -0.013775, -0.090821, -0.33841, -0.1129, 0.12137, 0.059202, -0.12133, -0.093398, 0.15426, -0.032649, -0.11216, 0.28842, -0.036565, -0.041662, -0.22413, 0.060877, 0.21641, 0.30208, 0.16084, -0.027118, 0.26084, -0.090324, -0.1036, -0.01901, 0.34244, 0.025017, 0.025958, 0.21387, 0.43512, -0.67789, -0.0039166, -0.30027, -0.058978, 0.17072, 0.14497, -0.19255, 0.13603, -0.16038, -0.075959, 0.2282, 0.20681, -0.011637, 0.086048, -0.034803, 0.23821, 0.21667, 0.10353, -0.012959, 0.36174, -0.12104, -0.033488, -0.030755, 0.43549, 0.1896, 0.45975, -0.34826, -0.16406, -0.12197, -0.064298, 0.19573, 0.017949, -0.12379, -0.0081198, 0.4002, 0.17065, -0.10712, 0.088398, -0.11473, -0.069708, -0.09321, 0.25621, -0.035815, 0.15968, -0.37266, -0.24035, -0.089325, 0.10603, -0.16025, -0.054419, -0.30824, -0.26249, -0.11237, 0.078259, 0.22398 ], [ -0.044058, 0.36611, 0.18032, -0.24942, -0.098095, 0.033261, 0.119, -0.51164, -0.16415, 3.136, 0.20901, 0.29082, 0.25193, -0.020379, -0.24789, -0.47501, -0.038328, 0.56434, -0.038566, -0.11559, 0.024392, -0.45873, -0.10009, 0.21731, 0.16996, -0.12939, 0.0063318, -0.017798, -0.18673, -0.1167, -0.14384, -0.0097187, 0.45289, -0.036453, -0.40523, -0.31816, -0.23389, -0.012272, -0.21479, -0.17841, 0.34474, 0.31133, 0.20543, -0.1896, 0.38995, 0.12103, -0.33685, -0.57051, 0.20732, 0.087872, 0.071458, 0.046355, -0.17425, 0.27856, 0.35989, -0.017122, 0.12197, -0.35806, 0.33181, -0.19827, -0.10386, -0.096699, 0.094231, 0.46722, -0.36612, -0.038628, 0.063485, -0.25765, -0.20415, 0.075931, 0.085753, 0.28176, -0.12443, -0.19756, 0.17218, -0.20121, 0.048154, 0.1301, -0.51096, 0.41643, 0.16487, 0.083688, 0.025331, 0.0014575, 0.26935, -0.46159, 0.18639, -0.6424, -0.2277, 0.032521, 0.050105, 0.1683, -0.27886, -0.037346, 0.50521, -0.39343, 0.25004, -0.091487, 0.044709, 0.15579, -0.19423, 0.29651, -0.27465, -0.33689, -0.11362, -0.43028, 0.016673, -0.015717, 0.15385, -0.30998, -0.17927, -0.002689, -0.029884, -0.18535, -0.079747, -0.31545, 0.0024644, 0.19685, -0.061948, -0.2724, 0.0372, 0.24951, 0.15755, -0.084023, -0.24132, 0.35744, -0.16309, -0.67866, -0.27942, -0.016828, 0.017248, -0.060522, -0.26155, 0.16951, 0.50993, -0.46213, -0.019627, 0.3955, 0.0053794, -0.13616, -1.3947, 0.24283, 0.33351, 0.18875, 0.33386, -0.1979, -0.45546, -0.14531, 0.32496, -0.24984, -0.38316, -0.047484, 0.3163, -0.27841, -0.31328, -0.13258, -0.15671, 0.050417, 0.2073, -0.13118, -0.40559, -0.34316, 0.14348, -0.45976, -0.48611, -0.32394, -0.19056, 0.16412, 0.22827, -0.054174, 0.039441, 0.079182, -0.034827, -0.043719, -0.56115, -0.18462, 0.012758, -0.058201, -0.4096, -0.28184, -0.035173, -0.27668, -0.44195, -0.094452, -0.36051, -0.23688, -0.22469, 0.22704, 0.070153, 0.079784, -0.050581, 0.19954, -0.53252, 0.38514, 0.20942, 0.3133, 0.37957, -0.31456, -0.22611, -0.14732, 0.12792, 0.026238, -0.19538, 0.2053, 0.18387, 0.070116, 0.20402, -0.057152, 0.16134, 0.023932, 0.04476, -0.0031943, 0.0076469, -0.032653, 0.39232, 0.11799, -0.18832, -0.21732, -0.038809, -0.19023, -0.067095, 0.021589, 0.03139, 0.27935, -0.25991, -0.010694, 0.071357, 0.20587, 0.030717, 0.14273, -0.012696, 0.30787, 0.1761, -0.23735, 0.10864, -0.34518, 0.051447, 0.060717, -0.050337, -0.018071, -0.39068, -0.0020948, 0.21507, 0.30334, 0.079873, -0.135, -0.0033115, -0.43378, 0.14857, -0.028767, -0.091394, -0.11293, 0.14341, -0.02577, 0.3054, -0.56747, 0.30705, -0.085973, -0.021836, 0.14566, 0.57363, 0.27721, -0.25141, 0.12354, 0.0045573, 0.10348, 0.14283, 0.086515, -0.11795, 0.070627, 0.455, 0.14827, -0.33691, -0.26387, -0.40101, -0.034913, 0.032671, -0.42077, 0.058225, 0.38307, 0.59657, 0.33333, 0.025108, -0.10701, 0.030241, -0.079168, -0.02454, 0.24922, 0.061272, 0.012772, -0.019862, 0.082316, 0.49588, 0.09668, 0.43798, 0.062743, -0.053951, 0.18625, -0.097817, -6.7104e-05 ], [ -0.13019, 0.27764, -0.24159, -0.1229, -0.099673, 0.1548, 0.045905, -0.048852, -0.070227, 2.6683, 0.018325, -0.14477, 0.54896, 0.13912, -0.5663, -0.076259, 0.020997, 0.53276, -0.31074, 0.12863, 0.21444, 0.13455, 0.2385, 0.048554, -0.3388, 0.021484, -0.3094, -0.298, 0.54086, -0.53253, -0.26099, 0.26757, -0.056835, -0.15544, 0.27707, 0.085732, 0.32584, 0.088359, -0.1003, -0.13693, -0.0411, 0.23295, -0.091702, 0.081552, 0.20786, -0.031941, -0.21681, 0.042795, -0.11452, 0.059274, -0.22558, -0.11495, 0.1403, -0.10689, 0.41254, -0.10124, 0.02884, -0.41875, 0.10929, -0.15641, 0.039418, -0.28243, -0.1422, 0.47574, 0.26084, -0.12307, -0.1724, -0.10748, 0.31673, 0.25172, 0.36466, 0.38725, 0.28571, -0.093661, 0.055286, 0.2427, 0.17109, -0.15716, 0.14724, 0.36599, -0.065697, -0.23486, -0.19004, -0.042001, -0.2401, -0.17693, 0.14648, -0.5474, 0.7324, -0.15878, -0.34187, -0.22201, -0.15924, 0.31443, 0.23653, 0.055091, 0.28245, 0.019587, 0.2493, 0.044079, -0.22886, 0.36507, -0.076332, -0.073083, 0.23082, -0.47279, 0.020098, -0.18732, -0.15627, 0.29637, 0.081997, -0.019479, 0.31497, -0.26458, 0.17927, -0.11421, 0.1828, -0.19653, -0.15288, -0.0028553, -0.020426, -0.22377, -0.070952, -0.36273, 0.25286, -0.017657, 0.070154, -0.52035, -0.06306, -0.063401, 0.031118, 0.17004, -0.12248, 0.344, 0.16571, -0.15052, 0.052269, -0.10966, -0.1355, -0.23961, -2.2062, 0.14258, 0.16897, 0.035696, 0.1303, -0.10539, -0.2026, 0.12489, -0.01408, -0.22252, 0.40705, 0.1008, 0.032245, 0.39523, -0.077507, -0.21702, -0.09814, -0.20212, -0.019024, 0.012406, -0.33449, 0.24188, -0.2918, 0.047731, -0.2364, -0.041121, 0.16712, -0.25078, 0.30386, -0.15052, 0.025404, -0.19863, 0.11755, -0.35061, 0.26133, 0.37007, 0.13698, 0.11351, 0.15074, -0.1246, -0.067939, 0.01238, -0.42154, -0.0091529, -0.10902, -0.086627, -0.059834, 0.10847, 0.013728, 0.043898, 0.16817, 0.043819, -0.29907, -0.17329, 0.21429, 0.13766, -0.088914, -0.22155, 0.34939, 0.48677, -0.13624, -0.086712, 0.027144, -0.24487, 0.043639, 0.13863, 0.012576, -0.091621, 0.35802, 0.084971, 0.019817, -0.20481, -0.15568, -0.52826, 0.064368, 0.060007, -0.24596, 0.15001, -0.60509, -0.15379, 0.016719, 0.15949, -0.097624, -0.07951, 0.19468, 0.098393, 0.15924, 0.15078, -0.091977, -0.0268, -0.21947, -0.19844, 0.58255, 0.074457, -0.18159, 0.10575, 0.39545, -0.62542, -0.070905, 0.29926, 0.23098, -0.058532, 0.079409, 0.21516, -0.0001531, -0.35988, -0.2874, -0.024583, -0.30688, 0.61378, -0.2017, -0.18743, -0.22567, -0.11391, -0.13839, 0.065526, -0.38524, -0.26101, 0.15433, 0.19295, 0.016614, 0.27393, 0.015089, 0.17114, 0.36981, 0.44011, -0.0013757, 0.15895, 0.52241, 0.21377, 0.099801, -0.07774, -0.02571, -0.45929, 0.009682, -0.015693, -0.012355, -0.15352, 0.040929, 0.24291, -0.099217, -0.12023, 0.048583, -0.22753, -0.40505, -0.23716, -0.011524, -0.15346, 0.068119, -0.035336, -0.34307, 0.065718, -0.026112, 0.083108, 0.27713, 0.020035, -0.20193, -0.17143, 0.55838, 0.19698 ], [ -0.29365, -0.049916, 0.096439, -0.089388, 0.27109, 0.057496, -0.50298, 0.11331, -0.19913, 1.0869, -0.36474, 0.18028, -0.2439, -0.84879, -0.09803, 0.22358, 0.16649, 1.8263, -0.30784, -0.45779, -0.13423, -0.7684, 0.061036, 0.13364, -0.07578, -0.36814, -0.56498, 0.11553, 0.18909, 0.069852, 0.10334, 0.54858, -0.017279, -0.42885, 0.17587, -0.48115, -0.21931, -0.39983, -0.05173, -0.46209, 0.46579, 0.21905, -0.14852, 0.11248, 0.21266, -0.13285, -0.1344, 0.22768, 0.38002, -0.31141, -0.75913, 0.34262, -0.57856, -0.44662, 0.17095, 0.13949, -0.28634, 0.066538, -0.21849, -0.48396, -0.73416, -0.45858, 0.20657, 0.0091145, -0.0039049, 0.01489, -0.25298, -0.022714, -0.027294, 0.41785, 0.11382, -0.33901, -0.032653, 0.042876, -0.1628, -0.083524, -0.36741, -0.26457, 0.053942, -0.01116, -0.50069, -0.16943, 0.10525, -0.030164, 0.4385, -0.13928, 1.141, 0.76126, 0.074075, -0.028966, 0.066959, 0.20611, 0.27884, -0.17062, 0.0044823, -0.46235, -0.052986, 0.50416, -0.018854, -0.38912, 0.57516, 0.61789, 0.45961, -0.1963, -0.51927, -0.51316, -0.8881, 0.28339, 0.032175, 0.26376, -0.47802, -0.35921, -0.50878, -0.1828, 0.26999, 0.24097, 0.099165, -0.031377, 0.089655, 0.32511, -0.42431, 0.01075, -0.32665, 0.15986, 0.16415, 0.38453, 0.24862, -0.31164, 0.16802, -0.38192, 0.092993, -0.033324, -0.13209, 0.038213, -0.0029631, 0.06452, 0.0079986, -0.50266, -0.018759, 0.05632, -3.0279, -0.079183, 0.70083, 0.2262, 0.36396, -0.096987, 0.19656, 0.012033, 0.23194, -0.030562, -0.28404, -0.37286, -0.005297, -0.33137, -0.44292, 0.28554, -0.71202, -0.0015515, 0.0093941, 0.31106, -0.20186, -0.10606, -0.0098406, 0.083881, 0.0014653, -0.43426, -0.13004, -0.14525, 0.24627, -0.038385, -0.33198, 0.4009, -0.053365, 0.47144, -0.18795, 0.25009, -0.22505, 0.10527, 0.4418, 0.18197, -0.4826, 0.51301, -0.21059, -0.51911, -0.18121, 0.69244, -0.36925, 0.13242, 0.17995, 0.024023, -0.092837, -0.16256, -0.25677, 0.058971, 0.4761, -0.12983, 0.019869, 0.22802, -0.36084, -0.091776, 0.45292, -0.027555, -0.15405, -0.30351, 0.16619, -0.074507, 0.12211, -0.14763, -0.1045, 0.39327, -0.058905, 0.6207, -0.49493, 0.023326, 0.37233, 0.032352, -0.65445, -0.32216, 0.39367, -0.12799, -0.78568, -0.13649, -0.59398, -0.039309, -0.16203, -0.088509, 0.14446, -0.14543, 0.17516, 0.67057, -0.31062, -0.31735, 0.48737, 0.51206, 0.12244, 0.58553, -0.3483, -0.070485, 0.65111, 0.49588, -0.042622, 0.085238, -0.24129, -0.61676, 0.065639, 0.21727, -0.31657, -0.20381, -0.18905, -0.0026379, -0.16428, 0.29292, -0.043597, -0.10713, 0.015803, 0.10977, 0.099193, 0.058263, -0.22138, 0.53114, 0.2194, 0.46687, -0.22339, 0.45082, -0.34546, -0.10945, 0.013951, -0.22981, -0.61019, 0.53618, -0.38039, -0.3018, 0.044355, -0.47215, 0.094294, -0.30885, -0.16255, 0.35686, -0.0010873, -0.13689, -0.24389, 0.64798, 0.19567, -0.17806, -0.46973, -0.026857, 0.25365, 0.099388, 0.057244, -0.32616, -0.59946, -0.070698, 0.044969, -0.83205, -0.37187, 0.28149, 0.1978, 0.047221, -0.22288, 0.017735 ], [ 0.031091, 0.56825, -0.03107, 0.004301, -0.03025, -0.2201, 0.016359, -0.27483, 0.54576, 0.69811, -0.92913, -0.32617, 0.34225, -0.36393, 0.17427, 0.10333, -0.22877, 0.62709, 0.40462, -0.27718, 0.051787, 0.26553, 0.0028972, -0.37731, -0.092281, 0.26781, -0.51189, -0.34465, -0.089694, -0.13627, -0.073969, -0.23845, 0.4223, -0.2777, -0.68097, 0.63363, 0.084718, 0.27264, -0.24687, -0.0064634, -0.19284, 0.022436, -0.15938, 0.50185, -0.5157, 0.43452, -0.07777, -0.14402, 0.21616, -0.11667, 0.013378, 0.12769, 0.31937, 0.052952, 0.20743, 0.25537, -0.42906, 0.058272, 0.24935, -0.38469, -0.70756, -0.28694, 0.21439, -0.49421, 0.071887, 0.028562, -0.399, 0.43777, -0.15642, 0.39649, -0.17438, -0.23927, 0.038615, 0.2249, 0.72184, 0.36085, 0.0488, 0.73647, -0.155, 0.57072, -0.040056, 0.27431, -0.12806, -0.1526, 0.14863, -0.0065193, 1.0662, -0.36083, 0.70706, 0.043639, 0.47063, 0.027411, -0.23215, -0.46128, -0.049949, 0.020991, -0.041238, 0.43573, 0.30043, 0.35162, -0.29331, -0.68115, 0.25255, 0.23526, -0.13769, -0.62365, 0.5771, -0.13667, -0.025982, -0.0593, -0.016144, -0.67935, 0.18109, 0.055385, -0.16223, 0.43419, 0.23547, 0.1063, 0.14108, 0.13307, 0.58862, 0.091575, -0.18413, 0.19386, 0.1448, -0.095191, 0.19922, 0.075761, -0.019164, -0.077082, 0.23526, -0.80003, -0.27262, 0.23965, -0.026556, -0.089179, -0.0030789, -0.20901, -0.13612, -0.22583, -2.2758, -0.18385, -0.16255, -0.42371, 0.15869, 0.35939, 0.074838, 0.24797, -0.05182, -0.26273, -0.11361, -0.42738, 0.083827, -0.20052, 0.13626, -0.28534, 0.22639, -0.32933, -0.11853, 0.085149, -0.003949, -0.35466, -0.23401, -0.13937, 0.0068301, -0.12079, -0.1791, 0.73794, -0.24173, 0.18064, -0.044553, 0.094685, -0.3187, 0.0014406, -0.42592, 0.29623, -0.17538, -0.51084, -0.041933, 0.18008, -0.16072, -0.062, -0.2527, 0.35847, 0.18067, -0.11788, -0.038845, 0.24696, 0.16047, -0.0051667, 0.24596, -0.24756, -0.15568, -0.37533, 0.098757, 0.64188, -0.0036217, -0.050042, 0.2193, -0.29602, -0.20227, 0.21918, 0.36836, 0.29122, 0.16094, -0.70123, -0.49206, 0.32254, 0.26288, 0.012186, 0.32156, 0.49782, 0.0039978, 0.36138, -0.27197, -0.30703, 0.093815, -0.76536, -0.3478, 0.48628, 0.44034, -0.62981, -0.49056, 0.12199, -0.15922, -0.081561, -0.14558, -0.16878, -0.37092, -0.33377, -0.29117, 0.591, 0.052894, -0.028036, -0.10446, -0.37111, -0.38053, -0.22004, -0.22515, 0.25515, -0.11202, -0.5538, 0.086523, 0.48785, 0.072203, 0.29461, 0.23643, -0.11222, -0.092494, -0.28043, -0.13144, 0.234, 0.11143, -0.67456, 0.43617, 0.023155, -0.57365, -0.39816, -0.73945, -0.44138, 0.21267, -0.018604, -0.25674, -0.025934, -0.23015, -0.25172, -0.40583, -0.083189, -0.054541, 0.15206, -0.31548, -0.14732, -0.23183, 0.75427, 0.018009, 0.30702, -0.22941, -0.013627, 0.38182, 0.26575, 0.63149, -0.88812, -0.6673, 0.080639, 0.23583, 0.49035, 0.10201, -0.16356, -0.11952, 0.60617, 0.38027, -0.0078457, 0.039968, -0.007106, 0.32, -0.26781, 0.41864, 0.12264, -0.43825, 0.090428 ], ]))
def disambiguation_with_glove(input_sentence, model_dir, hparams): """Input sentence is a list of lists of possible words at a given position""" glove = Glove300() BATCH_SIZE = 1 def create_input(): def data_gen(): yield ({ "inputs": np.array([[ LmInputDataPipeline( glove, None)._vocab_generalized.get_special_unit_id( SpecialUnit.START_OF_SEQUENCE) ]], dtype=np.int32), "length": len(input_sentence) }, np.array([0])) data = tf.data.Dataset.from_generator(data_gen, output_types=({ "inputs": tf.int32, "length": tf.int32 }, tf.int32), output_shapes=({ "inputs": ( 1, 1, ), "length": () }, (1, ))) return data def model_function(features, labels, mode, params): input_pipe = LmInputDataPipeline(glove) vocab_size = glove.vocab_size() embedding_size = input_pipe._vocab_generalized.vector_size() id_to_embeding_fn = input_pipe.get_id_to_embedding_mapping( ) if mode == tf.estimator.ModeKeys.PREDICT else lambda x: tf.zeros( (tf.shape(x), embedding_size), tf.float32) #with tf.device(device_assignment_function) if hparams.size_based_device_assignment else without: with tf.device("/device:CPU:0"): concrete_model_fn = get_autoregressor_model_fn( vocab_size, id_to_embeding_fn, time_major_optimization=True, predict_as_pure_lm=False, mask_allowables=input_sentence, hparams=hparams) estimator_spec = concrete_model_fn(features, labels, mode, params) training_hooks = [] to_restore = tf.contrib.framework.get_variables_to_restore() predictions = estimator_spec.predictions if mode == tf.estimator.ModeKeys.PREDICT: training_hooks.append(InitializeVocabularyHook(glove)) predicted_ids = tf.cast(predictions["paths"], dtype=tf.int64) words_shape = tf.shape(predicted_ids) to_vocab_id = input_pipe._vocab_generalized.generalized_id_to_vocab_id( ) to_word = glove.id_to_word_op() predicted_ids = tf.reshape(predicted_ids, shape=[-1]) predicted_words = to_word(to_vocab_id(predicted_ids)) predicted_words = tf.reshape(predicted_words, shape=words_shape) predictions["predicted_words"] = predicted_words if hparams.profiler: training_hooks.append( tf.train.ProfilerHook(output_dir=model_dir, save_secs=30, show_memory=True)) training_hooks.append(FullLogHook()) estimator_spec_with_hooks = tf.estimator.EstimatorSpec( mode=estimator_spec.mode, loss=estimator_spec.loss, train_op=estimator_spec.train_op, eval_metric_ops=estimator_spec.eval_metric_ops, predictions=estimator_spec.predictions, training_hooks=training_hooks) return estimator_spec_with_hooks params = { "learning_rate": hparams.learning_rate, "number_of_alternatives": 5 } #config=tf.estimator.RunConfig(session_config=tf.ConfigProto(log_device_placement=False)) #config=tf.estimator.RunConfig(session_config=tf.ConfigProto()) config = tf.estimator.RunConfig() estimator = tf.estimator.Estimator(model_function, params=params, model_dir=model_dir, config=config) t1 = datetime.datetime.now() predictions = estimator.predict(create_input) t2 = datetime.datetime.now() predictions = islice(predictions, 1) for prediction in predictions: print(prediction) print("start:", t1) print("stop:", t2) print("duration:", t2 - t1) return prediction
def eval_lm_on_cached_simple_examples_with_glove(data_dir, model_dir, subset, hparams, take_first_n=20): glove = Glove300(dry_run=True) BATCH_SIZE = 5 def create_input(): input_pipe = LmInputDataPipeline(glove, 5) embedding_size = LmInputDataPipeline( glove, None)._vocab_generalized.vector_size() train_data = read_dataset_from_dir(data_dir, subset, embedding_size) if take_first_n is not None: train_data = train_data.take(take_first_n) train_data = input_pipe.padded_batch(train_data, BATCH_SIZE) return train_data def model_function(features, labels, mode, params): input_pipe = LmInputDataPipeline(glove) vocab_size = glove.vocab_size() embedding_size = input_pipe._vocab_generalized.vector_size() id_to_embeding_fn = input_pipe.get_id_to_embedding_mapping( ) if mode == tf.estimator.ModeKeys.PREDICT else lambda x: tf.zeros( (tf.shape(x), embedding_size), tf.float32) #with tf.device(device_assignment_function) if hparams.size_based_device_assignment else without: with tf.device("/device:CPU:0"): concrete_model_fn = get_autoregressor_model_fn( vocab_size, id_to_embeding_fn, time_major_optimization=True, predict_as_pure_lm=True, hparams=hparams) estimator_spec = concrete_model_fn(features, labels, mode, params) training_hooks = [] predictions = estimator_spec.predictions if mode == tf.estimator.ModeKeys.PREDICT: training_hooks.append(InitializeVocabularyHook(glove)) predicted_ids = predictions["predicted_word_id"] words_shape = tf.shape(predicted_ids) to_vocab_id = input_pipe._vocab_generalized.generalized_id_to_vocab_id( ) to_word = glove.id_to_word_op() predicted_ids = tf.reshape(predicted_ids, shape=[-1]) predicted_words = to_word(to_vocab_id(predicted_ids)) predicted_words = tf.reshape(predicted_words, shape=words_shape) predictions["predicted_word"] = predicted_words if hparams.profiler: training_hooks.append( tf.train.ProfilerHook(output_dir=model_dir, save_secs=30, show_memory=True)) training_hooks.append(FullLogHook()) estimator_spec_with_hooks = tf.estimator.EstimatorSpec( mode=estimator_spec.mode, loss=estimator_spec.loss, train_op=estimator_spec.train_op, eval_metric_ops=estimator_spec.eval_metric_ops, predictions=estimator_spec.predictions, training_hooks=training_hooks) return estimator_spec_with_hooks params = { "learning_rate": hparams.learning_rate, "number_of_alternatives": 1 } #config=tf.estimator.RunConfig(session_config=tf.ConfigProto(log_device_placement=False)) #config=tf.estimator.RunConfig(session_config=tf.ConfigProto()) config = tf.estimator.RunConfig() estimator = tf.estimator.Estimator(model_function, params=params, model_dir=model_dir, config=config) t1 = datetime.datetime.now() predictions = estimator.predict(create_input) t2 = datetime.datetime.now() predictions = [*islice(predictions, take_first_n)] with open("rtest_expected.pickle", "wb") as file_expected: pickle.dump(predictions, file_expected) for prediction in predictions: print(prediction) print("start:", t1) print("stop:", t2) print("duration:", t2 - t1)
def train_lm_on_cached_simple_examples_with_glove(data_dir, model_dir, hparams): glove = Glove300(dry_run=False) BATCH_SIZE = 5 def create_input(): input_pipe = LmInputDataPipeline(glove, 5) embedding_size = LmInputDataPipeline( glove, None)._vocab_generalized.vector_size() train_data = read_dataset_from_dir(data_dir, DatasetType.TRAIN, embedding_size) train_data = train_data.repeat().shuffle(1000, seed=0) train_data = input_pipe.padded_batch(train_data, BATCH_SIZE) return train_data def model_function(features, labels, mode, params): input_pipe = LmInputDataPipeline(glove) vocab_size = glove.vocab_size() embedding_size = input_pipe._vocab_generalized.vector_size() id_to_embeding_fn = input_pipe.get_id_to_embedding_mapping( ) if mode == tf.estimator.ModeKeys.PREDICT else lambda x: tf.zeros( (tf.shape(x), embedding_size), tf.float32) with tf.device(device_assignment_function ) if hparams.size_based_device_assignment else without: concrete_model_fn = get_autoregressor_model_fn( vocab_size, id_to_embeding_fn, time_major_optimization=True, hparams=hparams) estimator_spec = concrete_model_fn(features, labels, mode, params) if hparams.write_target_text_to_summary: words_shape = tf.shape(labels["targets"]) to_vocab_id = input_pipe._vocab_generalized.generalized_id_to_vocab_id( ) to_word = glove.id_to_word_op() flat_targets = tf.reshape(labels["targets"], shape=[-1]) flat_targets_words = to_word(to_vocab_id(flat_targets)) targets_words = tf.reshape(flat_targets_words, shape=words_shape) tf.summary.text("targets_words", targets_words) training_hooks = [] if mode == tf.estimator.ModeKeys.PREDICT: training_hooks.append(InitializeVocabularyHook(glove)) if hparams.profiler: training_hooks.append( tf.train.ProfilerHook(output_dir=model_dir, save_secs=30, show_memory=True)) training_hooks.append(FullLogHook()) estimator_spec_with_hooks = tf.estimator.EstimatorSpec( mode=estimator_spec.mode, loss=estimator_spec.loss, train_op=estimator_spec.train_op, eval_metric_ops=estimator_spec.eval_metric_ops, predictions=estimator_spec.predictions, training_hooks=training_hooks) return estimator_spec_with_hooks params = { "learning_rate": hparams.learning_rate, "number_of_alternatives": 1 } if CREATE_RTEST_INPUT: dataset = create_input() it = dataset.make_initializable_iterator() next_example = it.get_next() with tf.Session() as sess: sess.run(tf.tables_initializer()) sess.run(tf.global_variables_initializer()) sess.run(it.initializer) expected = [] for i in range(3000): expected.append(sess.run(next_example)) with open("retest_expected.pickle", "wb") as rtest_expected: import pickle pickle.dump(expected, rtest_expected) return #config=tf.estimator.RunConfig(session_config=tf.ConfigProto(log_device_placement=False)) #config=tf.estimator.RunConfig(session_config=tf.ConfigProto()) config = tf.estimator.RunConfig() estimator = tf.estimator.Estimator(model_function, params=params, model_dir=model_dir, config=config) t1 = datetime.datetime.now() estimator.train(create_input, max_steps=hparams.max_training_steps) t2 = datetime.datetime.now() print("start:", t1) print("stop:", t2) print("duration:", t2 - t1)
def prepare_training_dataset(ouput_path): """This will transform input corpus into language model training examples with embeddings vectors as inputs and save it to disk. Expect HUGE dataset in terms of occupied space.""" if TEST_SERIALIZATION: test_examples = [] ouput_path = Path(ouput_path) glove = Glove300() def create_input(): simple_examples = SimpleExamplesCorpus() train_data = simple_examples.get_tokens_dataset(DatasetType.TRAIN) input_pipe = LmInputDataPipeline(glove, None) return input_pipe.load_data(train_data) dataset = create_input() def make_tf_record_example(features, labels) -> tf.train.SequenceExample: feature_inputs = tf.train.Feature(float_list=tf.train.FloatList( value=features["inputs"].reshape(-1))) feature_length = tf.train.Feature(int64_list=tf.train.Int64List( value=[features["length"]])) feature_targets = tf.train.Feature(int64_list=tf.train.Int64List( value=labels["targets"])) feature_dict = { "inputs": feature_inputs, "length": feature_length, "targets": feature_targets } features = tf.train.Features(feature=feature_dict) example = tf.train.Example(features=features) return example def max_length_condition(max_length): def check_length(features, labels): return tf.less_equal(features["length"], max_length) return check_length dataset = dataset.filter(max_length_condition(40)) it = dataset.make_initializable_iterator() next = it.get_next() EXAMPLES_PER_FILE = 2000 with tf.Session() as sess: sess.run(tf.tables_initializer()) glove.initialize_embeddings_in_graph(tf.get_default_graph(), sess) sess.run(it.initializer) for i in count(1): dataset_filename = str(ouput_path / "train.{:0=10}.tfrecords".format(i)) writer = tf.python_io.TFRecordWriter(dataset_filename) try: for _ in range(EXAMPLES_PER_FILE): features, labels = sess.run(next) if TEST_SERIALIZATION: test_examples.append((features, labels)) example = make_tf_record_example(features, labels) writer.write(example.SerializeToString()) except tf.errors.OutOfRangeError: break writer.close() if TEST_SERIALIZATION: embedding_size = LmInputDataPipeline( glove, None)._vocab_generalized.vector_size() records_dataset = read_dataset_from_files( [dataset_filename], embedding_size=embedding_size) it = records_dataset.make_initializable_iterator() next_record = it.get_next() with tf.Session() as sess: sess.run(it.initializer) for expected_features, expected_labels in test_examples: actual_features, actual_labels = sess.run(next_record) assert (actual_features["inputs"] == expected_features["inputs"]).all() assert (actual_features["length"] == expected_features["length"]).all() assert actual_labels["targets"] == approx( expected_labels["targets"])