def _load_model(self, base_dir, master_spec_name):
        master_spec = spec_pb2.MasterSpec()
        with open(os.path.join(base_dir, master_spec_name)) as f:
            text_format.Merge(f.read(), master_spec)
        spec_builder.complete_master_spec(master_spec, None, base_dir)

        graph = tf.Graph()
        with graph.as_default():
            hyperparam_config = spec_pb2.GridPoint()
            builder = graph_builder.MasterBuilder(
                master_spec,
                hyperparam_config
            )
            annotator = builder.add_annotation(enable_tracing=True)
            builder.add_saver()

        sess = tf.Session(graph=graph)
        with graph.as_default():
            builder.saver.restore(sess, os.path.join(base_dir, "checkpoint"))

        def annotate_sentence(sentence):
            with graph.as_default():
                return sess.run(
                    [annotator['annotations'], annotator['traces']],
                    feed_dict={annotator['input_batch']: [sentence]}
                )
        return annotate_sentence
예제 #2
0
  def RunTraining(self, hyperparam_config):
    master_spec = self.LoadSpec('master_spec_link.textproto')

    self.assertTrue(isinstance(hyperparam_config, spec_pb2.GridPoint))
    gold_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
    gold_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
    reader_strings = [
        gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
    ]
    tf.logging.info('Generating graph with config: %s', hyperparam_config)
    with tf.Graph().as_default():
      builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

      target = spec_pb2.TrainTarget()
      target.name = 'testTraining-all'
      train = builder.add_training_from_config(target)
      with self.test_session() as sess:
        logging.info('Initializing')
        sess.run(tf.global_variables_initializer())

        # Run one iteration of training and verify nothing crashes.
        logging.info('Training')
        sess.run(train['run'], feed_dict={train['input_batch']: reader_strings})
예제 #3
0
    def load_model(self, base_dir, master_spec_name, checkpoint_name="checkpoint", rename=True):
        try:
            master_spec = spec_pb2.MasterSpec()
            with open(os.path.join(base_dir, master_spec_name)) as f:
                text_format.Merge(f.read(), master_spec)
            spec_builder.complete_master_spec(master_spec, None, base_dir)

            graph = tf.Graph()
            with graph.as_default():
                hyperparam_config = spec_pb2.GridPoint()
                builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
                annotator = builder.add_annotation(enable_tracing=True)
                builder.add_saver()

            sess = tf.Session(graph=graph)
            with graph.as_default():
                builder.saver.restore(sess, os.path.join(base_dir, checkpoint_name))

            def annotate_sentence(sentence):
                with graph.as_default():
                    return sess.run([annotator['annotations'], annotator['traces']],
                                    feed_dict={annotator['input_batch']: [sentence]})
        except:
            if rename:
                self.rename_vars(base_dir, checkpoint_name)
                return self.load_model(base_dir, master_spec_name, checkpoint_name, False)
            raise Exception('Cannot load model: spec expects references to */kernel tensors instead of */weights.\
            Try running with rename=True or run rename_vars() to convert existing checkpoint files into supported format')

        return annotate_sentence
예제 #4
0
def build_train_graph(master_spec, hyperparam_config=None):
    # Build the TensorFlow graph based on the DRAGNN network spec.
    tf.logging.info('Building Graph...')
    if not hyperparam_config:
        hyperparam_config = spec_pb2.GridPoint(learning_method='adam',
                                               learning_rate=0.0005,
                                               adam_beta1=0.9,
                                               adam_beta2=0.9,
                                               adam_eps=0.00001,
                                               decay_steps=128000,
                                               dropout_rate=0.8,
                                               gradient_clip_norm=1,
                                               use_moving_average=True,
                                               seed=1)
    graph = tf.Graph()
    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        component_targets = [
            spec_pb2.TrainTarget(name=component.name,
                                 max_index=idx + 1,
                                 unroll_using_oracle=[False] * idx + [True])
            for idx, component in enumerate(master_spec.component)
            if 'shift-only' not in component.transition_system.registered_name
        ]
        trainers = [
            builder.add_training_from_config(target)
            for target in component_targets
        ]
        annotator = builder.add_annotation(enable_tracing=True)
        builder.add_saver()
        return graph, builder, trainers, annotator
    def load_model(self, base_dir, master_spec_name, checkpoint_name):
        # Read the master spec
        master_spec = spec_pb2.MasterSpec()
        with open(os.path.join(base_dir, master_spec_name), "r") as f:
            text_format.Merge(f.read(), master_spec)
        spec_builder.complete_master_spec(master_spec, None, base_dir)
        logging.set_verbosity(logging.WARN)  # Turn off TensorFlow spam.

        # Initialize a graph
        graph = tf.Graph()
        with graph.as_default():
            hyperparam_config = spec_pb2.GridPoint()
            builder = graph_builder.MasterBuilder(master_spec,
                                                  hyperparam_config)
            # This is the component that will annotate test sentences.
            annotator = builder.add_annotation(enable_tracing=True)
            builder.add_saver(
            )  # "Savers" can save and load models; here, we're only going to load.

        sess = tf.Session(graph=graph)
        with graph.as_default():
            # sess.run(tf.global_variables_initializer())
            # sess.run('save/restore_all', {'save/Const:0': os.path.join(base_dir, checkpoint_name)})
            builder.saver.restore(sess, os.path.join(base_dir,
                                                     checkpoint_name))

        def annotate_sentence(sentence):
            with graph.as_default():
                return sess.run(
                    [annotator['annotations'], annotator['traces']],
                    feed_dict={annotator['input_batch']: [sentence]})

        return annotate_sentence
예제 #6
0
def build_inference_graph(master_spec, enable_tracing=False):
    # Initialize a graph
    tf.logging.info('Building Graph...')
    graph = tf.Graph()
    with graph.as_default():
        hyperparam_config = spec_pb2.GridPoint()
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        # This is the component that will annotate test sentences.
        annotator = builder.add_annotation(enable_tracing=enable_tracing)
        builder.add_saver()
    return graph, builder, annotator
