def _load_model(self, base_dir, master_spec_name): master_spec = spec_pb2.MasterSpec() with open(os.path.join(base_dir, master_spec_name)) as f: text_format.Merge(f.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, base_dir) graph = tf.Graph() with graph.as_default(): hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder( master_spec, hyperparam_config ) annotator = builder.add_annotation(enable_tracing=True) builder.add_saver() sess = tf.Session(graph=graph) with graph.as_default(): builder.saver.restore(sess, os.path.join(base_dir, "checkpoint")) def annotate_sentence(sentence): with graph.as_default(): return sess.run( [annotator['annotations'], annotator['traces']], feed_dict={annotator['input_batch']: [sentence]} ) return annotate_sentence
def load_model(self, base_dir, master_spec_name, checkpoint_name): # Read the master spec master_spec = spec_pb2.MasterSpec() with open(os.path.join(base_dir, master_spec_name), "r") as f: text_format.Merge(f.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, base_dir) logging.set_verbosity(logging.WARN) # Turn off TensorFlow spam. # Initialize a graph graph = tf.Graph() with graph.as_default(): hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) # This is the component that will annotate test sentences. annotator = builder.add_annotation(enable_tracing=True) builder.add_saver( ) # "Savers" can save and load models; here, we're only going to load. sess = tf.Session(graph=graph) with graph.as_default(): # sess.run(tf.global_variables_initializer()) # sess.run('save/restore_all', {'save/Const:0': os.path.join(base_dir, checkpoint_name)}) builder.saver.restore(sess, os.path.join(base_dir, checkpoint_name)) def annotate_sentence(sentence): with graph.as_default(): return sess.run( [annotator['annotations'], annotator['traces']], feed_dict={annotator['input_batch']: [sentence]}) return annotate_sentence
def load_model(self, base_dir, master_spec_name, checkpoint_name="checkpoint", rename=True): try: master_spec = spec_pb2.MasterSpec() with open(os.path.join(base_dir, master_spec_name)) as f: text_format.Merge(f.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, base_dir) graph = tf.Graph() with graph.as_default(): hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation(enable_tracing=True) builder.add_saver() sess = tf.Session(graph=graph) with graph.as_default(): builder.saver.restore(sess, os.path.join(base_dir, checkpoint_name)) def annotate_sentence(sentence): with graph.as_default(): return sess.run([annotator['annotations'], annotator['traces']], feed_dict={annotator['input_batch']: [sentence]}) except: if rename: self.rename_vars(base_dir, checkpoint_name) return self.load_model(base_dir, master_spec_name, checkpoint_name, False) raise Exception('Cannot load model: spec expects references to */kernel tensors instead of */weights.\ Try running with rename=True or run rename_vars() to convert existing checkpoint files into supported format') return annotate_sentence
def load_master_spec(spec_file, resource_path): tf.logging.info('Loading MasterSpec...') master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(spec_file, 'r') as fin: text_format.Parse(fin.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, resource_path) logging.info('Constructed master spec: %s', str(master_spec)) return master_spec
def load_master_spec(spec_file, resource_path) : tf.logging.info('Loading MasterSpec...') master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(spec_file, 'r') as fin : text_format.Parse(fin.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, resource_path) logging.info('Constructed master spec: %s', str(master_spec)) return master_spec
def export(master_spec_path, params_path, resource_path, export_path, export_moving_averages): """Restores a model and exports it in SavedModel form. This method loads a graph specified by the spec at master_spec_path and the params in params_path. It then saves the model in SavedModel format to the location specified in export_path. Args: master_spec_path: Path to a proto-text master spec. params_path: Path to the parameters file to export. resource_path: Path to resources in the master spec. export_path: Path to export the SavedModel to. export_moving_averages: Whether to export the moving average parameters. """ # Old CoNLL checkpoints did not need a known-word-map. Create a temporary if # that file is missing. if not tf.gfile.Exists(os.path.join(resource_path, 'known-word-map')): with tf.gfile.FastGFile(os.path.join(resource_path, 'known-word-map'), 'w') as out_file: out_file.write('This file intentionally left blank.') graph = tf.Graph() master_spec = spec_pb2.MasterSpec() with tf.gfile.FastGFile(master_spec_path) as fin: text_format.Parse(fin.read(), master_spec) # This is a workaround for an issue where the segmenter master-spec had a # spurious resource in it; this resource was not respected in the spec-builder # and ended up crashing the saver (since it didn't really exist). for component in master_spec.component: del component.resource[:] spec_builder.complete_master_spec(master_spec, None, resource_path) # Remove '/' if it exists at the end of the export path, ensuring that # path utils work correctly. stripped_path = export_path.rstrip('/') saver_lib.clean_output_paths(stripped_path) short_to_original = saver_lib.shorten_resource_paths(master_spec) saver_lib.export_master_spec(master_spec, graph) saver_lib.export_to_graph(master_spec, params_path, stripped_path, graph, export_moving_averages) saver_lib.export_assets(master_spec, short_to_original, stripped_path)
def export(master_spec_path, params_path, resource_path, export_path, export_moving_averages): """Restores a model and exports it in SavedModel form. This method loads a graph specified by the spec at master_spec_path and the params in params_path. It then saves the model in SavedModel format to the location specified in export_path. Args: master_spec_path: Path to a proto-text master spec. params_path: Path to the parameters file to export. resource_path: Path to resources in the master spec. export_path: Path to export the SavedModel to. export_moving_averages: Whether to export the moving average parameters. """ # Old CoNLL checkpoints did not need a known-word-map. Create a temporary if # that file is missing. if not tf.gfile.Exists(os.path.join(resource_path, 'known-word-map')): with tf.gfile.FastGFile(os.path.join(resource_path, 'known-word-map'), 'w') as out_file: out_file.write('This file intentionally left blank.') graph = tf.Graph() master_spec = spec_pb2.MasterSpec() with tf.gfile.FastGFile(master_spec_path) as fin: text_format.Parse(fin.read(), master_spec) # This is a workaround for an issue where the segmenter master-spec had a # spurious resource in it; this resource was not respected in the spec-builder # and ended up crashing the saver (since it didn't really exist). for component in master_spec.component: del component.resource[:] spec_builder.complete_master_spec(master_spec, None, resource_path) # Remove '/' if it exists at the end of the export path, ensuring that # path utils work correctly. stripped_path = export_path.rstrip('/') saver_lib.clean_output_paths(stripped_path) short_to_original = saver_lib.shorten_resource_paths(master_spec) saver_lib.export_master_spec(master_spec, graph) saver_lib.export_to_graph(master_spec, params_path, stripped_path, graph, export_moving_averages) saver_lib.export_assets(master_spec, short_to_original, stripped_path)
def main(unused_argv): logging.set_verbosity(logging.INFO) check.IsTrue(FLAGS.checkpoint_filename) check.IsTrue(FLAGS.tensorboard_dir) check.IsTrue(FLAGS.resource_path) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0] tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0] # SummaryWriter for TensorBoard tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir) tf.logging.info('Deleting prior data if exists...') stats_file = '%s.stats' % FLAGS.checkpoint_filename try: stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',') stats = [int(x) for x in stats] except errors.OpError: stats = [-1, 0, 0] tf.logging.info('Read ckpt stats: %s', str(stats)) do_restore = True if stats[0] < FLAGS.job_id: do_restore = False tf.logging.info('Deleting last job: %d', stats[0]) try: gfile.DeleteRecursively(FLAGS.tensorboard_dir) gfile.Remove(FLAGS.checkpoint_filename) except errors.OpError as err: tf.logging.error('Unable to delete prior files: %s', err) stats = [FLAGS.job_id, 0, 0] tf.logging.info('Creating the directory again...') gfile.MakeDirs(FLAGS.tensorboard_dir) tf.logging.info('Created! Instatiating SummaryWriter...') summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon(FLAGS.resource_path, training_corpus_path, morph_to_pos=True) tf.logging.info('Loading MasterSpec...') master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin: text_format.Parse(fin.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path) logging.info('Constructed master spec: %s', str(master_spec)) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. tf.logging.info('Building Graph...') hyperparam_config = spec_pb2.GridPoint() try: text_format.Parse(FLAGS.hyperparams, hyperparam_config) except text_format.ParseError: text_format.Parse(base64.b64decode(FLAGS.hyperparams), hyperparam_config) g = tf.Graph() with g.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = [ spec_pb2.TrainTarget(name=component.name, max_index=idx + 1, unroll_using_oracle=[False] * idx + [True]) for idx, component in enumerate(master_spec.component) if 'shift-only' not in component.transition_system.registered_name ] trainers = [ builder.add_training_from_config(target) for target in component_targets ] annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = ConllSentenceReader( training_corpus_path, projectivize=FLAGS.projectivize_training_set, morph_to_pos=True).corpus() tune_set = ConllSentenceReader(tune_corpus_path, projectivize=False, morph_to_pos=True).corpus() # Ready to train_bkp! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(tune_set)) pretrain_steps = [10000, 0] tagger_steps = 100000 train_steps = [tagger_steps, 8 * tagger_steps] with tf.Session(FLAGS.tf_master, graph=g) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) if do_restore: tf.logging.info('Restoring from checkpoint...') builder.saver.restore(sess, FLAGS.checkpoint_filename) prev_tagger_steps = stats[1] prev_parser_steps = stats[2] tf.logging.info('adjusting schedule from steps: %d, %d', prev_tagger_steps, prev_parser_steps) pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0) tf.logging.info('new pretrain steps: %d', pretrain_steps[0]) trainer_lib.run_training(sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, training_set, tune_set, tune_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename, stats)
def build_complete_master_spec(resource_path): tf.logging.info('Building MasterSpec...') master_spec = build_master_spec() spec_builder.complete_master_spec(master_spec, None, resource_path) logging.info('Constructed master spec: %s', str(master_spec)) return master_spec
def main(unused_argv): # Parse the flags containint lists, using regular expressions. # This matches and extracts key=value pairs. component_beam_sizes = re.findall(r'([^=,]+)=(\d+)', FLAGS.inference_beam_size) # This matches strings separated by a comma. Does not return any empty # strings. components_to_locally_normalize = re.findall(r'[^,]+', FLAGS.locally_normalize) # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.master_spec) as fin: text_format.Parse(fin.read(), master_spec) # Rewrite resource locations. if FLAGS.resource_dir: for component in master_spec.component: for resource in component.resource: for part in resource.part: part.file_pattern = os.path.join(FLAGS.resource_dir, part.file_pattern) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus() with tf.Session(graph=tf.Graph()) as tmp_session: char_input = gen_parser_ops.char_token_generator(input_corpus) char_corpus = tmp_session.run(char_input) check.Eq(len(input_corpus), len(char_corpus)) session_config = tf.ConfigProto( log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(char_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(char_corpus)) feed_dict = {annotator['input_batch']: char_corpus[start:end]} for comp, beam_size in component_beam_sizes: feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size for comp in components_to_locally_normalize: feed_dict['%s/LocallyNormalize:0' % comp] = True if FLAGS.timeline_output_file and end == len(char_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline(step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(char_corpus), time.time() - start_time) evaluation.calculate_segmentation_metrics(input_corpus, processed) if FLAGS.output_file: with gfile.GFile(FLAGS.output_file, 'w') as f: for serialized_sentence in processed: sentence = sentence_pb2.Sentence() sentence.ParseFromString(serialized_sentence) f.write(text_format.MessageToString(sentence) + '\n\n')
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Parse the flags containint lists, using regular expressions. # This matches and extracts key=value pairs. component_beam_sizes = re.findall(r'([^=,]+)=(\d+)', FLAGS.inference_beam_size) # This matches strings separated by a comma. Does not return any empty # strings. components_to_locally_normalize = re.findall(r'[^,]+', FLAGS.locally_normalize) # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.master_spec) as fin: text_format.Parse(fin.read(), master_spec) # Rewrite resource locations. if FLAGS.resource_dir: for component in master_spec.component: for resource in component.resource: for part in resource.part: part.file_pattern = os.path.join(FLAGS.resource_dir, part.file_pattern) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus() session_config = tf.ConfigProto(log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(input_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(input_corpus)) feed_dict = {annotator['input_batch']: input_corpus[start:end]} for comp, beam_size in component_beam_sizes: feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size for comp in components_to_locally_normalize: feed_dict['%s/LocallyNormalize:0' % comp] = True if FLAGS.timeline_output_file and end == len(input_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline(step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run(annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(input_corpus), time.time() - start_time) pos, uas, las = evaluation.calculate_parse_metrics( input_corpus, processed) if FLAGS.log_file: with gfile.GFile(FLAGS.log_file, 'w') as f: f.write('%s\t%f\t%f\t%f\n' % (FLAGS.language_name, pos, uas, las)) if FLAGS.output_file: with gfile.GFile(FLAGS.output_file, 'w') as f: for serialized_sentence in processed: sentence = sentence_pb2.Sentence() sentence.ParseFromString(serialized_sentence) f.write(text_format.MessageToString(sentence) + '\n\n')
def main(unused_argv): logging.set_verbosity(logging.INFO) check.IsTrue(FLAGS.checkpoint_filename) check.IsTrue(FLAGS.tensorboard_dir) check.IsTrue(FLAGS.resource_path) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0] tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0] # SummaryWriter for TensorBoard tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir) tf.logging.info('Deleting prior data if exists...') stats_file = '%s.stats' % FLAGS.checkpoint_filename try: stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',') stats = [int(x) for x in stats] except errors.OpError: stats = [-1, 0, 0] tf.logging.info('Read ckpt stats: %s', str(stats)) do_restore = True if stats[0] < FLAGS.job_id: do_restore = False tf.logging.info('Deleting last job: %d', stats[0]) try: gfile.DeleteRecursively(FLAGS.tensorboard_dir) gfile.Remove(FLAGS.checkpoint_filename) except errors.OpError as err: tf.logging.error('Unable to delete prior files: %s', err) stats = [FLAGS.job_id, 0, 0] tf.logging.info('Creating the directory again...') gfile.MakeDirs(FLAGS.tensorboard_dir) tf.logging.info('Created! Instatiating SummaryWriter...') summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon( FLAGS.resource_path, training_corpus_path, morph_to_pos=True) tf.logging.info('Loading MasterSpec...') master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin: text_format.Parse(fin.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path) logging.info('Constructed master spec: %s', str(master_spec)) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. tf.logging.info('Building Graph...') hyperparam_config = spec_pb2.GridPoint() try: text_format.Parse(FLAGS.hyperparams, hyperparam_config) except text_format.ParseError: text_format.Parse(base64.b64decode(FLAGS.hyperparams), hyperparam_config) g = tf.Graph() with g.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = [ spec_pb2.TrainTarget( name=component.name, max_index=idx + 1, unroll_using_oracle=[False] * idx + [True]) for idx, component in enumerate(master_spec.component) if 'shift-only' not in component.transition_system.registered_name ] trainers = [ builder.add_training_from_config(target) for target in component_targets ] annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = ConllSentenceReader( training_corpus_path, projectivize=FLAGS.projectivize_training_set, morph_to_pos=True).corpus() tune_set = ConllSentenceReader( tune_corpus_path, projectivize=False, morph_to_pos=True).corpus() # Ready to train! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(tune_set)) pretrain_steps = [10000, 0] tagger_steps = 100000 train_steps = [tagger_steps, 8 * tagger_steps] with tf.Session(FLAGS.tf_master, graph=g) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) if do_restore: tf.logging.info('Restoring from checkpoint...') builder.saver.restore(sess, FLAGS.checkpoint_filename) prev_tagger_steps = stats[1] prev_parser_steps = stats[2] tf.logging.info('adjusting schedule from steps: %d, %d', prev_tagger_steps, prev_parser_steps) pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0) tf.logging.info('new pretrain steps: %d', pretrain_steps[0]) trainer_lib.run_training( sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, training_set, tune_set, tune_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename, stats)
def build_complete_master_spec(resource_path) : tf.logging.info('Building MasterSpec...') master_spec = build_master_spec() spec_builder.complete_master_spec(master_spec, None, resource_path) logging.info('Constructed master spec: %s', str(master_spec)) return master_spec
def main(unused_argv): # Parse the flags containint lists, using regular expressions. # This matches and extracts key=value pairs. component_beam_sizes = re.findall(r'([^=,]+)=(\d+)', FLAGS.inference_beam_size) # This matches strings separated by a comma. Does not return any empty # strings. components_to_locally_normalize = re.findall(r'[^,]+', FLAGS.locally_normalize) ## SEGMENTATION ## if not FLAGS.use_gold_segmentation: # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin: text_format.Parse(fin.read(), master_spec) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.segmenter_resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') input_corpus = sentence_io.ConllSentenceReader( FLAGS.input_file).corpus() with tf.Session(graph=tf.Graph()) as tmp_session: char_input = gen_parser_ops.char_token_generator(input_corpus) char_corpus = tmp_session.run(char_input) check.Eq(len(input_corpus), len(char_corpus)) session_config = tf.ConfigProto( log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.segmenter_checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(char_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(char_corpus)) feed_dict = {annotator['input_batch']: char_corpus[start:end]} if FLAGS.timeline_output_file and end == len(char_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline( step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run(annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(char_corpus), time.time() - start_time) input_corpus = processed else: input_corpus = sentence_io.ConllSentenceReader( FLAGS.input_file).corpus() ## PARSING # Reads master spec. master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.parser_master_spec) as fin: text_format.Parse(fin.read(), master_spec) if FLAGS.complete_master_spec: spec_builder.complete_master_spec(master_spec, None, FLAGS.parser_resource_dir) # Graph building. tf.logging.info('Building the graph') g = tf.Graph() with g.as_default(), tf.device('/device:CPU:0'): hyperparam_config = spec_pb2.GridPoint() hyperparam_config.use_moving_average = True builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) annotator = builder.add_annotation() builder.add_saver() tf.logging.info('Reading documents...') session_config = tf.ConfigProto(log_device_placement=False, intra_op_parallelism_threads=FLAGS.threads, inter_op_parallelism_threads=FLAGS.threads) with tf.Session(graph=g, config=session_config) as sess: tf.logging.info('Initializing variables...') sess.run(tf.global_variables_initializer()) tf.logging.info('Loading from checkpoint...') sess.run('save/restore_all', {'save/Const:0': FLAGS.parser_checkpoint_file}) tf.logging.info('Processing sentences...') processed = [] start_time = time.time() run_metadata = tf.RunMetadata() for start in range(0, len(input_corpus), FLAGS.max_batch_size): end = min(start + FLAGS.max_batch_size, len(input_corpus)) feed_dict = {annotator['input_batch']: input_corpus[start:end]} for comp, beam_size in component_beam_sizes: feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size for comp in components_to_locally_normalize: feed_dict['%s/LocallyNormalize:0' % comp] = True if FLAGS.timeline_output_file and end == len(input_corpus): serialized_annotations = sess.run( annotator['annotations'], feed_dict=feed_dict, options=tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE), run_metadata=run_metadata) trace = timeline.Timeline(step_stats=run_metadata.step_stats) with open(FLAGS.timeline_output_file, 'w') as trace_file: trace_file.write(trace.generate_chrome_trace_format()) else: serialized_annotations = sess.run(annotator['annotations'], feed_dict=feed_dict) processed.extend(serialized_annotations) tf.logging.info('Processed %d documents in %.2f seconds.', len(input_corpus), time.time() - start_time) if FLAGS.output_file: with gfile.GFile(FLAGS.output_file, 'w') as f: for serialized_sentence in processed: sentence = sentence_pb2.Sentence() sentence.ParseFromString(serialized_sentence) f.write('#' + sentence.text.encode('utf-8') + '\n') for i, token in enumerate(sentence.token): head = token.head + 1 f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' % (i + 1, token.word.encode('utf-8'), head, token.label.encode('utf-8'))) f.write('\n\n')