def _load_model(self, base_dir, master_spec_name): master_spec = spec_pb2.MasterSpec() with open(os.path.join(base_dir, master_spec_name)) as f: text_format.Merge(f.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, base_dir) graph = tf.Graph() with graph.as_default(): hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder( master_spec, hyperparam_config ) annotator = builder.add_annotation(enable_tracing=True) builder.add_saver() sess = tf.Session(graph=graph) with graph.as_default(): builder.saver.restore(sess, os.path.join(base_dir, "checkpoint")) def annotate_sentence(sentence): with graph.as_default(): return sess.run( [annotator['annotations'], annotator['traces']], feed_dict={annotator['input_batch']: [sentence]} ) return annotate_sentence
def RunTraining(self, hyperparam_config): master_spec = self.LoadSpec('master_spec_link.textproto') self.assertTrue(isinstance(hyperparam_config, spec_pb2.GridPoint)) gold_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc) gold_doc_2 = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2) reader_strings = [ gold_doc.SerializeToString(), gold_doc_2.SerializeToString() ] tf.logging.info('Generating graph with config: %s', hyperparam_config) with tf.Graph().as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) target = spec_pb2.TrainTarget() target.name = 'testTraining-all' train = builder.add_training_from_config(target) with self.test_session() as sess: logging.info('Initializing') sess.run(tf.global_variables_initializer()) # Run one iteration of training and verify nothing crashes. logging.info('Training') sess.run(train['run'], feed_dict={train['input_batch']: reader_strings})
def load_model(self, base_dir, master_spec_name, checkpoint_name="checkpoint", rename=True): try: master_spec = spec_pb2.MasterSpec() with open(os.path.join(base_dir, master_spec_name)) as f: text_format.Merge(f.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, base_dir) graph = tf.Graph() with graph.as_default(): hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation(enable_tracing=True) builder.add_saver() sess = tf.Session(graph=graph) with graph.as_default(): builder.saver.restore(sess, os.path.join(base_dir, checkpoint_name)) def annotate_sentence(sentence): with graph.as_default(): return sess.run([annotator['annotations'], annotator['traces']], feed_dict={annotator['input_batch']: [sentence]}) except: if rename: self.rename_vars(base_dir, checkpoint_name) return self.load_model(base_dir, master_spec_name, checkpoint_name, False) raise Exception('Cannot load model: spec expects references to */kernel tensors instead of */weights.\ Try running with rename=True or run rename_vars() to convert existing checkpoint files into supported format') return annotate_sentence
def build_train_graph(master_spec, hyperparam_config=None): # Build the TensorFlow graph based on the DRAGNN network spec. tf.logging.info('Building Graph...') if not hyperparam_config: hyperparam_config = spec_pb2.GridPoint(learning_method='adam', learning_rate=0.0005, adam_beta1=0.9, adam_beta2=0.9, adam_eps=0.00001, decay_steps=128000, dropout_rate=0.8, gradient_clip_norm=1, use_moving_average=True, seed=1) graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = [ spec_pb2.TrainTarget(name=component.name, max_index=idx + 1, unroll_using_oracle=[False] * idx + [True]) for idx, component in enumerate(master_spec.component) if 'shift-only' not in component.transition_system.registered_name ] trainers = [ builder.add_training_from_config(target) for target in component_targets ] annotator = builder.add_annotation(enable_tracing=True) builder.add_saver() return graph, builder, trainers, annotator
def load_model(self, base_dir, master_spec_name, checkpoint_name): # Read the master spec master_spec = spec_pb2.MasterSpec() with open(os.path.join(base_dir, master_spec_name), "r") as f: text_format.Merge(f.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, base_dir) logging.set_verbosity(logging.WARN) # Turn off TensorFlow spam. # Initialize a graph graph = tf.Graph() with graph.as_default(): hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) # This is the component that will annotate test sentences. annotator = builder.add_annotation(enable_tracing=True) builder.add_saver( ) # "Savers" can save and load models; here, we're only going to load. sess = tf.Session(graph=graph) with graph.as_default(): # sess.run(tf.global_variables_initializer()) # sess.run('save/restore_all', {'save/Const:0': os.path.join(base_dir, checkpoint_name)}) builder.saver.restore(sess, os.path.join(base_dir, checkpoint_name)) def annotate_sentence(sentence): with graph.as_default(): return sess.run( [annotator['annotations'], annotator['traces']], feed_dict={annotator['input_batch']: [sentence]}) return annotate_sentence
def build_inference_graph(master_spec, enable_tracing=False): # Initialize a graph tf.logging.info('Building Graph...') graph = tf.Graph() with graph.as_default(): hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) # This is the component that will annotate test sentences. annotator = builder.add_annotation(enable_tracing=enable_tracing) builder.add_saver() return graph, builder, annotator
def getBuilderAndTarget( self, test_name, master_spec_path='simple_parser_master_spec.textproto'): """Generates a MasterBuilder and TrainTarget based on a simple spec.""" master_spec = self.LoadSpec(master_spec_path) hyperparam_config = spec_pb2.GridPoint() target = spec_pb2.TrainTarget() target.name = 'test-%s-train' % test_name target.component_weights.extend([0] * len(master_spec.component)) target.component_weights[-1] = 1.0 target.unroll_using_oracle.extend([False] * len(master_spec.component)) target.unroll_using_oracle[-1] = True builder = graph_builder.MasterBuilder( master_spec, hyperparam_config, pool_scope=test_name) return builder, target
def main(argv): del argv # unused # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. lexicon.build_lexicon(lexicon_dir, training_sentence, training_corpus_format='sentence-prototext') # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN # sequence tagger. tagger = spec_builder.ComponentSpecBuilder('tagger') tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256') tagger.set_transition_system(name='tagger') tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64) tagger.add_rnn_link(embedding_dim=-1) tagger.fill_from_resources(lexicon_dir) master_spec = spec_pb2.MasterSpec() master_spec.component.extend([tagger.spec]) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) target = spec_pb2.TrainTarget() target.name = 'all' target.unroll_using_oracle.extend([True]) dry_run = builder.add_training_from_config(target, trace_only=True) # Read in serialized protos from training data. sentence = sentence_pb2.Sentence() text_format.Merge(open(training_sentence).read(), sentence) training_set = [sentence.SerializeToString()] with tf.Session(graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.initialize_all_variables()) traces = sess.run(dry_run['traces'], feed_dict={dry_run['input_batch']: training_set}) with open('dragnn_tutorial_1.html', 'w') as f: f.write( visualization.trace_html(traces[0], height='300px').encode('utf-8'))
def main(argv): del argv # unused # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. lexicon.build_lexicon( lexicon_dir, training_sentence, training_corpus_format='sentence-prototext') # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN # sequence tagger. tagger = spec_builder.ComponentSpecBuilder('tagger') tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256') tagger.set_transition_system(name='tagger') tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64) tagger.add_rnn_link(embedding_dim=-1) tagger.fill_from_resources(lexicon_dir) # Construct the ComponentSpec for parsing. parser = spec_builder.ComponentSpecBuilder('parser') parser.set_network_unit( name='FeedForwardNetwork', hidden_layer_sizes='256', layer_norm_hidden='True') parser.set_transition_system(name='arc-standard') parser.add_token_link( source=tagger, fml='input.focus stack.focus stack(1).focus', embedding_dim=32, source_layer='logits') # Recurrent connection for the arc-standard parser. For both tokens on the # stack, we connect to the last time step to either SHIFT or REDUCE that # token. This allows the parser to build up compositional representations of # phrases. parser.add_link( source=parser, # recurrent connection name='rnn-stack', # unique identifier fml='stack.focus stack(1).focus', # look for both stack tokens source_translator='shift-reduce-step', # maps token indices -> step embedding_dim=32) # project down to 32 dims parser.fill_from_resources(lexicon_dir) master_spec = spec_pb2.MasterSpec() master_spec.component.extend([tagger.spec, parser.spec]) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) target = spec_pb2.TrainTarget() target.name = 'all' target.unroll_using_oracle.extend([True, True]) dry_run = builder.add_training_from_config(target, trace_only=True) # Read in serialized protos from training data. sentence = sentence_pb2.Sentence() text_format.Merge(open(training_sentence).read(), sentence) training_set = [sentence.SerializeToString()] with tf.Session(graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.initialize_all_variables()) traces = sess.run( dry_run['traces'], feed_dict={dry_run['input_batch']: training_set}) with open('dragnn_tutorial_2.html', 'w') as f: f.write( visualization.trace_html( traces[0], height='400px', master_spec=master_spec).encode('utf-8'))
def main(unused_argv): logging.set_verbosity(logging.INFO) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon(FLAGS.resource_path, FLAGS.training_corpus_path) # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN # sequence model, which encodes the context to the right of each token. It has # no loss except for the downstream components. char2word = spec_builder.ComponentSpecBuilder('char_lstm') char2word.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') char2word.set_transition_system(name='char-shift-only', left_to_right='true') char2word.add_fixed_feature(name='chars', fml='char-input.text-char', embedding_dim=16) char2word.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) lookahead = spec_builder.ComponentSpecBuilder('lookahead') lookahead.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') lookahead.set_transition_system(name='shift-only', left_to_right='false') lookahead.add_link(source=char2word, fml='input.last-char-focus', embedding_dim=32) lookahead.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN # sequence tagger. tagger = spec_builder.ComponentSpecBuilder('tagger') tagger.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') tagger.set_transition_system(name='tagger') tagger.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32) tagger.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Construct the ComponentSpec for parsing. parser = spec_builder.ComponentSpecBuilder('parser') parser.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256', layer_norm_hidden='True') parser.set_transition_system(name='arc-standard') parser.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32) parser.add_token_link(source=tagger, fml='input.focus stack.focus stack(1).focus', embedding_dim=32) # Recurrent connection for the arc-standard parser. For both tokens on the # stack, we connect to the last time step to either SHIFT or REDUCE that # token. This allows the parser to build up compositional representations of # phrases. parser.add_link( source=parser, # recurrent connection name='rnn-stack', # unique identifier fml='stack.focus stack(1).focus', # look for both stack tokens source_translator='shift-reduce-step', # maps token indices -> step embedding_dim=32) # project down to 32 dims parser.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) master_spec = spec_pb2.MasterSpec() master_spec.component.extend( [char2word.spec, lookahead.spec, tagger.spec, parser.spec]) logging.info('Constructed master spec: %s', str(master_spec)) hyperparam_config = spec_pb2.GridPoint() hyperparam_config.decay_steps = 128000 hyperparam_config.learning_rate = 0.001 hyperparam_config.learning_method = 'adam' hyperparam_config.adam_beta1 = 0.9 hyperparam_config.adam_beta2 = 0.9 hyperparam_config.adam_eps = 0.0001 hyperparam_config.gradient_clip_norm = 1 hyperparam_config.self_norm_alpha = 1.0 hyperparam_config.use_moving_average = True hyperparam_config.dropout_rate = 0.7 hyperparam_config.seed = 1 # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = spec_builder.default_targets_from_spec(master_spec) trainers = [ builder.add_training_from_config(target) for target in component_targets ] assert len(trainers) == 2 annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = sentence_io.ConllSentenceReader( FLAGS.training_corpus_path, projectivize=FLAGS.projectivize_training_set).corpus() dev_set = sentence_io.ConllSentenceReader(FLAGS.dev_corpus_path, projectivize=False).corpus() # Ready to train! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(dev_set)) pretrain_steps = [100, 0] tagger_steps = 1000 train_steps = [tagger_steps, 8 * tagger_steps] tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) trainer_lib.run_training(sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, training_set, dev_set, dev_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename)
def RunFullTrainingAndInference(self, test_name, master_spec_path=None, master_spec=None, component_weights=None, unroll_using_oracle=None, num_evaluated_components=1, expected_num_actions=None, expected=None, batch_size_limit=None): if not master_spec: master_spec = self.LoadSpec(master_spec_path) gold_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc) gold_doc_2 = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2) gold_reader_strings = [ gold_doc.SerializeToString(), gold_doc_2.SerializeToString() ] test_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc) test_doc_2 = sentence_pb2.Sentence() text_format.Parse(_DUMMY_TEST_SENTENCE_2, test_doc_2) test_reader_strings = [ test_doc.SerializeToString(), test_doc.SerializeToString(), test_doc_2.SerializeToString(), test_doc.SerializeToString() ] if batch_size_limit is not None: gold_reader_strings = gold_reader_strings[:batch_size_limit] test_reader_strings = test_reader_strings[:batch_size_limit] with tf.Graph().as_default(): tf.set_random_seed(1) hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder( master_spec, hyperparam_config, pool_scope=test_name) target = spec_pb2.TrainTarget() target.name = 'testFullInference-train-%s' % test_name if component_weights: target.component_weights.extend(component_weights) else: target.component_weights.extend([0] * len(master_spec.component)) target.component_weights[-1] = 1.0 if unroll_using_oracle: target.unroll_using_oracle.extend(unroll_using_oracle) else: target.unroll_using_oracle.extend([False] * len(master_spec.component)) target.unroll_using_oracle[-1] = True train = builder.add_training_from_config(target) oracle_trace = builder.add_training_from_config( target, prefix='train_traced-', trace_only=True) builder.add_saver() anno = builder.add_annotation(test_name) trace = builder.add_annotation(test_name + '-traced', enable_tracing=True) # Verifies that the summaries can be built. for component in builder.components: component.get_summaries() config = tf.ConfigProto( intra_op_parallelism_threads=0, inter_op_parallelism_threads=0) with self.test_session(config=config) as sess: logging.info('Initializing') sess.run(tf.global_variables_initializer()) logging.info('Dry run oracle trace...') traces = sess.run( oracle_trace['traces'], feed_dict={oracle_trace['input_batch']: gold_reader_strings}) # Check that the oracle traces are not empty. for serialized_trace in traces: master_trace = trace_pb2.MasterTrace() master_trace.ParseFromString(serialized_trace) self.assertTrue(master_trace.component_trace) self.assertTrue(master_trace.component_trace[0].step_trace) logging.info('Simulating training...') break_iter = 400 is_resolved = False for i in range(0, 400): # needs ~100 iterations, but is not deterministic cost, eval_res_val = sess.run( [train['cost'], train['metrics']], feed_dict={train['input_batch']: gold_reader_strings}) logging.info('cost = %s', cost) self.assertFalse(np.isnan(cost)) total_val = eval_res_val.reshape((-1, 2))[:, 0].sum() correct_val = eval_res_val.reshape((-1, 2))[:, 1].sum() if correct_val == total_val and not is_resolved: logging.info('... converged on iteration %d with (correct, total) ' '= (%d, %d)', i, correct_val, total_val) is_resolved = True # Run for slightly longer than convergence to help with quantized # weight tiebreakers. break_iter = i + 50 if i == break_iter: break # If training failed, report total/correct actions for each component. if not expected_num_actions: expected_num_actions = 4 * num_evaluated_components if (correct_val != total_val or correct_val != expected_num_actions or total_val != expected_num_actions): for c in xrange(len(master_spec.component)): logging.error('component %s:\nname=%s\ntotal=%s\ncorrect=%s', c, master_spec.component[c].name, eval_res_val[2 * c], eval_res_val[2 * c + 1]) assert correct_val == total_val, 'Did not converge! %d vs %d.' % ( correct_val, total_val) self.assertEqual(expected_num_actions, correct_val) self.assertEqual(expected_num_actions, total_val) builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model')) logging.info('Running test.') logging.info('Printing annotations') annotations = sess.run( anno['annotations'], feed_dict={anno['input_batch']: test_reader_strings}) logging.info('Put %d inputs in, got %d annotations out.', len(test_reader_strings), len(annotations)) # Also run the annotation graph with tracing enabled. annotations_with_trace, traces = sess.run( [trace['annotations'], trace['traces']], feed_dict={trace['input_batch']: test_reader_strings}) # The result of the two annotation graphs should be identical. self.assertItemsEqual(annotations, annotations_with_trace) # Check that the inference traces are not empty. for serialized_trace in traces: master_trace = trace_pb2.MasterTrace() master_trace.ParseFromString(serialized_trace) self.assertTrue(master_trace.component_trace) self.assertTrue(master_trace.component_trace[0].step_trace) self.assertEqual(len(test_reader_strings), len(annotations)) pred_sentences = [] for annotation in annotations: pred_sentences.append(sentence_pb2.Sentence()) pred_sentences[-1].ParseFromString(annotation) if expected is None: expected = _TAGGER_EXPECTED_SENTENCES expected_sentences = [expected[i] for i in [0, 0, 1, 0]] for i, pred_sentence in enumerate(pred_sentences): self.assertProtoEquals(expected_sentences[i], pred_sentence)
def main(unused_argv): # Parse the flags containint lists, using regular expressions. # This matches and extracts key=value pairs. component_beam_sizes = re.findall(r'([^=,]+)=(\d+)', FLAGS.inference_beam_size) # This matches strings separated by a comma. Does not return any empty # strings. components_to_locally_normalize = re.findall(r'[^,]+', FLAGS.locally_normalize) ## SEGMENTATION ## if not FLAGS.use_gold_segmentation: # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin: text_format.Parse(fin.read(), master_spec) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.segmenter_resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') input_corpus = sentence_io.ConllSentenceReader( FLAGS.input_file).corpus() with tf.Session(graph=tf.Graph()) as tmp_session: char_input = gen_parser_ops.char_token_generator(input_corpus) char_corpus = tmp_session.run(char_input) check.Eq(len(input_corpus), len(char_corpus)) session_config = tf.ConfigProto( log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.segmenter_checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(char_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(char_corpus)) feed_dict = {annotator['input_batch']: char_corpus[start:end]} if FLAGS.timeline_output_file and end == len(char_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline( step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run(annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(char_corpus), time.time() - start_time) input_corpus = processed else: input_corpus = sentence_io.ConllSentenceReader( FLAGS.input_file).corpus() ## PARSING # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.parser_master_spec) as fin: text_format.Parse(fin.read(), master_spec) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.parser_resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') session_config = tf.ConfigProto(log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.parser_checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(input_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(input_corpus)) feed_dict = {annotator['input_batch']: input_corpus[start:end]} for comp, beam_size in component_beam_sizes: feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size for comp in components_to_locally_normalize: feed_dict['%s/LocallyNormalize:0' % comp] = True if FLAGS.timeline_output_file and end == len(input_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline(step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run(annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(input_corpus), time.time() - start_time) if FLAGS.output_file: with gfile.GFile(FLAGS.output_file, 'w') as f: for serialized_sentence in processed: sentence = sentence_pb2.Sentence() sentence.ParseFromString(serialized_sentence) f.write('#' + sentence.text.encode('utf-8') + '\n') for i, token in enumerate(sentence.token): head = token.head + 1 f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' % (i + 1, token.word.encode('utf-8'), head, token.label.encode('utf-8'))) f.write('\n\n')
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Parse the flags containint lists, using regular expressions. # This matches and extracts key=value pairs. component_beam_sizes = re.findall(r'([^=,]+)=(\d+)', FLAGS.inference_beam_size) # This matches strings separated by a comma. Does not return any empty # strings. components_to_locally_normalize = re.findall(r'[^,]+', FLAGS.locally_normalize) # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.master_spec) as fin: text_format.Parse(fin.read(), master_spec) # Rewrite resource locations. if FLAGS.resource_dir: for component in master_spec.component: for resource in component.resource: for part in resource.part: part.file_pattern = os.path.join(FLAGS.resource_dir, part.file_pattern) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus() session_config = tf.ConfigProto(log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(input_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(input_corpus)) feed_dict = {annotator['input_batch']: input_corpus[start:end]} for comp, beam_size in component_beam_sizes: feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size for comp in components_to_locally_normalize: feed_dict['%s/LocallyNormalize:0' % comp] = True if FLAGS.timeline_output_file and end == len(input_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline(step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run(annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(input_corpus), time.time() - start_time) pos, uas, las = evaluation.calculate_parse_metrics( input_corpus, processed) if FLAGS.log_file: with gfile.GFile(FLAGS.log_file, 'w') as f: f.write('%s\t%f\t%f\t%f\n' % (FLAGS.language_name, pos, uas, las)) if FLAGS.output_file: with gfile.GFile(FLAGS.output_file, 'w') as f: for serialized_sentence in processed: sentence = sentence_pb2.Sentence() sentence.ParseFromString(serialized_sentence) f.write(text_format.MessageToString(sentence) + '\n\n')
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) check.NotNone(FLAGS.model_dir, '--model_dir is required') check.Ne( FLAGS.pretrain_steps is None, FLAGS.pretrain_epochs is None, 'Exactly one of --pretrain_steps or --pretrain_epochs is required') check.Ne(FLAGS.train_steps is None, FLAGS.train_epochs is None, 'Exactly one of --train_steps or --train_epochs is required') config_path = os.path.join(FLAGS.model_dir, 'config.txt') master_path = os.path.join(FLAGS.model_dir, 'master.pbtxt') hyperparameters_path = os.path.join(FLAGS.model_dir, 'hyperparameters.pbtxt') targets_path = os.path.join(FLAGS.model_dir, 'targets.pbtxt') checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoints/best') tensorboard_dir = os.path.join(FLAGS.model_dir, 'tensorboard') with tf.gfile.FastGFile(config_path) as config_file: config = collections.defaultdict(bool, ast.literal_eval(config_file.read())) train_corpus_path = config['train_corpus_path'] tune_corpus_path = config['tune_corpus_path'] projectivize_train_corpus = config['projectivize_train_corpus'] master = _read_text_proto(master_path, spec_pb2.MasterSpec) hyperparameters = _read_text_proto(hyperparameters_path, spec_pb2.GridPoint) targets = spec_builder.default_targets_from_spec(master) if tf.gfile.Exists(targets_path): targets = _read_text_proto(targets_path, spec_pb2.TrainingGridSpec).target # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): tf.set_random_seed(hyperparameters.seed) builder = graph_builder.MasterBuilder(master, hyperparameters) trainers = [ builder.add_training_from_config(target) for target in targets ] annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. train_corpus = sentence_io.ConllSentenceReader( train_corpus_path, projectivize=projectivize_train_corpus).corpus() tune_corpus = sentence_io.ConllSentenceReader(tune_corpus_path, projectivize=False).corpus() gold_tune_corpus = tune_corpus # Convert to char-based corpora, if requested. if config['convert_to_char_corpora']: # NB: Do not convert the |gold_tune_corpus|, which should remain word-based # for segmentation evaluation purposes. train_corpus = _convert_to_char_corpus(train_corpus) tune_corpus = _convert_to_char_corpus(tune_corpus) pretrain_steps = _get_steps(FLAGS.pretrain_steps, FLAGS.pretrain_epochs, len(train_corpus)) train_steps = _get_steps(FLAGS.train_steps, FLAGS.train_epochs, len(train_corpus)) check.Eq(len(targets), len(pretrain_steps), 'Length mismatch between training targets and --pretrain_steps') check.Eq(len(targets), len(train_steps), 'Length mismatch between training targets and --train_steps') # Ready to train! tf.logging.info('Training on %d sentences.', len(train_corpus)) tf.logging.info('Tuning on %d sentences.', len(tune_corpus)) tf.logging.info('Creating TensorFlow checkpoint dir...') summary_writer = trainer_lib.get_summary_writer(tensorboard_dir) checkpoint_dir = os.path.dirname(checkpoint_path) if tf.gfile.IsDirectory(checkpoint_dir): tf.gfile.DeleteRecursively(checkpoint_dir) elif tf.gfile.Exists(checkpoint_dir): tf.gfile.Remove(checkpoint_dir) tf.gfile.MakeDirs(checkpoint_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) trainer_lib.run_training(sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, train_corpus, tune_corpus, gold_tune_corpus, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, checkpoint_path) tf.logging.info('Best checkpoint written to:\n%s', checkpoint_path)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Read hyperparams and master spec. hyperparam_config = spec_pb2.GridPoint() text_format.Parse(FLAGS.hyperparams, hyperparam_config) print hyperparam_config master_spec = spec_pb2.MasterSpec() with gfile.GFile(FLAGS.master_spec, 'r') as fin: text_format.Parse(fin.read(), master_spec) # Make output folder if not gfile.Exists(FLAGS.output_folder): gfile.MakeDirs(FLAGS.output_folder) # Construct TF Graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) # Construct default per-component targets. default_targets = [ spec_pb2.TrainTarget(name=component.name, max_index=idx + 1, unroll_using_oracle=[False] * idx + [True]) for idx, component in enumerate(master_spec.component) if (component.transition_system.registered_name != 'shift-only') ] # Add default and manually specified targets. trainers = [] for target in default_targets: trainers += [builder.add_training_from_config(target)] check.Eq(len(trainers), 1, "Expected only one training target (FF unit)") # Construct annotation and saves. Will use moving average if enabled. annotator = builder.add_annotation() builder.add_saver() # Add backwards compatible training summary. summaries = [] for component in builder.components: summaries += component.get_summaries() merged_summaries = tf.summary.merge_all() # Construct target to initialize variables. tf.group(tf.global_variables_initializer(), name='inits') # Prepare tensorboard dir. events_dir = os.path.join(FLAGS.output_folder, "tensorboard") empty_dir(events_dir) summary_writer = tf.summary.FileWriter(events_dir, graph) print "Wrote events (incl. graph) for Tensorboard to folder:", events_dir print "The graph can be viewed via" print " tensorboard --logdir=" + events_dir print " then navigating to http://localhost:6006 and clicking on 'GRAPHS'" with graph.as_default(): tf.set_random_seed(hyperparam_config.seed) # Read train and dev corpora. print "Reading corpora..." train_corpus = read_corpus(FLAGS.train_corpus) dev_corpus = read_corpus(FLAGS.dev_corpus) # Prepare checkpoint folder. checkpoint_path = os.path.join(FLAGS.output_folder, 'checkpoints/best') checkpoint_dir = os.path.dirname(checkpoint_path) empty_dir(checkpoint_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) # Run training. trainer_lib.run_training( sess, trainers, annotator, evaluator, [0], # pretrain_steps [FLAGS.train_steps], train_corpus, dev_corpus, dev_corpus, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, checkpoint_path) # Convert model to a Myelin flow. if len(FLAGS.flow) != 0: tf.logging.info('Saving flow to %s', FLAGS.flow) flow = convert_model(master_spec, sess) flow.save(FLAGS.flow) tf.logging.info('Best checkpoint written to %s', checkpoint_path)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) session_config = tf.ConfigProto(log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) master_spec = spec_pb2.MasterSpec() master_spec_file = FLAGS.parser_dir + "/master_spec" with file(master_spec_file, 'r') as fin: text_format.Parse(fin.read(), master_spec) fin.close() tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') checkpoint = FLAGS.parser_dir + "/checkpoints/best" sess.run('save/restore_all', {'save/Const:0': checkpoint}) # Annotate the corpus. corpus = read_corpus(FLAGS.corpus) annotated = [] annotation_time = 0 for start in range(0, len(corpus), FLAGS.batch_size): end = min(start + FLAGS.batch_size, len(corpus)) feed_dict = {annotator['input_batch']: corpus[start:end]} start_time = timeit.default_timer() output = sess.run(annotator['annotations'], feed_dict=feed_dict) annotation_time += (timeit.default_timer() - start_time) annotated.extend(output) tf.logging.info("Wall clock time for %s annotation: %f seconds", len(annotated), annotation_time) output_file = FLAGS.output if FLAGS.evaluate and len(FLAGS.output) == 0: output_file = "/tmp/annotated.zip" tf.logging.info( '--output not provided, will write annotated docs to %s', output_file) if FLAGS.evaluate or len(FLAGS.output) != 0: # Write the annotated corpus to disk as a zip file. with zipfile.ZipFile(output_file, 'w') as outfile: for i in xrange(len(annotated)): outfile.writestr('test.' + str(i), annotated[i]) tf.logging.info('Wrote %d annotated docs to %s', len(annotated), output_file) if FLAGS.evaluate: # Evaluate against gold annotations. try: eval_output_lines = subprocess.check_output( [ 'bazel-bin/nlp/parser/tools/evaluate-frames', '--gold_documents=' + FLAGS.corpus, '--test_documents=' + output_file, '--commons=' + FLAGS.commons ], stderr=subprocess.STDOUT) eval_output = {} eval_metric = -1 for line in eval_output_lines.splitlines(): line = line.rstrip() tf.logging.info("Evaluation Metric: %s", line) parts = line.split('\t') assert len(parts) == 2, line eval_output[parts[0]] = float(parts[1]) if line.startswith("SLOT_F1"): eval_metric = float(parts[1]) assert eval_metric != -1, "Missing SLOT F1" tf.logging.info('Overall Evaluation Metric: %f', eval_metric) except subprocess.CalledProcessError as e: print("Evaluation failed: ", e.returncode, e.output)
def main(unused_argv): logging.set_verbosity(logging.INFO) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon(FLAGS.resource_path, FLAGS.training_corpus_path) # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN # sequence model, which encodes the context to the right of each token. It has # no loss except for the downstream components. lookahead = spec_builder.ComponentSpecBuilder('lookahead') lookahead.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') lookahead.set_transition_system(name='shift-only', left_to_right='false') lookahead.add_fixed_feature(name='char', fml='input(-1).char input.char input(1).char', embedding_dim=32) lookahead.add_fixed_feature(name='char-bigram', fml='input.char-bigram', embedding_dim=32) lookahead.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Construct the ComponentSpec for segmentation. segmenter = spec_builder.ComponentSpecBuilder('segmenter') segmenter.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='128') segmenter.set_transition_system(name='binary-segment-transitions') segmenter.add_token_link(source=lookahead, fml='input.focus stack.focus', embedding_dim=64) segmenter.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Build and write master_spec. master_spec = spec_pb2.MasterSpec() master_spec.component.extend([lookahead.spec, segmenter.spec]) logging.info('Constructed master spec: %s', str(master_spec)) with gfile.GFile(FLAGS.resource_path + '/master_spec', 'w') as f: f.write(str(master_spec).encode('utf-8')) hyperparam_config = spec_pb2.GridPoint() try: text_format.Parse(FLAGS.hyperparams, hyperparam_config) except text_format.ParseError: text_format.Parse(base64.b64decode(FLAGS.hyperparams), hyperparam_config) # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = spec_builder.default_targets_from_spec(master_spec) trainers = [ builder.add_training_from_config(target) for target in component_targets ] assert len(trainers) == 1 annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = ConllSentenceReader(FLAGS.training_corpus_path, projectivize=False).corpus() dev_set = ConllSentenceReader(FLAGS.dev_corpus_path, projectivize=False).corpus() # Convert word-based docs to char-based documents for segmentation training # and evaluation. with tf.Session(graph=tf.Graph()) as tmp_session: char_training_set_op = gen_parser_ops.segmenter_training_data_constructor( training_set) char_dev_set_op = gen_parser_ops.char_token_generator(dev_set) char_training_set = tmp_session.run(char_training_set_op) char_dev_set = tmp_session.run(char_dev_set_op) # Ready to train! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(dev_set)) pretrain_steps = [0] train_steps = [FLAGS.num_epochs * len(training_set)] tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) trainer_lib.run_training( sess, trainers, annotator, evaluation.segmentation_summaries, pretrain_steps, train_steps, char_training_set, char_dev_set, dev_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename)
def main(unused_argv): logging.set_verbosity(logging.INFO) check.IsTrue(FLAGS.checkpoint_filename) check.IsTrue(FLAGS.tensorboard_dir) check.IsTrue(FLAGS.resource_path) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0] tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0] # SummaryWriter for TensorBoard tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir) tf.logging.info('Deleting prior data if exists...') stats_file = '%s.stats' % FLAGS.checkpoint_filename try: stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',') stats = [int(x) for x in stats] except errors.OpError: stats = [-1, 0, 0] tf.logging.info('Read ckpt stats: %s', str(stats)) do_restore = True if stats[0] < FLAGS.job_id: do_restore = False tf.logging.info('Deleting last job: %d', stats[0]) try: gfile.DeleteRecursively(FLAGS.tensorboard_dir) gfile.Remove(FLAGS.checkpoint_filename) except errors.OpError as err: tf.logging.error('Unable to delete prior files: %s', err) stats = [FLAGS.job_id, 0, 0] tf.logging.info('Creating the directory again...') gfile.MakeDirs(FLAGS.tensorboard_dir) tf.logging.info('Created! Instatiating SummaryWriter...') summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon(FLAGS.resource_path, training_corpus_path, morph_to_pos=True) tf.logging.info('Loading MasterSpec...') master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin: text_format.Parse(fin.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path) logging.info('Constructed master spec: %s', str(master_spec)) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. tf.logging.info('Building Graph...') hyperparam_config = spec_pb2.GridPoint() try: text_format.Parse(FLAGS.hyperparams, hyperparam_config) except text_format.ParseError: text_format.Parse(base64.b64decode(FLAGS.hyperparams), hyperparam_config) g = tf.Graph() with g.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = [ spec_pb2.TrainTarget(name=component.name, max_index=idx + 1, unroll_using_oracle=[False] * idx + [True]) for idx, component in enumerate(master_spec.component) if 'shift-only' not in component.transition_system.registered_name ] trainers = [ builder.add_training_from_config(target) for target in component_targets ] annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = ConllSentenceReader( training_corpus_path, projectivize=FLAGS.projectivize_training_set, morph_to_pos=True).corpus() tune_set = ConllSentenceReader(tune_corpus_path, projectivize=False, morph_to_pos=True).corpus() # Ready to train_bkp! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(tune_set)) pretrain_steps = [10000, 0] tagger_steps = 100000 train_steps = [tagger_steps, 8 * tagger_steps] with tf.Session(FLAGS.tf_master, graph=g) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) if do_restore: tf.logging.info('Restoring from checkpoint...') builder.saver.restore(sess, FLAGS.checkpoint_filename) prev_tagger_steps = stats[1] prev_parser_steps = stats[2] tf.logging.info('adjusting schedule from steps: %d, %d', prev_tagger_steps, prev_parser_steps) pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0) tf.logging.info('new pretrain steps: %d', pretrain_steps[0]) trainer_lib.run_training(sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, training_set, tune_set, tune_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename, stats)
def export_to_graph(master_spec, params_path, export_path, external_graph, export_moving_averages, signature_name='model'): """Restores a model and exports it in SavedModel form. This method loads a graph specified by the master_spec and the params in params_path into the graph given in external_graph. It then saves the model in SavedModel format to the location specified in export_path. Args: master_spec: Proto master spec. params_path: Path to the parameters file to export. export_path: Path to export the SavedModel to. external_graph: A tf.Graph() object to build the graph inside. export_moving_averages: Whether to export the moving average parameters. signature_name: Name of the signature to insert. """ tf.logging.info( 'Exporting graph with signature_name "%s" and use_moving_averages = %s' % (signature_name, export_moving_averages)) tf.logging.info('Building the graph') with external_graph.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = export_moving_averages builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) post_restore_hook = builder.build_post_restore_hook() annotation = builder.add_annotation() builder.add_saver() # Resets session. session_config = tf.ConfigProto(log_device_placement=False, intra_op_parallelism_threads=10, inter_op_parallelism_threads=10) with tf.Session(graph=external_graph, config=session_config) as session: tf.logging.info('Initializing variables...') session.run(tf.global_variables_initializer()) tf.logging.info('Loading params...') session.run('save/restore_all', {'save/Const:0': params_path}) tf.logging.info('Saving.') with tf.device('/device:CPU:0'): saved_model_builder = tf.saved_model.builder.SavedModelBuilder( export_path) signature_map = { signature_name: tf.saved_model.signature_def_utils.build_signature_def( inputs={ 'inputs': tf.saved_model.utils.build_tensor_info( annotation['input_batch']) }, outputs={ 'annotations': tf.saved_model.utils.build_tensor_info( annotation['annotations']) }, method_name=tf.saved_model.signature_constants. PREDICT_METHOD_NAME), } tf.logging.info('Input is: %s', annotation['input_batch'].name) tf.logging.info('Output is: %s', annotation['annotations'].name) saved_model_builder.add_meta_graph_and_variables( session, tags=_SAVED_MODEL_TAGS, legacy_init_op=tf.group( post_restore_hook, builder.build_warmup_graph( tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)[0])), signature_def_map=signature_map, assets_collection=tf.get_collection( tf.GraphKeys.ASSET_FILEPATHS)) saved_model_builder.save()