예제 #7
0
 def getBuilderAndTarget(
     self, test_name, master_spec_path='simple_parser_master_spec.textproto'):
   """Generates a MasterBuilder and TrainTarget based on a simple spec."""
   master_spec = self.LoadSpec(master_spec_path)
   hyperparam_config = spec_pb2.GridPoint()
   target = spec_pb2.TrainTarget()
   target.name = 'test-%s-train' % test_name
   target.component_weights.extend([0] * len(master_spec.component))
   target.component_weights[-1] = 1.0
   target.unroll_using_oracle.extend([False] * len(master_spec.component))
   target.unroll_using_oracle[-1] = True
   builder = graph_builder.MasterBuilder(
       master_spec, hyperparam_config, pool_scope=test_name)
   return builder, target
예제 #8
0
def main(argv):
    del argv  # unused
    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    lexicon.build_lexicon(lexicon_dir,
                          training_sentence,
                          training_corpus_format='sentence-prototext')

    # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
    # sequence tagger.
    tagger = spec_builder.ComponentSpecBuilder('tagger')
    tagger.set_network_unit(name='FeedForwardNetwork',
                            hidden_layer_sizes='256')
    tagger.set_transition_system(name='tagger')
    tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
    tagger.add_rnn_link(embedding_dim=-1)
    tagger.fill_from_resources(lexicon_dir)

    master_spec = spec_pb2.MasterSpec()
    master_spec.component.extend([tagger.spec])

    hyperparam_config = spec_pb2.GridPoint()

    # Build the TensorFlow graph.
    graph = tf.Graph()
    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

        target = spec_pb2.TrainTarget()
        target.name = 'all'
        target.unroll_using_oracle.extend([True])
        dry_run = builder.add_training_from_config(target, trace_only=True)

    # Read in serialized protos from training data.
    sentence = sentence_pb2.Sentence()
    text_format.Merge(open(training_sentence).read(), sentence)
    training_set = [sentence.SerializeToString()]

    with tf.Session(graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.initialize_all_variables())
        traces = sess.run(dry_run['traces'],
                          feed_dict={dry_run['input_batch']: training_set})

    with open('dragnn_tutorial_1.html', 'w') as f:
        f.write(
            visualization.trace_html(traces[0],
                                     height='300px').encode('utf-8'))
예제 #9
0
def main(argv):
  del argv  # unused
  # Constructs lexical resources for SyntaxNet in the given resource path, from
  # the training data.
  lexicon.build_lexicon(
      lexicon_dir,
      training_sentence,
      training_corpus_format='sentence-prototext')

  # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
  # sequence tagger.
  tagger = spec_builder.ComponentSpecBuilder('tagger')
  tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
  tagger.set_transition_system(name='tagger')
  tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
  tagger.add_rnn_link(embedding_dim=-1)
  tagger.fill_from_resources(lexicon_dir)

  # Construct the ComponentSpec for parsing.
  parser = spec_builder.ComponentSpecBuilder('parser')
  parser.set_network_unit(
      name='FeedForwardNetwork',
      hidden_layer_sizes='256',
      layer_norm_hidden='True')
  parser.set_transition_system(name='arc-standard')
  parser.add_token_link(
      source=tagger,
      fml='input.focus stack.focus stack(1).focus',
      embedding_dim=32,
      source_layer='logits')

  # Recurrent connection for the arc-standard parser. For both tokens on the
  # stack, we connect to the last time step to either SHIFT or REDUCE that
  # token. This allows the parser to build up compositional representations of
  # phrases.
  parser.add_link(
      source=parser,  # recurrent connection
      name='rnn-stack',  # unique identifier
      fml='stack.focus stack(1).focus',  # look for both stack tokens
      source_translator='shift-reduce-step',  # maps token indices -> step
      embedding_dim=32)  # project down to 32 dims
  parser.fill_from_resources(lexicon_dir)

  master_spec = spec_pb2.MasterSpec()
  master_spec.component.extend([tagger.spec, parser.spec])

  hyperparam_config = spec_pb2.GridPoint()

  # Build the TensorFlow graph.
  graph = tf.Graph()
  with graph.as_default():
    builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

    target = spec_pb2.TrainTarget()
    target.name = 'all'
    target.unroll_using_oracle.extend([True, True])
    dry_run = builder.add_training_from_config(target, trace_only=True)

  # Read in serialized protos from training data.
  sentence = sentence_pb2.Sentence()
  text_format.Merge(open(training_sentence).read(), sentence)
  training_set = [sentence.SerializeToString()]

  with tf.Session(graph=graph) as sess:
    # Make sure to re-initialize all underlying state.
    sess.run(tf.initialize_all_variables())
    traces = sess.run(
        dry_run['traces'], feed_dict={dry_run['input_batch']: training_set})

  with open('dragnn_tutorial_2.html', 'w') as f:
    f.write(
        visualization.trace_html(
            traces[0], height='400px', master_spec=master_spec).encode('utf-8'))
