def RunTraining(self, hyperparam_config): master_spec = self.LoadSpec('master_spec_link.textproto') self.assertTrue(isinstance(hyperparam_config, spec_pb2.GridPoint)) gold_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc) gold_doc_2 = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2) reader_strings = [ gold_doc.SerializeToString(), gold_doc_2.SerializeToString() ] tf.logging.info('Generating graph with config: %s', hyperparam_config) with tf.Graph().as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) target = spec_pb2.TrainTarget() target.name = 'testTraining-all' train = builder.add_training_from_config(target) with self.test_session() as sess: logging.info('Initializing') sess.run(tf.global_variables_initializer()) # Run one iteration of training and verify nothing crashes. logging.info('Training') sess.run(train['run'], feed_dict={train['input_batch']: reader_strings})
def calculate_parse_metrics(gold_corpus, annotated_corpus): """Calculate POS/UAS/LAS accuracy based on gold and annotated sentences.""" check.Eq(len(gold_corpus), len(annotated_corpus), 'Corpora are not aligned') num_tokens = 0 num_correct_pos = 0 num_correct_uas = 0 num_correct_las = 0 for gold_str, annotated_str in zip(gold_corpus, annotated_corpus): gold = sentence_pb2.Sentence() annotated = sentence_pb2.Sentence() gold.ParseFromString(gold_str) annotated.ParseFromString(annotated_str) check.Eq(gold.text, annotated.text, 'Text is not aligned') check.Eq(len(gold.token), len(annotated.token), 'Tokens are not aligned') tokens = zip(gold.token, annotated.token) num_tokens += len(tokens) num_correct_pos += sum(1 for x, y in tokens if x.tag == y.tag) num_correct_uas += sum(1 for x, y in tokens if x.head == y.head) num_correct_las += sum(1 for x, y in tokens if x.head == y.head and x.label == y.label) tf.logging.info('Total num documents: %d', len(annotated_corpus)) tf.logging.info('Total num tokens: %d', num_tokens) pos = num_correct_pos * 100.0 / num_tokens uas = num_correct_uas * 100.0 / num_tokens las = num_correct_las * 100.0 / num_tokens tf.logging.info('POS: %.2f%%', pos) tf.logging.info('UAS: %.2f%%', uas) tf.logging.info('LAS: %.2f%%', las) return pos, uas, las
def calculate_segmentation_metrics(gold_corpus, annotated_corpus): """Calculate precision/recall/f1 based on gold and annotated sentences.""" check.Eq(len(gold_corpus), len(annotated_corpus), 'Corpora are not aligned') num_gold_tokens = 0 num_test_tokens = 0 num_correct_tokens = 0 def token_span(token): check.Ge(token.end, token.start) return (token.start, token.end) def ratio(numerator, denominator): check.Ge(numerator, 0) check.Ge(denominator, 0) if denominator > 0: return numerator / denominator elif numerator == 0: return 0.0 # map 0/0 to 0 else: return float('inf') # map x/0 to inf for gold_str, annotated_str in zip(gold_corpus, annotated_corpus): gold = sentence_pb2.Sentence() annotated = sentence_pb2.Sentence() gold.ParseFromString(gold_str) annotated.ParseFromString(annotated_str) check.Eq(gold.text, annotated.text, 'Text is not aligned') gold_spans = set() test_spans = set() for token in gold.token: check.NotIn(token_span(token), gold_spans, 'Duplicate token') gold_spans.add(token_span(token)) for token in annotated.token: check.NotIn(token_span(token), test_spans, 'Duplicate token') test_spans.add(token_span(token)) num_gold_tokens += len(gold_spans) num_test_tokens += len(test_spans) num_correct_tokens += len(gold_spans.intersection(test_spans)) tf.logging.info('Total num documents: %d', len(annotated_corpus)) tf.logging.info('Total gold tokens: %d', num_gold_tokens) tf.logging.info('Total test tokens: %d', num_test_tokens) precision = 100 * ratio(num_correct_tokens, num_test_tokens) recall = 100 * ratio(num_correct_tokens, num_gold_tokens) f1 = ratio(2 * precision * recall, precision + recall) tf.logging.info('Precision: %.2f%%', precision) tf.logging.info('Recall: %.2f%%', recall) tf.logging.info('F1: %.2f%%', f1) return round(precision, 2), round(recall, 2), round(f1, 2)
def pretty_print(): _write_input(_read_output().strip()) logging.set_verbosity(logging.INFO) with tf.Session() as sess: src = gen_parser_ops.document_source( batch_size=32, corpus_name='input-from-file-conll', task_context=task_context_path) sentence = sentence_pb2.Sentence() while True: documents, finished = sess.run(src) logging.info('Read %d documents', len(documents)) # for d in documents: # sentence.ParseFromString(d) # as_asciitree(sentence) for d in documents: sentence.ParseFromString(d) tr = asciitree.LeftAligned() d = to_dict(sentence) print('Input: %s' % sentence.text) print('Parse:') tr_str = tr(d) pat = re.compile(r'\s*@\d+$') for tr_ln in tr_str.splitlines(): print(pat.sub('', tr_ln)) if finished: break
def main(unused_argv): logging.set_verbosity(logging.INFO) with tf.Session() as sess: src = gen_parser_ops.document_source(batch_size=32, corpus_name=FLAGS.corpus_name, task_context=FLAGS.task_context) sentence = sentence_pb2.Sentence() while True: documents, finished = sess.run(src) logging.info('Read %d documents', len(documents)) for d in documents: sentence.ParseFromString(d) #print '...Sentence string before serialization: ', d tr = asciitree.LeftAligned() d = to_dict(sentence) print 'Input: %s' % sentence.text serializedStr = sentence.SerializeToString() #print '...Sentence string protobuf: ', serializedStr file = open("/Users/yihed/Documents/workspace/Other/src/thmp/data/serializedSentence.txt", "wb") #file = open("serializedSentence.txt", "wb") file.write(serializedStr) file.close() print 'Parse:' print tr(d) if finished: break
def assertParseable(self, reader, expected_num, expected_last): sentences, last = reader.read() self.assertEqual(expected_num, len(sentences)) self.assertEqual(expected_last, last) for s in sentences: pb = sentence_pb2.Sentence() pb.ParseFromString(s) self.assertGreater(len(pb.token), 0)
def ReadNextDocument(self, sess, sentence): sentence_str, = sess.run([sentence]) if sentence_str: sentence_doc = sentence_pb2.Sentence() sentence_doc.ParseFromString(sentence_str[0]) else: sentence_doc = None return sentence_doc
def ReadNextDocument(self, sess, doc_source): doc_str, last = sess.run(doc_source) if doc_str: doc = sentence_pb2.Sentence() doc.ParseFromString(doc_str[0]) else: doc = None return doc, last
def testReadFirstSentence(self): reader = sentence_io.ConllSentenceReader(self.filepath, 1) sentences, last = reader.read() self.assertEqual(1, len(sentences)) pb = sentence_pb2.Sentence() pb.ParseFromString(sentences[0]) self.assertFalse(last) self.assertEqual( u'I knew I could do it properly if given the right kind of support .', pb.text)
def testGiveMeAName(self): document = sentence_pb2.Sentence() document.token.add(start=0, end=0, word='hi', head=1, label='something') document.token.add(start=1, end=1, word='there') contents = render_parse_tree_graphviz.parse_tree_graph(document) self.assertIn('<polygon', contents) self.assertIn('text/html;charset=utf-8;base64', contents) self.assertIn('something', contents) self.assertIn('hi', contents) self.assertIn('there', contents)
def testModelExportProducesRunnableModel(self): # Get the master spec and params for this graph. master_spec = self.LoadSpec('ud-hungarian.master-spec') params_path = os.path.join( test_flags.source_root(), 'dragnn/python/testdata' '/ud-hungarian.params') # Export the graph via SavedModel. (Here, we maintain a handle to the graph # for comparison, but that's usually not necessary.) export_path = os.path.join(test_flags.temp_dir(), 'export') dragnn_model_saver_lib.clean_output_paths(export_path) saver_graph = tf.Graph() shortened_to_original = dragnn_model_saver_lib.shorten_resource_paths( master_spec) dragnn_model_saver_lib.export_master_spec(master_spec, saver_graph) dragnn_model_saver_lib.export_to_graph(master_spec, params_path, export_path, saver_graph, export_moving_averages=False, build_runtime_graph=False) # Export the assets as well. dragnn_model_saver_lib.export_assets(master_spec, shortened_to_original, export_path) # Restore the graph from the checkpoint into a new Graph object. restored_graph = tf.Graph() restoration_config = tf.ConfigProto(log_device_placement=False, intra_op_parallelism_threads=10, inter_op_parallelism_threads=10) with tf.Session(graph=restored_graph, config=restoration_config) as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_path) test_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc) test_reader_string = test_doc.SerializeToString() test_inputs = [test_reader_string] tf_out = sess.run('annotation/annotations:0', feed_dict={ 'annotation/ComputeSession/InputBatch:0': test_inputs }) # We don't care about accuracy, only that the run sessions don't crash. del tf_out
def _add_sentence(self, tags, heads, labels, corpus): """Adds a sentence to the corpus.""" sentence = sentence_pb2.Sentence() for tag, head, label in zip(tags, heads, labels): sentence.token.add(word='x', start=0, end=0, tag=tag, head=head, label=label) corpus.append(sentence.SerializeToString())
def _create_fake_corpus(): """Returns a list of fake serialized sentences for tests.""" num_docs = 4 corpus = [] for num_tokens in range(1, num_docs + 1): sentence = sentence_pb2.Sentence() sentence.text = 'x' * num_tokens for i in range(num_tokens): token = sentence.token.add() token.word = 'x' token.start = i token.end = i corpus.append(sentence.SerializeToString()) return corpus
def annotate_text(self, text): sentence = sentence_pb2.Sentence( text=text, token=[sentence_pb2.Token(word=text, start=-1, end=-1)]) # preprocess with tf.Session(graph=tf.Graph()) as tmp_session: char_input = gen_parser_ops.char_token_generator( [sentence.SerializeToString()]) preprocessed = tmp_session.run(char_input)[0] segmented, _ = self.segmenter_model(preprocessed) annotations, traces = self.parser_model(segmented[0]) assert len(annotations) == 1 assert len(traces) == 1 return sentence_pb2.Sentence.FromString(annotations[0])
def main(argv): del argv # unused # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. lexicon.build_lexicon(lexicon_dir, training_sentence, training_corpus_format='sentence-prototext') # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN # sequence tagger. tagger = spec_builder.ComponentSpecBuilder('tagger') tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256') tagger.set_transition_system(name='tagger') tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64) tagger.add_rnn_link(embedding_dim=-1) tagger.fill_from_resources(lexicon_dir) master_spec = spec_pb2.MasterSpec() master_spec.component.extend([tagger.spec]) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) target = spec_pb2.TrainTarget() target.name = 'all' target.unroll_using_oracle.extend([True]) dry_run = builder.add_training_from_config(target, trace_only=True) # Read in serialized protos from training data. sentence = sentence_pb2.Sentence() text_format.Merge(open(training_sentence).read(), sentence) training_set = [sentence.SerializeToString()] with tf.Session(graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.initialize_all_variables()) traces = sess.run(dry_run['traces'], feed_dict={dry_run['input_batch']: training_set}) with open('dragnn_tutorial_1.html', 'w') as f: f.write( visualization.trace_html(traces[0], height='300px').encode('utf-8'))
def syntaxnet_tokenize(text): sentence = sentence_pb2.Sentence( text=text, token=[sentence_pb2.Token(word=text, start=-1, end=-1)] ) # preprocess with tf.Session(graph=tf.Graph()) as tmp_session: char_input = gen_parser_ops.char_token_generator([sentence.SerializeToString()]) preprocessed = tmp_session.run(char_input)[0] segmented, _ = segmenter_model(preprocessed) tokens = [] for t in sentence_pb2.Sentence.FromString(segmented[0]).token: tokens.append(t.word) return tokens
def syntaxnet_sentence(tokens): pb_tokens = [] last_start = 0 for token in tokens: token_bytes = token.encode("utf8") pb_tokens.append(sentence_pb2.Token( word=token_bytes, start=last_start, end=last_start + len(token_bytes) - 1) ) last_start = last_start + len(token_bytes) + 1 annotations, traces = parser_model(sentence_pb2.Sentence( text=u" ".join(tokens).encode("utf8"), token=pb_tokens ).SerializeToString()) assert len(annotations) == 1 assert len(traces) == 1 return sentence_pb2.Sentence.FromString(annotations[0])
def annotate_text(text): """ Segment and parse input text using syntaxnet models. """ sentence = sentence_pb2.Sentence( text=text, token=[sentence_pb2.Token(word=text, start=-1, end=-1)]) # preprocess with tf.Session(graph=tf.Graph()) as tmp_session: char_input = gen_parser_ops.char_token_generator( [sentence.SerializeToString()]) preprocessed = tmp_session.run(char_input)[0] segmented, _ = SEGMENTER_MODEL(preprocessed) annotations, traces = PARSER_MODEL(segmented[0]) assert len(annotations) == 1 assert len(traces) == 1 return sentence_pb2.Sentence.FromString(annotations[0]), traces[0]
def _get_sentence_dict(): logging.set_verbosity(logging.INFO) with tf.Session() as sess: src = gen_parser_ops.document_source( batch_size=32, corpus_name='input-from-file-conll', task_context=task_context_path) sentence = sentence_pb2.Sentence() result_dict = None while True: documents, finished = sess.run(src) for d in documents: sentence.ParseFromString(d) d = to_dict(sentence) result_dict = d if finished: break return result_dict
def main(unused_argv): logging.set_verbosity(logging.INFO) with tf.Session() as sess: src = gen_parser_ops.document_source(batch_size=32, corpus_name=FLAGS.corpus_name, task_context=FLAGS.task_context) sentence = sentence_pb2.Sentence() while True: documents, finished = sess.run(src) logging.info('Read %d documents', len(documents)) for d in documents: sentence.ParseFromString(d) tr = asciitree.LeftAligned() d = to_dict(sentence) print('Input: %s' % sentence.text) print('Parse:') print(tr(d)) if finished: break
def print_output(output_file, use_text_format, use_gold_segmentation, output): """Writes a set of sentences in CoNLL format. Args: output_file: The file to write to. use_text_format: Whether this computation used text-format input. use_gold_segmentation: Whether this computation used gold segmentation. output: A list of sentences to write to the output file. """ with gfile.GFile(output_file, 'w') as f: f.write('## tf:{}\n'.format(use_text_format)) f.write('## gs:{}\n'.format(use_gold_segmentation)) for serialized_sentence in output: sentence = sentence_pb2.Sentence() sentence.ParseFromString(serialized_sentence) f.write('# text = {}\n'.format(sentence.text.encode('utf-8'))) for i, token in enumerate(sentence.token): head = token.head + 1 f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' % (i + 1, token.word.encode('utf-8'), head, token.label.encode('utf-8'))) f.write('\n')
def main(unused_argv): logging.set_verbosity(logging.INFO) with tf.Session() as sess: src = gen_parser_ops.document_source(batch_size=32, corpus_name=FLAGS.corpus_name, task_context=FLAGS.task_context) sentence = sentence_pb2.Sentence() while True: documents, finished = sess.run(src) logging.info('Read %d documents', len(documents)) for d in documents: sentence.ParseFromString(d) tr = asciitree.LeftAligned() d = to_dict(sentence) print 'Input: %s' % sentence.text print 'Parse:' tr_str = tr(d) pat = re.compile(r'\s*@\d+$') for tr_ln in tr_str.splitlines(): print pat.sub('', tr_ln) if finished: break
def inference(sess, graph, builder, annotator, text, enable_tracing=False): tokens = [ sentence_pb2.Token(word=word, start=-1, end=-1) for word in text.split() ] sentence = sentence_pb2.Sentence() sentence.token.extend(tokens) if enable_tracing: annotations, traces = sess.run( [annotator['annotations'], annotator['traces']], feed_dict={ annotator['input_batch']: [sentence.SerializeToString()] }) #HTML(visualization.trace_html(traces[0])) else: annotations = sess.run(annotator['annotations'], feed_dict={ annotator['input_batch']: [sentence.SerializeToString()] }) parsed_sentence = sentence_pb2.Sentence.FromString(annotations[0]) #HTML(render_parse_tree_graphviz.parse_tree_graph(parsed_sentence)) return parsed_sentence
def _parse(self): with self.graph.as_default(): num_epochs = None num_docs = 0 result = [] while True: tf_epochs, _, tf_documents = self.sess.run([ self.parser.evaluation['epochs'], self.parser.evaluation['eval_metrics'], self.parser.evaluation['documents'] ]) #print len(tf_documents) # assert len(tf_documents) == 1 # print type(tf_documents[len(tf_documents)-1]) if len(tf_documents) > 0: for item in tf_documents: doc = sentence_pb2.Sentence() doc.ParseFromString(item) result.append(ConvertToString(doc)) if num_epochs is None: num_epochs = tf_epochs elif num_epochs < tf_epochs: break return result
def RunFullTrainingAndInference(self, test_name, master_spec_path=None, master_spec=None, component_weights=None, unroll_using_oracle=None, num_evaluated_components=1, expected_num_actions=None, expected=None, batch_size_limit=None): if not master_spec: master_spec = self.LoadSpec(master_spec_path) gold_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc) gold_doc_2 = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2) gold_reader_strings = [ gold_doc.SerializeToString(), gold_doc_2.SerializeToString() ] test_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc) test_doc_2 = sentence_pb2.Sentence() text_format.Parse(_DUMMY_TEST_SENTENCE_2, test_doc_2) test_reader_strings = [ test_doc.SerializeToString(), test_doc.SerializeToString(), test_doc_2.SerializeToString(), test_doc.SerializeToString() ] if batch_size_limit is not None: gold_reader_strings = gold_reader_strings[:batch_size_limit] test_reader_strings = test_reader_strings[:batch_size_limit] with tf.Graph().as_default(): tf.set_random_seed(1) hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder( master_spec, hyperparam_config, pool_scope=test_name) target = spec_pb2.TrainTarget() target.name = 'testFullInference-train-%s' % test_name if component_weights: target.component_weights.extend(component_weights) else: target.component_weights.extend([0] * len(master_spec.component)) target.component_weights[-1] = 1.0 if unroll_using_oracle: target.unroll_using_oracle.extend(unroll_using_oracle) else: target.unroll_using_oracle.extend([False] * len(master_spec.component)) target.unroll_using_oracle[-1] = True train = builder.add_training_from_config(target) oracle_trace = builder.add_training_from_config( target, prefix='train_traced-', trace_only=True) builder.add_saver() anno = builder.add_annotation(test_name) trace = builder.add_annotation(test_name + '-traced', enable_tracing=True) # Verifies that the summaries can be built. for component in builder.components: component.get_summaries() config = tf.ConfigProto( intra_op_parallelism_threads=0, inter_op_parallelism_threads=0) with self.test_session(config=config) as sess: logging.info('Initializing') sess.run(tf.global_variables_initializer()) logging.info('Dry run oracle trace...') traces = sess.run( oracle_trace['traces'], feed_dict={oracle_trace['input_batch']: gold_reader_strings}) # Check that the oracle traces are not empty. for serialized_trace in traces: master_trace = trace_pb2.MasterTrace() master_trace.ParseFromString(serialized_trace) self.assertTrue(master_trace.component_trace) self.assertTrue(master_trace.component_trace[0].step_trace) logging.info('Simulating training...') break_iter = 400 is_resolved = False for i in range(0, 400): # needs ~100 iterations, but is not deterministic cost, eval_res_val = sess.run( [train['cost'], train['metrics']], feed_dict={train['input_batch']: gold_reader_strings}) logging.info('cost = %s', cost) self.assertFalse(np.isnan(cost)) total_val = eval_res_val.reshape((-1, 2))[:, 0].sum() correct_val = eval_res_val.reshape((-1, 2))[:, 1].sum() if correct_val == total_val and not is_resolved: logging.info('... converged on iteration %d with (correct, total) ' '= (%d, %d)', i, correct_val, total_val) is_resolved = True # Run for slightly longer than convergence to help with quantized # weight tiebreakers. break_iter = i + 50 if i == break_iter: break # If training failed, report total/correct actions for each component. if not expected_num_actions: expected_num_actions = 4 * num_evaluated_components if (correct_val != total_val or correct_val != expected_num_actions or total_val != expected_num_actions): for c in xrange(len(master_spec.component)): logging.error('component %s:\nname=%s\ntotal=%s\ncorrect=%s', c, master_spec.component[c].name, eval_res_val[2 * c], eval_res_val[2 * c + 1]) assert correct_val == total_val, 'Did not converge! %d vs %d.' % ( correct_val, total_val) self.assertEqual(expected_num_actions, correct_val) self.assertEqual(expected_num_actions, total_val) builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model')) logging.info('Running test.') logging.info('Printing annotations') annotations = sess.run( anno['annotations'], feed_dict={anno['input_batch']: test_reader_strings}) logging.info('Put %d inputs in, got %d annotations out.', len(test_reader_strings), len(annotations)) # Also run the annotation graph with tracing enabled. annotations_with_trace, traces = sess.run( [trace['annotations'], trace['traces']], feed_dict={trace['input_batch']: test_reader_strings}) # The result of the two annotation graphs should be identical. self.assertItemsEqual(annotations, annotations_with_trace) # Check that the inference traces are not empty. for serialized_trace in traces: master_trace = trace_pb2.MasterTrace() master_trace.ParseFromString(serialized_trace) self.assertTrue(master_trace.component_trace) self.assertTrue(master_trace.component_trace[0].step_trace) self.assertEqual(len(test_reader_strings), len(annotations)) pred_sentences = [] for annotation in annotations: pred_sentences.append(sentence_pb2.Sentence()) pred_sentences[-1].ParseFromString(annotation) if expected is None: expected = _TAGGER_EXPECTED_SENTENCES expected_sentences = [expected[i] for i in [0, 0, 1, 0]] for i, pred_sentence in enumerate(pred_sentences): self.assertProtoEquals(expected_sentences[i], pred_sentence)
def main(argv): del argv # unused # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. lexicon.build_lexicon( lexicon_dir, training_sentence, training_corpus_format='sentence-prototext') # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN # sequence tagger. tagger = spec_builder.ComponentSpecBuilder('tagger') tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256') tagger.set_transition_system(name='tagger') tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64) tagger.add_rnn_link(embedding_dim=-1) tagger.fill_from_resources(lexicon_dir) # Construct the ComponentSpec for parsing. parser = spec_builder.ComponentSpecBuilder('parser') parser.set_network_unit( name='FeedForwardNetwork', hidden_layer_sizes='256', layer_norm_hidden='True') parser.set_transition_system(name='arc-standard') parser.add_token_link( source=tagger, fml='input.focus stack.focus stack(1).focus', embedding_dim=32, source_layer='logits') # Recurrent connection for the arc-standard parser. For both tokens on the # stack, we connect to the last time step to either SHIFT or REDUCE that # token. This allows the parser to build up compositional representations of # phrases. parser.add_link( source=parser, # recurrent connection name='rnn-stack', # unique identifier fml='stack.focus stack(1).focus', # look for both stack tokens source_translator='shift-reduce-step', # maps token indices -> step embedding_dim=32) # project down to 32 dims parser.fill_from_resources(lexicon_dir) master_spec = spec_pb2.MasterSpec() master_spec.component.extend([tagger.spec, parser.spec]) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) target = spec_pb2.TrainTarget() target.name = 'all' target.unroll_using_oracle.extend([True, True]) dry_run = builder.add_training_from_config(target, trace_only=True) # Read in serialized protos from training data. sentence = sentence_pb2.Sentence() text_format.Merge(open(training_sentence).read(), sentence) training_set = [sentence.SerializeToString()] with tf.Session(graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.initialize_all_variables()) traces = sess.run( dry_run['traces'], feed_dict={dry_run['input_batch']: training_set}) with open('dragnn_tutorial_2.html', 'w') as f: f.write( visualization.trace_html( traces[0], height='400px', master_spec=master_spec).encode('utf-8'))
def add_sentence_for_segment_eval(starts, ends, corpus): """Adds a sentence to the corpus.""" sentence = sentence_pb2.Sentence() for start, end in zip(starts, ends): sentence.token.add(word='x', start=start, end=end) corpus.append(sentence.SerializeToString())
def main(unused_argv): # Parse the flags containint lists, using regular expressions. # This matches and extracts key=value pairs. component_beam_sizes = re.findall(r'([^=,]+)=(\d+)', FLAGS.inference_beam_size) # This matches strings separated by a comma. Does not return any empty # strings. components_to_locally_normalize = re.findall(r'[^,]+', FLAGS.locally_normalize) ## SEGMENTATION ## if not FLAGS.use_gold_segmentation: # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin: text_format.Parse(fin.read(), master_spec) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.segmenter_resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') input_corpus = sentence_io.ConllSentenceReader( FLAGS.input_file).corpus() with tf.Session(graph=tf.Graph()) as tmp_session: char_input = gen_parser_ops.char_token_generator(input_corpus) char_corpus = tmp_session.run(char_input) check.Eq(len(input_corpus), len(char_corpus)) session_config = tf.ConfigProto( log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.segmenter_checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(char_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(char_corpus)) feed_dict = {annotator['input_batch']: char_corpus[start:end]} if FLAGS.timeline_output_file and end == len(char_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline( step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run(annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(char_corpus), time.time() - start_time) input_corpus = processed else: input_corpus = sentence_io.ConllSentenceReader( FLAGS.input_file).corpus() ## PARSING # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.parser_master_spec) as fin: text_format.Parse(fin.read(), master_spec) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.parser_resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') session_config = tf.ConfigProto(log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.parser_checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(input_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(input_corpus)) feed_dict = {annotator['input_batch']: input_corpus[start:end]} for comp, beam_size in component_beam_sizes: feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size for comp in components_to_locally_normalize: feed_dict['%s/LocallyNormalize:0' % comp] = True if FLAGS.timeline_output_file and end == len(input_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline(step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run(annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(input_corpus), time.time() - start_time) if FLAGS.output_file: with gfile.GFile(FLAGS.output_file, 'w') as f: for serialized_sentence in processed: sentence = sentence_pb2.Sentence() sentence.ParseFromString(serialized_sentence) f.write('#' + sentence.text.encode('utf-8') + '\n') for i, token in enumerate(sentence.token): head = token.head + 1 f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' % (i + 1, token.word.encode('utf-8'), head, token.label.encode('utf-8'))) f.write('\n\n')
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Parse the flags containint lists, using regular expressions. # This matches and extracts key=value pairs. component_beam_sizes = re.findall(r'([^=,]+)=(\d+)', FLAGS.inference_beam_size) # This matches strings separated by a comma. Does not return any empty # strings. components_to_locally_normalize = re.findall(r'[^,]+', FLAGS.locally_normalize) # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.master_spec) as fin: text_format.Parse(fin.read(), master_spec) # Rewrite resource locations. if FLAGS.resource_dir: for component in master_spec.component: for resource in component.resource: for part in resource.part: part.file_pattern = os.path.join(FLAGS.resource_dir, part.file_pattern) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus() session_config = tf.ConfigProto(log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(input_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(input_corpus)) feed_dict = {annotator['input_batch']: input_corpus[start:end]} for comp, beam_size in component_beam_sizes: feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size for comp in components_to_locally_normalize: feed_dict['%s/LocallyNormalize:0' % comp] = True if FLAGS.timeline_output_file and end == len(input_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline(step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run(annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(input_corpus), time.time() - start_time) pos, uas, las = evaluation.calculate_parse_metrics( input_corpus, processed) if FLAGS.log_file: with gfile.GFile(FLAGS.log_file, 'w') as f: f.write('%s\t%f\t%f\t%f\n' % (FLAGS.language_name, pos, uas, las)) if FLAGS.output_file: with gfile.GFile(FLAGS.output_file, 'w') as f: for serialized_sentence in processed: sentence = sentence_pb2.Sentence() sentence.ParseFromString(serialized_sentence) f.write(text_format.MessageToString(sentence) + '\n\n')