def main(unused_argv):
    logging.set_verbosity(logging.INFO)

    if not gfile.IsDirectory(FLAGS.resource_path):
        gfile.MakeDirs(FLAGS.resource_path)

    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    if FLAGS.compute_lexicon:
        logging.info('Computing lexicon...')
        lexicon.build_lexicon(FLAGS.resource_path, FLAGS.training_corpus_path)

    # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN
    # sequence model, which encodes the context to the right of each token. It has
    # no loss except for the downstream components.

    char2word = spec_builder.ComponentSpecBuilder('char_lstm')
    char2word.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork',
                               hidden_layer_sizes='256')
    char2word.set_transition_system(name='char-shift-only',
                                    left_to_right='true')
    char2word.add_fixed_feature(name='chars',
                                fml='char-input.text-char',
                                embedding_dim=16)
    char2word.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)

    lookahead = spec_builder.ComponentSpecBuilder('lookahead')
    lookahead.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork',
                               hidden_layer_sizes='256')
    lookahead.set_transition_system(name='shift-only', left_to_right='false')
    lookahead.add_link(source=char2word,
                       fml='input.last-char-focus',
                       embedding_dim=32)
    lookahead.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)

    # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
    # sequence tagger.
    tagger = spec_builder.ComponentSpecBuilder('tagger')
    tagger.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork',
                            hidden_layer_sizes='256')
    tagger.set_transition_system(name='tagger')
    tagger.add_token_link(source=lookahead,
                          fml='input.focus',
                          embedding_dim=32)
    tagger.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)

    # Construct the ComponentSpec for parsing.
    parser = spec_builder.ComponentSpecBuilder('parser')
    parser.set_network_unit(name='FeedForwardNetwork',
                            hidden_layer_sizes='256',
                            layer_norm_hidden='True')
    parser.set_transition_system(name='arc-standard')
    parser.add_token_link(source=lookahead,
                          fml='input.focus',
                          embedding_dim=32)
    parser.add_token_link(source=tagger,
                          fml='input.focus stack.focus stack(1).focus',
                          embedding_dim=32)

    # Recurrent connection for the arc-standard parser. For both tokens on the
    # stack, we connect to the last time step to either SHIFT or REDUCE that
    # token. This allows the parser to build up compositional representations of
    # phrases.
    parser.add_link(
        source=parser,  # recurrent connection
        name='rnn-stack',  # unique identifier
        fml='stack.focus stack(1).focus',  # look for both stack tokens
        source_translator='shift-reduce-step',  # maps token indices -> step
        embedding_dim=32)  # project down to 32 dims

    parser.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)

    master_spec = spec_pb2.MasterSpec()
    master_spec.component.extend(
        [char2word.spec, lookahead.spec, tagger.spec, parser.spec])
    logging.info('Constructed master spec: %s', str(master_spec))
    hyperparam_config = spec_pb2.GridPoint()
    hyperparam_config.decay_steps = 128000
    hyperparam_config.learning_rate = 0.001
    hyperparam_config.learning_method = 'adam'
    hyperparam_config.adam_beta1 = 0.9
    hyperparam_config.adam_beta2 = 0.9
    hyperparam_config.adam_eps = 0.0001
    hyperparam_config.gradient_clip_norm = 1
    hyperparam_config.self_norm_alpha = 1.0
    hyperparam_config.use_moving_average = True
    hyperparam_config.dropout_rate = 0.7
    hyperparam_config.seed = 1

    # Build the TensorFlow graph.
    graph = tf.Graph()
    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        component_targets = spec_builder.default_targets_from_spec(master_spec)
        trainers = [
            builder.add_training_from_config(target)
            for target in component_targets
        ]
        assert len(trainers) == 2
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    training_set = sentence_io.ConllSentenceReader(
        FLAGS.training_corpus_path,
        projectivize=FLAGS.projectivize_training_set).corpus()
    dev_set = sentence_io.ConllSentenceReader(FLAGS.dev_corpus_path,
                                              projectivize=False).corpus()

    # Ready to train!
    logging.info('Training on %d sentences.', len(training_set))
    logging.info('Tuning on %d sentences.', len(dev_set))

    pretrain_steps = [100, 0]
    tagger_steps = 1000
    train_steps = [tagger_steps, 8 * tagger_steps]

    tf.logging.info('Creating TensorFlow checkpoint dir...')
    gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))
    summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)

    with tf.Session(FLAGS.tf_master, graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())
        trainer_lib.run_training(sess, trainers, annotator,
                                 evaluation.parser_summaries, pretrain_steps,
                                 train_steps, training_set, dev_set, dev_set,
                                 FLAGS.batch_size, summary_writer,
                                 FLAGS.report_every, builder.saver,
                                 FLAGS.checkpoint_filename)
예제 #11
0
  def RunFullTrainingAndInference(self,
                                  test_name,
                                  master_spec_path=None,
                                  master_spec=None,
                                  component_weights=None,
                                  unroll_using_oracle=None,
                                  num_evaluated_components=1,
                                  expected_num_actions=None,
                                  expected=None,
                                  batch_size_limit=None):
    if not master_spec:
      master_spec = self.LoadSpec(master_spec_path)

    gold_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
    gold_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
    gold_reader_strings = [
        gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
    ]

    test_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc)
    test_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_TEST_SENTENCE_2, test_doc_2)
    test_reader_strings = [
        test_doc.SerializeToString(), test_doc.SerializeToString(),
        test_doc_2.SerializeToString(), test_doc.SerializeToString()
    ]

    if batch_size_limit is not None:
      gold_reader_strings = gold_reader_strings[:batch_size_limit]
      test_reader_strings = test_reader_strings[:batch_size_limit]

    with tf.Graph().as_default():
      tf.set_random_seed(1)
      hyperparam_config = spec_pb2.GridPoint()
      builder = graph_builder.MasterBuilder(
          master_spec, hyperparam_config, pool_scope=test_name)
      target = spec_pb2.TrainTarget()
      target.name = 'testFullInference-train-%s' % test_name
      if component_weights:
        target.component_weights.extend(component_weights)
      else:
        target.component_weights.extend([0] * len(master_spec.component))
        target.component_weights[-1] = 1.0
      if unroll_using_oracle:
        target.unroll_using_oracle.extend(unroll_using_oracle)
      else:
        target.unroll_using_oracle.extend([False] * len(master_spec.component))
        target.unroll_using_oracle[-1] = True
      train = builder.add_training_from_config(target)
      oracle_trace = builder.add_training_from_config(
          target, prefix='train_traced-', trace_only=True)
      builder.add_saver()

      anno = builder.add_annotation(test_name)
      trace = builder.add_annotation(test_name + '-traced', enable_tracing=True)

      # Verifies that the summaries can be built.
      for component in builder.components:
        component.get_summaries()

      config = tf.ConfigProto(
          intra_op_parallelism_threads=0, inter_op_parallelism_threads=0)
      with self.test_session(config=config) as sess:
        logging.info('Initializing')
        sess.run(tf.global_variables_initializer())

        logging.info('Dry run oracle trace...')
        traces = sess.run(
            oracle_trace['traces'],
            feed_dict={oracle_trace['input_batch']: gold_reader_strings})

        # Check that the oracle traces are not empty.
        for serialized_trace in traces:
          master_trace = trace_pb2.MasterTrace()
          master_trace.ParseFromString(serialized_trace)
          self.assertTrue(master_trace.component_trace)
          self.assertTrue(master_trace.component_trace[0].step_trace)

        logging.info('Simulating training...')
        break_iter = 400
        is_resolved = False
        for i in range(0,
                       400):  # needs ~100 iterations, but is not deterministic
          cost, eval_res_val = sess.run(
              [train['cost'], train['metrics']],
              feed_dict={train['input_batch']: gold_reader_strings})
          logging.info('cost = %s', cost)
          self.assertFalse(np.isnan(cost))
          total_val = eval_res_val.reshape((-1, 2))[:, 0].sum()
          correct_val = eval_res_val.reshape((-1, 2))[:, 1].sum()
          if correct_val == total_val and not is_resolved:
            logging.info('... converged on iteration %d with (correct, total) '
                         '= (%d, %d)', i, correct_val, total_val)
            is_resolved = True
            # Run for slightly longer than convergence to help with quantized
            # weight tiebreakers.
            break_iter = i + 50

          if i == break_iter:
            break

        # If training failed, report total/correct actions for each component.
        if not expected_num_actions:
          expected_num_actions = 4 * num_evaluated_components
        if (correct_val != total_val or correct_val != expected_num_actions or
            total_val != expected_num_actions):
          for c in xrange(len(master_spec.component)):
            logging.error('component %s:\nname=%s\ntotal=%s\ncorrect=%s', c,
                          master_spec.component[c].name, eval_res_val[2 * c],
                          eval_res_val[2 * c + 1])

        assert correct_val == total_val, 'Did not converge! %d vs %d.' % (
            correct_val, total_val)

        self.assertEqual(expected_num_actions, correct_val)
        self.assertEqual(expected_num_actions, total_val)

        builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model'))

        logging.info('Running test.')
        logging.info('Printing annotations')
        annotations = sess.run(
            anno['annotations'],
            feed_dict={anno['input_batch']: test_reader_strings})
        logging.info('Put %d inputs in, got %d annotations out.',
                     len(test_reader_strings), len(annotations))

        # Also run the annotation graph with tracing enabled.
        annotations_with_trace, traces = sess.run(
            [trace['annotations'], trace['traces']],
            feed_dict={trace['input_batch']: test_reader_strings})

        # The result of the two annotation graphs should be identical.
        self.assertItemsEqual(annotations, annotations_with_trace)

        # Check that the inference traces are not empty.
        for serialized_trace in traces:
          master_trace = trace_pb2.MasterTrace()
          master_trace.ParseFromString(serialized_trace)
          self.assertTrue(master_trace.component_trace)
          self.assertTrue(master_trace.component_trace[0].step_trace)

        self.assertEqual(len(test_reader_strings), len(annotations))
        pred_sentences = []
        for annotation in annotations:
          pred_sentences.append(sentence_pb2.Sentence())
          pred_sentences[-1].ParseFromString(annotation)

        if expected is None:
          expected = _TAGGER_EXPECTED_SENTENCES

        expected_sentences = [expected[i] for i in [0, 0, 1, 0]]

        for i, pred_sentence in enumerate(pred_sentences):
          self.assertProtoEquals(expected_sentences[i], pred_sentence)
예제 #12
0
def main(unused_argv):

    # Parse the flags containint lists, using regular expressions.
    # This matches and extracts key=value pairs.
    component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                      FLAGS.inference_beam_size)
    # This matches strings separated by a comma. Does not return any empty
    # strings.
    components_to_locally_normalize = re.findall(r'[^,]+',
                                                 FLAGS.locally_normalize)

    ## SEGMENTATION ##

    if not FLAGS.use_gold_segmentation:

        # Reads master spec.
        master_spec = spec_pb2.MasterSpec()
        with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin:
            text_format.Parse(fin.read(), master_spec)

        if FLAGS.complete_master_spec:
            spec_builder.complete_master_spec(master_spec, None,
                                              FLAGS.segmenter_resource_dir)

        # Graph building.
        tf.logging.info('Building the graph')
        g = tf.Graph()
        with g.as_default(), tf.device('/device:CPU:0'):
            hyperparam_config = spec_pb2.GridPoint()
            hyperparam_config.use_moving_average = True
            builder = graph_builder.MasterBuilder(master_spec,
                                                  hyperparam_config)
            annotator = builder.add_annotation()
            builder.add_saver()

        tf.logging.info('Reading documents...')
        input_corpus = sentence_io.ConllSentenceReader(
            FLAGS.input_file).corpus()
        with tf.Session(graph=tf.Graph()) as tmp_session:
            char_input = gen_parser_ops.char_token_generator(input_corpus)
            char_corpus = tmp_session.run(char_input)
        check.Eq(len(input_corpus), len(char_corpus))

        session_config = tf.ConfigProto(
            log_device_placement=False,
            intra_op_parallelism_threads=FLAGS.threads,
            inter_op_parallelism_threads=FLAGS.threads)

        with tf.Session(graph=g, config=session_config) as sess:
            tf.logging.info('Initializing variables...')
            sess.run(tf.global_variables_initializer())
            tf.logging.info('Loading from checkpoint...')
            sess.run('save/restore_all',
                     {'save/Const:0': FLAGS.segmenter_checkpoint_file})

            tf.logging.info('Processing sentences...')

            processed = []
            start_time = time.time()
            run_metadata = tf.RunMetadata()
            for start in range(0, len(char_corpus), FLAGS.max_batch_size):
                end = min(start + FLAGS.max_batch_size, len(char_corpus))
                feed_dict = {annotator['input_batch']: char_corpus[start:end]}
                if FLAGS.timeline_output_file and end == len(char_corpus):
                    serialized_annotations = sess.run(
                        annotator['annotations'],
                        feed_dict=feed_dict,
                        options=tf.RunOptions(
                            trace_level=tf.RunOptions.FULL_TRACE),
                        run_metadata=run_metadata)
                    trace = timeline.Timeline(
                        step_stats=run_metadata.step_stats)
                    with open(FLAGS.timeline_output_file, 'w') as trace_file:
                        trace_file.write(trace.generate_chrome_trace_format())
                else:
                    serialized_annotations = sess.run(annotator['annotations'],
                                                      feed_dict=feed_dict)
                processed.extend(serialized_annotations)

            tf.logging.info('Processed %d documents in %.2f seconds.',
                            len(char_corpus),
                            time.time() - start_time)

        input_corpus = processed
    else:
        input_corpus = sentence_io.ConllSentenceReader(
            FLAGS.input_file).corpus()

    ## PARSING

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.parser_master_spec) as fin:
        text_format.Parse(fin.read(), master_spec)

    if FLAGS.complete_master_spec:
        spec_builder.complete_master_spec(master_spec, None,
                                          FLAGS.parser_resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    tf.logging.info('Reading documents...')

    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        sess.run('save/restore_all',
                 {'save/Const:0': FLAGS.parser_checkpoint_file})

        tf.logging.info('Processing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        for start in range(0, len(input_corpus), FLAGS.max_batch_size):
            end = min(start + FLAGS.max_batch_size, len(input_corpus))
            feed_dict = {annotator['input_batch']: input_corpus[start:end]}
            for comp, beam_size in component_beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in components_to_locally_normalize:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            if FLAGS.timeline_output_file and end == len(input_corpus):
                serialized_annotations = sess.run(
                    annotator['annotations'],
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(FLAGS.timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(annotator['annotations'],
                                                  feed_dict=feed_dict)
            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_corpus),
                        time.time() - start_time)

        if FLAGS.output_file:
            with gfile.GFile(FLAGS.output_file, 'w') as f:
                for serialized_sentence in processed:
                    sentence = sentence_pb2.Sentence()
                    sentence.ParseFromString(serialized_sentence)
                    f.write('#' + sentence.text.encode('utf-8') + '\n')
                    for i, token in enumerate(sentence.token):
                        head = token.head + 1
                        f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n' %
                                (i + 1, token.word.encode('utf-8'), head,
                                 token.label.encode('utf-8')))
                    f.write('\n\n')
예제 #13
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    # Parse the flags containint lists, using regular expressions.
    # This matches and extracts key=value pairs.
    component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
                                      FLAGS.inference_beam_size)
    # This matches strings separated by a comma. Does not return any empty
    # strings.
    components_to_locally_normalize = re.findall(r'[^,]+',
                                                 FLAGS.locally_normalize)

    # Reads master spec.
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.master_spec) as fin:
        text_format.Parse(fin.read(), master_spec)

    # Rewrite resource locations.
    if FLAGS.resource_dir:
        for component in master_spec.component:
            for resource in component.resource:
                for part in resource.part:
                    part.file_pattern = os.path.join(FLAGS.resource_dir,
                                                     part.file_pattern)

    if FLAGS.complete_master_spec:
        spec_builder.complete_master_spec(master_spec, None,
                                          FLAGS.resource_dir)

    # Graph building.
    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    tf.logging.info('Reading documents...')
    input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()

    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file})

        tf.logging.info('Processing sentences...')

        processed = []
        start_time = time.time()
        run_metadata = tf.RunMetadata()
        for start in range(0, len(input_corpus), FLAGS.max_batch_size):
            end = min(start + FLAGS.max_batch_size, len(input_corpus))
            feed_dict = {annotator['input_batch']: input_corpus[start:end]}
            for comp, beam_size in component_beam_sizes:
                feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
            for comp in components_to_locally_normalize:
                feed_dict['%s/LocallyNormalize:0' % comp] = True
            if FLAGS.timeline_output_file and end == len(input_corpus):
                serialized_annotations = sess.run(
                    annotator['annotations'],
                    feed_dict=feed_dict,
                    options=tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE),
                    run_metadata=run_metadata)
                trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                with open(FLAGS.timeline_output_file, 'w') as trace_file:
                    trace_file.write(trace.generate_chrome_trace_format())
            else:
                serialized_annotations = sess.run(annotator['annotations'],
                                                  feed_dict=feed_dict)
            processed.extend(serialized_annotations)

        tf.logging.info('Processed %d documents in %.2f seconds.',
                        len(input_corpus),
                        time.time() - start_time)
        pos, uas, las = evaluation.calculate_parse_metrics(
            input_corpus, processed)
        if FLAGS.log_file:
            with gfile.GFile(FLAGS.log_file, 'w') as f:
                f.write('%s\t%f\t%f\t%f\n' %
                        (FLAGS.language_name, pos, uas, las))

        if FLAGS.output_file:
            with gfile.GFile(FLAGS.output_file, 'w') as f:
                for serialized_sentence in processed:
                    sentence = sentence_pb2.Sentence()
                    sentence.ParseFromString(serialized_sentence)
                    f.write(text_format.MessageToString(sentence) + '\n\n')
예제 #14
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    check.NotNone(FLAGS.model_dir, '--model_dir is required')
    check.Ne(
        FLAGS.pretrain_steps is None, FLAGS.pretrain_epochs is None,
        'Exactly one of --pretrain_steps or --pretrain_epochs is required')
    check.Ne(FLAGS.train_steps is None, FLAGS.train_epochs is None,
             'Exactly one of --train_steps or --train_epochs is required')

    config_path = os.path.join(FLAGS.model_dir, 'config.txt')
    master_path = os.path.join(FLAGS.model_dir, 'master.pbtxt')
    hyperparameters_path = os.path.join(FLAGS.model_dir,
                                        'hyperparameters.pbtxt')
    targets_path = os.path.join(FLAGS.model_dir, 'targets.pbtxt')
    checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoints/best')
    tensorboard_dir = os.path.join(FLAGS.model_dir, 'tensorboard')

    with tf.gfile.FastGFile(config_path) as config_file:
        config = collections.defaultdict(bool,
                                         ast.literal_eval(config_file.read()))
    train_corpus_path = config['train_corpus_path']
    tune_corpus_path = config['tune_corpus_path']
    projectivize_train_corpus = config['projectivize_train_corpus']

    master = _read_text_proto(master_path, spec_pb2.MasterSpec)
    hyperparameters = _read_text_proto(hyperparameters_path,
                                       spec_pb2.GridPoint)
    targets = spec_builder.default_targets_from_spec(master)
    if tf.gfile.Exists(targets_path):
        targets = _read_text_proto(targets_path,
                                   spec_pb2.TrainingGridSpec).target

    # Build the TensorFlow graph.
    graph = tf.Graph()
    with graph.as_default():
        tf.set_random_seed(hyperparameters.seed)
        builder = graph_builder.MasterBuilder(master, hyperparameters)
        trainers = [
            builder.add_training_from_config(target) for target in targets
        ]
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    train_corpus = sentence_io.ConllSentenceReader(
        train_corpus_path, projectivize=projectivize_train_corpus).corpus()
    tune_corpus = sentence_io.ConllSentenceReader(tune_corpus_path,
                                                  projectivize=False).corpus()
    gold_tune_corpus = tune_corpus

    # Convert to char-based corpora, if requested.
    if config['convert_to_char_corpora']:
        # NB: Do not convert the |gold_tune_corpus|, which should remain word-based
        # for segmentation evaluation purposes.
        train_corpus = _convert_to_char_corpus(train_corpus)
        tune_corpus = _convert_to_char_corpus(tune_corpus)

    pretrain_steps = _get_steps(FLAGS.pretrain_steps, FLAGS.pretrain_epochs,
                                len(train_corpus))
    train_steps = _get_steps(FLAGS.train_steps, FLAGS.train_epochs,
                             len(train_corpus))
    check.Eq(len(targets), len(pretrain_steps),
             'Length mismatch between training targets and --pretrain_steps')
    check.Eq(len(targets), len(train_steps),
             'Length mismatch between training targets and --train_steps')

    # Ready to train!
    tf.logging.info('Training on %d sentences.', len(train_corpus))
    tf.logging.info('Tuning on %d sentences.', len(tune_corpus))

    tf.logging.info('Creating TensorFlow checkpoint dir...')
    summary_writer = trainer_lib.get_summary_writer(tensorboard_dir)

    checkpoint_dir = os.path.dirname(checkpoint_path)
    if tf.gfile.IsDirectory(checkpoint_dir):
        tf.gfile.DeleteRecursively(checkpoint_dir)
    elif tf.gfile.Exists(checkpoint_dir):
        tf.gfile.Remove(checkpoint_dir)
    tf.gfile.MakeDirs(checkpoint_dir)

    with tf.Session(FLAGS.tf_master, graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())
        trainer_lib.run_training(sess, trainers, annotator,
                                 evaluation.parser_summaries, pretrain_steps,
                                 train_steps, train_corpus, tune_corpus,
                                 gold_tune_corpus, FLAGS.batch_size,
                                 summary_writer, FLAGS.report_every,
                                 builder.saver, checkpoint_path)

    tf.logging.info('Best checkpoint written to:\n%s', checkpoint_path)
예제 #15
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Read hyperparams and master spec.
    hyperparam_config = spec_pb2.GridPoint()
    text_format.Parse(FLAGS.hyperparams, hyperparam_config)
    print hyperparam_config
    master_spec = spec_pb2.MasterSpec()

    with gfile.GFile(FLAGS.master_spec, 'r') as fin:
        text_format.Parse(fin.read(), master_spec)

    # Make output folder
    if not gfile.Exists(FLAGS.output_folder):
        gfile.MakeDirs(FLAGS.output_folder)

    # Construct TF Graph.
    graph = tf.Graph()

    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

        # Construct default per-component targets.
        default_targets = [
            spec_pb2.TrainTarget(name=component.name,
                                 max_index=idx + 1,
                                 unroll_using_oracle=[False] * idx + [True])
            for idx, component in enumerate(master_spec.component)
            if (component.transition_system.registered_name != 'shift-only')
        ]

        # Add default and manually specified targets.
        trainers = []
        for target in default_targets:
            trainers += [builder.add_training_from_config(target)]
        check.Eq(len(trainers), 1,
                 "Expected only one training target (FF unit)")

        # Construct annotation and saves. Will use moving average if enabled.
        annotator = builder.add_annotation()
        builder.add_saver()

        # Add backwards compatible training summary.
        summaries = []
        for component in builder.components:
            summaries += component.get_summaries()
        merged_summaries = tf.summary.merge_all()

        # Construct target to initialize variables.
        tf.group(tf.global_variables_initializer(), name='inits')

    # Prepare tensorboard dir.
    events_dir = os.path.join(FLAGS.output_folder, "tensorboard")
    empty_dir(events_dir)
    summary_writer = tf.summary.FileWriter(events_dir, graph)
    print "Wrote events (incl. graph) for Tensorboard to folder:", events_dir
    print "The graph can be viewed via"
    print "  tensorboard --logdir=" + events_dir
    print "  then navigating to http://localhost:6006 and clicking on 'GRAPHS'"

    with graph.as_default():
        tf.set_random_seed(hyperparam_config.seed)

    # Read train and dev corpora.
    print "Reading corpora..."
    train_corpus = read_corpus(FLAGS.train_corpus)
    dev_corpus = read_corpus(FLAGS.dev_corpus)

    # Prepare checkpoint folder.
    checkpoint_path = os.path.join(FLAGS.output_folder, 'checkpoints/best')
    checkpoint_dir = os.path.dirname(checkpoint_path)
    empty_dir(checkpoint_dir)

    with tf.Session(FLAGS.tf_master, graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())

        # Run training.
        trainer_lib.run_training(
            sess,
            trainers,
            annotator,
            evaluator,
            [0],  # pretrain_steps
            [FLAGS.train_steps],
            train_corpus,
            dev_corpus,
            dev_corpus,
            FLAGS.batch_size,
            summary_writer,
            FLAGS.report_every,
            builder.saver,
            checkpoint_path)

        # Convert model to a Myelin flow.
        if len(FLAGS.flow) != 0:
            tf.logging.info('Saving flow to %s', FLAGS.flow)
            flow = convert_model(master_spec, sess)
            flow.save(FLAGS.flow)

    tf.logging.info('Best checkpoint written to %s', checkpoint_path)
예제 #16
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=FLAGS.threads,
                                    inter_op_parallelism_threads=FLAGS.threads)

    master_spec = spec_pb2.MasterSpec()
    master_spec_file = FLAGS.parser_dir + "/master_spec"
    with file(master_spec_file, 'r') as fin:
        text_format.Parse(fin.read(), master_spec)
    fin.close()

    tf.logging.info('Building the graph')
    g = tf.Graph()
    with g.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = True
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        annotator = builder.add_annotation()
        builder.add_saver()

    with tf.Session(graph=g, config=session_config) as sess:
        tf.logging.info('Initializing variables...')
        sess.run(tf.global_variables_initializer())

        tf.logging.info('Loading from checkpoint...')
        checkpoint = FLAGS.parser_dir + "/checkpoints/best"
        sess.run('save/restore_all', {'save/Const:0': checkpoint})

        # Annotate the corpus.
        corpus = read_corpus(FLAGS.corpus)
        annotated = []
        annotation_time = 0
        for start in range(0, len(corpus), FLAGS.batch_size):
            end = min(start + FLAGS.batch_size, len(corpus))
            feed_dict = {annotator['input_batch']: corpus[start:end]}
            start_time = timeit.default_timer()
            output = sess.run(annotator['annotations'], feed_dict=feed_dict)
            annotation_time += (timeit.default_timer() - start_time)
            annotated.extend(output)

        tf.logging.info("Wall clock time for %s annotation: %f seconds",
                        len(annotated), annotation_time)

    output_file = FLAGS.output
    if FLAGS.evaluate and len(FLAGS.output) == 0:
        output_file = "/tmp/annotated.zip"
        tf.logging.info(
            '--output not provided, will write annotated docs to %s',
            output_file)

    if FLAGS.evaluate or len(FLAGS.output) != 0:
        # Write the annotated corpus to disk as a zip file.
        with zipfile.ZipFile(output_file, 'w') as outfile:
            for i in xrange(len(annotated)):
                outfile.writestr('test.' + str(i), annotated[i])
            tf.logging.info('Wrote %d annotated docs to %s', len(annotated),
                            output_file)

    if FLAGS.evaluate:
        # Evaluate against gold annotations.
        try:
            eval_output_lines = subprocess.check_output(
                [
                    'bazel-bin/nlp/parser/tools/evaluate-frames',
                    '--gold_documents=' + FLAGS.corpus, '--test_documents=' +
                    output_file, '--commons=' + FLAGS.commons
                ],
                stderr=subprocess.STDOUT)

            eval_output = {}
            eval_metric = -1
            for line in eval_output_lines.splitlines():
                line = line.rstrip()
                tf.logging.info("Evaluation Metric: %s", line)
                parts = line.split('\t')
                assert len(parts) == 2, line
                eval_output[parts[0]] = float(parts[1])
                if line.startswith("SLOT_F1"):
                    eval_metric = float(parts[1])
            assert eval_metric != -1, "Missing SLOT F1"
            tf.logging.info('Overall Evaluation Metric: %f', eval_metric)
        except subprocess.CalledProcessError as e:
            print("Evaluation failed: ", e.returncode, e.output)
예제 #17
0
def main(unused_argv):
    logging.set_verbosity(logging.INFO)

    if not gfile.IsDirectory(FLAGS.resource_path):
        gfile.MakeDirs(FLAGS.resource_path)

    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    if FLAGS.compute_lexicon:
        logging.info('Computing lexicon...')
        lexicon.build_lexicon(FLAGS.resource_path, FLAGS.training_corpus_path)

    # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN
    # sequence model, which encodes the context to the right of each token. It has
    # no loss except for the downstream components.
    lookahead = spec_builder.ComponentSpecBuilder('lookahead')
    lookahead.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork',
                               hidden_layer_sizes='256')
    lookahead.set_transition_system(name='shift-only', left_to_right='false')
    lookahead.add_fixed_feature(name='char',
                                fml='input(-1).char input.char input(1).char',
                                embedding_dim=32)
    lookahead.add_fixed_feature(name='char-bigram',
                                fml='input.char-bigram',
                                embedding_dim=32)
    lookahead.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)

    # Construct the ComponentSpec for segmentation.
    segmenter = spec_builder.ComponentSpecBuilder('segmenter')
    segmenter.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork',
                               hidden_layer_sizes='128')
    segmenter.set_transition_system(name='binary-segment-transitions')
    segmenter.add_token_link(source=lookahead,
                             fml='input.focus stack.focus',
                             embedding_dim=64)
    segmenter.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)

    # Build and write master_spec.
    master_spec = spec_pb2.MasterSpec()
    master_spec.component.extend([lookahead.spec, segmenter.spec])
    logging.info('Constructed master spec: %s', str(master_spec))
    with gfile.GFile(FLAGS.resource_path + '/master_spec', 'w') as f:
        f.write(str(master_spec).encode('utf-8'))

    hyperparam_config = spec_pb2.GridPoint()
    try:
        text_format.Parse(FLAGS.hyperparams, hyperparam_config)
    except text_format.ParseError:
        text_format.Parse(base64.b64decode(FLAGS.hyperparams),
                          hyperparam_config)

    # Build the TensorFlow graph.
    graph = tf.Graph()
    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        component_targets = spec_builder.default_targets_from_spec(master_spec)
        trainers = [
            builder.add_training_from_config(target)
            for target in component_targets
        ]
        assert len(trainers) == 1
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    training_set = ConllSentenceReader(FLAGS.training_corpus_path,
                                       projectivize=False).corpus()
    dev_set = ConllSentenceReader(FLAGS.dev_corpus_path,
                                  projectivize=False).corpus()

    # Convert word-based docs to char-based documents for segmentation training
    # and evaluation.
    with tf.Session(graph=tf.Graph()) as tmp_session:
        char_training_set_op = gen_parser_ops.segmenter_training_data_constructor(
            training_set)
        char_dev_set_op = gen_parser_ops.char_token_generator(dev_set)
        char_training_set = tmp_session.run(char_training_set_op)
        char_dev_set = tmp_session.run(char_dev_set_op)

    # Ready to train!
    logging.info('Training on %d sentences.', len(training_set))
    logging.info('Tuning on %d sentences.', len(dev_set))

    pretrain_steps = [0]
    train_steps = [FLAGS.num_epochs * len(training_set)]

    tf.logging.info('Creating TensorFlow checkpoint dir...')
    gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))
    summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)

    with tf.Session(FLAGS.tf_master, graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())
        trainer_lib.run_training(
            sess, trainers, annotator, evaluation.segmentation_summaries,
            pretrain_steps, train_steps, char_training_set, char_dev_set,
            dev_set, FLAGS.batch_size, summary_writer, FLAGS.report_every,
            builder.saver, FLAGS.checkpoint_filename)
예제 #18
0
def main(unused_argv):
    logging.set_verbosity(logging.INFO)
    check.IsTrue(FLAGS.checkpoint_filename)
    check.IsTrue(FLAGS.tensorboard_dir)
    check.IsTrue(FLAGS.resource_path)

    if not gfile.IsDirectory(FLAGS.resource_path):
        gfile.MakeDirs(FLAGS.resource_path)

    training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0]
    tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0]

    # SummaryWriter for TensorBoard
    tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir)
    tf.logging.info('Deleting prior data if exists...')

    stats_file = '%s.stats' % FLAGS.checkpoint_filename
    try:
        stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',')
        stats = [int(x) for x in stats]
    except errors.OpError:
        stats = [-1, 0, 0]

    tf.logging.info('Read ckpt stats: %s', str(stats))
    do_restore = True
    if stats[0] < FLAGS.job_id:
        do_restore = False
        tf.logging.info('Deleting last job: %d', stats[0])
        try:
            gfile.DeleteRecursively(FLAGS.tensorboard_dir)
            gfile.Remove(FLAGS.checkpoint_filename)
        except errors.OpError as err:
            tf.logging.error('Unable to delete prior files: %s', err)
        stats = [FLAGS.job_id, 0, 0]

    tf.logging.info('Creating the directory again...')
    gfile.MakeDirs(FLAGS.tensorboard_dir)
    tf.logging.info('Created! Instatiating SummaryWriter...')
    summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)
    tf.logging.info('Creating TensorFlow checkpoint dir...')
    gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))

    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    if FLAGS.compute_lexicon:
        logging.info('Computing lexicon...')
        lexicon.build_lexicon(FLAGS.resource_path,
                              training_corpus_path,
                              morph_to_pos=True)

    tf.logging.info('Loading MasterSpec...')
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin:
        text_format.Parse(fin.read(), master_spec)
    spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path)
    logging.info('Constructed master spec: %s', str(master_spec))
    hyperparam_config = spec_pb2.GridPoint()

    # Build the TensorFlow graph.
    tf.logging.info('Building Graph...')
    hyperparam_config = spec_pb2.GridPoint()
    try:
        text_format.Parse(FLAGS.hyperparams, hyperparam_config)
    except text_format.ParseError:
        text_format.Parse(base64.b64decode(FLAGS.hyperparams),
                          hyperparam_config)
    g = tf.Graph()
    with g.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        component_targets = [
            spec_pb2.TrainTarget(name=component.name,
                                 max_index=idx + 1,
                                 unroll_using_oracle=[False] * idx + [True])
            for idx, component in enumerate(master_spec.component)
            if 'shift-only' not in component.transition_system.registered_name
        ]
        trainers = [
            builder.add_training_from_config(target)
            for target in component_targets
        ]
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    training_set = ConllSentenceReader(
        training_corpus_path,
        projectivize=FLAGS.projectivize_training_set,
        morph_to_pos=True).corpus()
    tune_set = ConllSentenceReader(tune_corpus_path,
                                   projectivize=False,
                                   morph_to_pos=True).corpus()

    # Ready to train_bkp!
    logging.info('Training on %d sentences.', len(training_set))
    logging.info('Tuning on %d sentences.', len(tune_set))

    pretrain_steps = [10000, 0]
    tagger_steps = 100000
    train_steps = [tagger_steps, 8 * tagger_steps]

    with tf.Session(FLAGS.tf_master, graph=g) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())

        if do_restore:
            tf.logging.info('Restoring from checkpoint...')
            builder.saver.restore(sess, FLAGS.checkpoint_filename)

            prev_tagger_steps = stats[1]
            prev_parser_steps = stats[2]
            tf.logging.info('adjusting schedule from steps: %d, %d',
                            prev_tagger_steps, prev_parser_steps)
            pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0)
            tf.logging.info('new pretrain steps: %d', pretrain_steps[0])

        trainer_lib.run_training(sess, trainers, annotator,
                                 evaluation.parser_summaries, pretrain_steps,
                                 train_steps, training_set, tune_set, tune_set,
                                 FLAGS.batch_size, summary_writer,
                                 FLAGS.report_every, builder.saver,
                                 FLAGS.checkpoint_filename, stats)
def export_to_graph(master_spec,
                    params_path,
                    export_path,
                    external_graph,
                    export_moving_averages,
                    signature_name='model'):
    """Restores a model and exports it in SavedModel form.

  This method loads a graph specified by the master_spec and the params in
  params_path into the graph given in external_graph. It then saves the model
  in SavedModel format to the location specified in export_path.

  Args:
    master_spec: Proto master spec.
    params_path: Path to the parameters file to export.
    export_path: Path to export the SavedModel to.
    external_graph: A tf.Graph() object to build the graph inside.
    export_moving_averages: Whether to export the moving average parameters.
    signature_name: Name of the signature to insert.
  """
    tf.logging.info(
        'Exporting graph with signature_name "%s" and use_moving_averages = %s'
        % (signature_name, export_moving_averages))

    tf.logging.info('Building the graph')
    with external_graph.as_default(), tf.device('/device:CPU:0'):
        hyperparam_config = spec_pb2.GridPoint()
        hyperparam_config.use_moving_average = export_moving_averages
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        post_restore_hook = builder.build_post_restore_hook()
        annotation = builder.add_annotation()
        builder.add_saver()

    # Resets session.
    session_config = tf.ConfigProto(log_device_placement=False,
                                    intra_op_parallelism_threads=10,
                                    inter_op_parallelism_threads=10)

    with tf.Session(graph=external_graph, config=session_config) as session:
        tf.logging.info('Initializing variables...')
        session.run(tf.global_variables_initializer())

        tf.logging.info('Loading params...')
        session.run('save/restore_all', {'save/Const:0': params_path})

        tf.logging.info('Saving.')

        with tf.device('/device:CPU:0'):
            saved_model_builder = tf.saved_model.builder.SavedModelBuilder(
                export_path)

            signature_map = {
                signature_name:
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        'inputs':
                        tf.saved_model.utils.build_tensor_info(
                            annotation['input_batch'])
                    },
                    outputs={
                        'annotations':
                        tf.saved_model.utils.build_tensor_info(
                            annotation['annotations'])
                    },
                    method_name=tf.saved_model.signature_constants.
                    PREDICT_METHOD_NAME),
            }

            tf.logging.info('Input is: %s', annotation['input_batch'].name)
            tf.logging.info('Output is: %s', annotation['annotations'].name)

            saved_model_builder.add_meta_graph_and_variables(
                session,
                tags=_SAVED_MODEL_TAGS,
                legacy_init_op=tf.group(
                    post_restore_hook,
                    builder.build_warmup_graph(
                        tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)[0])),
                signature_def_map=signature_map,
                assets_collection=tf.get_collection(
                    tf.GraphKeys.ASSET_FILEPATHS))

            saved_model_builder.save()