Beispiel #1
0
def build_train_graph(master_spec, hyperparam_config=None):
    # Build the TensorFlow graph based on the DRAGNN network spec.
    tf.logging.info('Building Graph...')
    if not hyperparam_config:
        hyperparam_config = spec_pb2.GridPoint(learning_method='adam',
                                               learning_rate=0.0005,
                                               adam_beta1=0.9,
                                               adam_beta2=0.9,
                                               adam_eps=0.00001,
                                               decay_steps=128000,
                                               dropout_rate=0.8,
                                               gradient_clip_norm=1,
                                               use_moving_average=True,
                                               seed=1)
    graph = tf.Graph()
    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        component_targets = [
            spec_pb2.TrainTarget(name=component.name,
                                 max_index=idx + 1,
                                 unroll_using_oracle=[False] * idx + [True])
            for idx, component in enumerate(master_spec.component)
            if 'shift-only' not in component.transition_system.registered_name
        ]
        trainers = [
            builder.add_training_from_config(target)
            for target in component_targets
        ]
        annotator = builder.add_annotation(enable_tracing=True)
        builder.add_saver()
        return graph, builder, trainers, annotator
  def RunTraining(self, hyperparam_config):
    master_spec = self.LoadSpec('master_spec_link.textproto')

    self.assertTrue(isinstance(hyperparam_config, spec_pb2.GridPoint))
    gold_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
    gold_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
    reader_strings = [
        gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
    ]
    tf.logging.info('Generating graph with config: %s', hyperparam_config)
    with tf.Graph().as_default():
      builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

      target = spec_pb2.TrainTarget()
      target.name = 'testTraining-all'
      train = builder.add_training_from_config(target)
      with self.test_session() as sess:
        logging.info('Initializing')
        sess.run(tf.global_variables_initializer())

        # Run one iteration of training and verify nothing crashes.
        logging.info('Training')
        sess.run(train['run'], feed_dict={train['input_batch']: reader_strings})
def default_targets_from_spec(spec):
  """Constructs a default set of TrainTarget protos from a DRAGNN spec.

  For each component in the DRAGNN spec, it adds a training target for that
  component's oracle. It also stops unrolling the graph with that component.  It
  skips any 'shift-only' transition systems which have no oracle. E.g.: if there
  are three components, a 'shift-only', a 'tagger', and a 'arc-standard', it
  will construct two training targets, one for the tagger and one for the
  arc-standard parser.

  Arguments:
    spec: DRAGNN spec.

  Returns:
    List of TrainTarget protos.
  """
  component_targets = [
      spec_pb2.TrainTarget(
          name=component.name,
          max_index=idx + 1,
          unroll_using_oracle=[False] * idx + [True])
      for idx, component in enumerate(spec.component)
      if not component.transition_system.registered_name.endswith('shift-only')
  ]
  return component_targets
 def getBuilderAndTarget(
     self, test_name, master_spec_path='simple_parser_master_spec.textproto'):
   """Generates a MasterBuilder and TrainTarget based on a simple spec."""
   master_spec = self.LoadSpec(master_spec_path)
   hyperparam_config = spec_pb2.GridPoint()
   target = spec_pb2.TrainTarget()
   target.name = 'test-%s-train' % test_name
   target.component_weights.extend([0] * len(master_spec.component))
   target.component_weights[-1] = 1.0
   target.unroll_using_oracle.extend([False] * len(master_spec.component))
   target.unroll_using_oracle[-1] = True
   builder = graph_builder.MasterBuilder(
       master_spec, hyperparam_config, pool_scope=test_name)
   return builder, target
def main(argv):
    del argv  # unused
    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    lexicon.build_lexicon(lexicon_dir,
                          training_sentence,
                          training_corpus_format='sentence-prototext')

    # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
    # sequence tagger.
    tagger = spec_builder.ComponentSpecBuilder('tagger')
    tagger.set_network_unit(name='FeedForwardNetwork',
                            hidden_layer_sizes='256')
    tagger.set_transition_system(name='tagger')
    tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
    tagger.add_rnn_link(embedding_dim=-1)
    tagger.fill_from_resources(lexicon_dir)

    master_spec = spec_pb2.MasterSpec()
    master_spec.component.extend([tagger.spec])

    hyperparam_config = spec_pb2.GridPoint()

    # Build the TensorFlow graph.
    graph = tf.Graph()
    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

        target = spec_pb2.TrainTarget()
        target.name = 'all'
        target.unroll_using_oracle.extend([True])
        dry_run = builder.add_training_from_config(target, trace_only=True)

    # Read in serialized protos from training data.
    sentence = sentence_pb2.Sentence()
    text_format.Merge(open(training_sentence).read(), sentence)
    training_set = [sentence.SerializeToString()]

    with tf.Session(graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.initialize_all_variables())
        traces = sess.run(dry_run['traces'],
                          feed_dict={dry_run['input_batch']: training_set})

    with open('dragnn_tutorial_1.html', 'w') as f:
        f.write(
            visualization.trace_html(traces[0],
                                     height='300px').encode('utf-8'))
Beispiel #6
0
def main(unused_argv):
    logging.set_verbosity(logging.INFO)
    check.IsTrue(FLAGS.checkpoint_filename)
    check.IsTrue(FLAGS.tensorboard_dir)
    check.IsTrue(FLAGS.resource_path)

    if not gfile.IsDirectory(FLAGS.resource_path):
        gfile.MakeDirs(FLAGS.resource_path)

    training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0]
    tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0]

    # SummaryWriter for TensorBoard
    tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir)
    tf.logging.info('Deleting prior data if exists...')

    stats_file = '%s.stats' % FLAGS.checkpoint_filename
    try:
        stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',')
        stats = [int(x) for x in stats]
    except errors.OpError:
        stats = [-1, 0, 0]

    tf.logging.info('Read ckpt stats: %s', str(stats))
    do_restore = True
    if stats[0] < FLAGS.job_id:
        do_restore = False
        tf.logging.info('Deleting last job: %d', stats[0])
        try:
            gfile.DeleteRecursively(FLAGS.tensorboard_dir)
            gfile.Remove(FLAGS.checkpoint_filename)
        except errors.OpError as err:
            tf.logging.error('Unable to delete prior files: %s', err)
        stats = [FLAGS.job_id, 0, 0]

    tf.logging.info('Creating the directory again...')
    gfile.MakeDirs(FLAGS.tensorboard_dir)
    tf.logging.info('Created! Instatiating SummaryWriter...')
    summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)
    tf.logging.info('Creating TensorFlow checkpoint dir...')
    gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))

    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    if FLAGS.compute_lexicon:
        logging.info('Computing lexicon...')
        lexicon.build_lexicon(FLAGS.resource_path,
                              training_corpus_path,
                              morph_to_pos=True)

    tf.logging.info('Loading MasterSpec...')
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin:
        text_format.Parse(fin.read(), master_spec)
    spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path)
    logging.info('Constructed master spec: %s', str(master_spec))
    hyperparam_config = spec_pb2.GridPoint()

    # Build the TensorFlow graph.
    tf.logging.info('Building Graph...')
    hyperparam_config = spec_pb2.GridPoint()
    try:
        text_format.Parse(FLAGS.hyperparams, hyperparam_config)
    except text_format.ParseError:
        text_format.Parse(base64.b64decode(FLAGS.hyperparams),
                          hyperparam_config)
    g = tf.Graph()
    with g.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        component_targets = [
            spec_pb2.TrainTarget(name=component.name,
                                 max_index=idx + 1,
                                 unroll_using_oracle=[False] * idx + [True])
            for idx, component in enumerate(master_spec.component)
            if 'shift-only' not in component.transition_system.registered_name
        ]
        trainers = [
            builder.add_training_from_config(target)
            for target in component_targets
        ]
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    training_set = ConllSentenceReader(
        training_corpus_path,
        projectivize=FLAGS.projectivize_training_set,
        morph_to_pos=True).corpus()
    tune_set = ConllSentenceReader(tune_corpus_path,
                                   projectivize=False,
                                   morph_to_pos=True).corpus()

    # Ready to train_bkp!
    logging.info('Training on %d sentences.', len(training_set))
    logging.info('Tuning on %d sentences.', len(tune_set))

    pretrain_steps = [10000, 0]
    tagger_steps = 100000
    train_steps = [tagger_steps, 8 * tagger_steps]

    with tf.Session(FLAGS.tf_master, graph=g) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())

        if do_restore:
            tf.logging.info('Restoring from checkpoint...')
            builder.saver.restore(sess, FLAGS.checkpoint_filename)

            prev_tagger_steps = stats[1]
            prev_parser_steps = stats[2]
            tf.logging.info('adjusting schedule from steps: %d, %d',
                            prev_tagger_steps, prev_parser_steps)
            pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0)
            tf.logging.info('new pretrain steps: %d', pretrain_steps[0])

        trainer_lib.run_training(sess, trainers, annotator,
                                 evaluation.parser_summaries, pretrain_steps,
                                 train_steps, training_set, tune_set, tune_set,
                                 FLAGS.batch_size, summary_writer,
                                 FLAGS.report_every, builder.saver,
                                 FLAGS.checkpoint_filename, stats)
  def RunFullTrainingAndInference(self,
                                  test_name,
                                  master_spec_path=None,
                                  master_spec=None,
                                  component_weights=None,
                                  unroll_using_oracle=None,
                                  num_evaluated_components=1,
                                  expected_num_actions=None,
                                  expected=None,
                                  batch_size_limit=None):
    if not master_spec:
      master_spec = self.LoadSpec(master_spec_path)

    gold_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
    gold_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
    gold_reader_strings = [
        gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
    ]

    test_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc)
    test_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_TEST_SENTENCE_2, test_doc_2)
    test_reader_strings = [
        test_doc.SerializeToString(), test_doc.SerializeToString(),
        test_doc_2.SerializeToString(), test_doc.SerializeToString()
    ]

    if batch_size_limit is not None:
      gold_reader_strings = gold_reader_strings[:batch_size_limit]
      test_reader_strings = test_reader_strings[:batch_size_limit]

    with tf.Graph().as_default():
      tf.set_random_seed(1)
      hyperparam_config = spec_pb2.GridPoint()
      builder = graph_builder.MasterBuilder(
          master_spec, hyperparam_config, pool_scope=test_name)
      target = spec_pb2.TrainTarget()
      target.name = 'testFullInference-train-%s' % test_name
      if component_weights:
        target.component_weights.extend(component_weights)
      else:
        target.component_weights.extend([0] * len(master_spec.component))
        target.component_weights[-1] = 1.0
      if unroll_using_oracle:
        target.unroll_using_oracle.extend(unroll_using_oracle)
      else:
        target.unroll_using_oracle.extend([False] * len(master_spec.component))
        target.unroll_using_oracle[-1] = True
      train = builder.add_training_from_config(target)
      oracle_trace = builder.add_training_from_config(
          target, prefix='train_traced-', trace_only=True)
      builder.add_saver()

      anno = builder.add_annotation(test_name)
      trace = builder.add_annotation(test_name + '-traced', enable_tracing=True)

      # Verifies that the summaries can be built.
      for component in builder.components:
        component.get_summaries()

      config = tf.ConfigProto(
          intra_op_parallelism_threads=0, inter_op_parallelism_threads=0)
      with self.test_session(config=config) as sess:
        logging.info('Initializing')
        sess.run(tf.global_variables_initializer())

        logging.info('Dry run oracle trace...')
        traces = sess.run(
            oracle_trace['traces'],
            feed_dict={oracle_trace['input_batch']: gold_reader_strings})

        # Check that the oracle traces are not empty.
        for serialized_trace in traces:
          master_trace = trace_pb2.MasterTrace()
          master_trace.ParseFromString(serialized_trace)
          self.assertTrue(master_trace.component_trace)
          self.assertTrue(master_trace.component_trace[0].step_trace)

        logging.info('Simulating training...')
        break_iter = 400
        is_resolved = False
        for i in range(0,
                       400):  # needs ~100 iterations, but is not deterministic
          cost, eval_res_val = sess.run(
              [train['cost'], train['metrics']],
              feed_dict={train['input_batch']: gold_reader_strings})
          logging.info('cost = %s', cost)
          self.assertFalse(np.isnan(cost))
          total_val = eval_res_val.reshape((-1, 2))[:, 0].sum()
          correct_val = eval_res_val.reshape((-1, 2))[:, 1].sum()
          if correct_val == total_val and not is_resolved:
            logging.info('... converged on iteration %d with (correct, total) '
                         '= (%d, %d)', i, correct_val, total_val)
            is_resolved = True
            # Run for slightly longer than convergence to help with quantized
            # weight tiebreakers.
            break_iter = i + 50

          if i == break_iter:
            break

        # If training failed, report total/correct actions for each component.
        if not expected_num_actions:
          expected_num_actions = 4 * num_evaluated_components
        if (correct_val != total_val or correct_val != expected_num_actions or
            total_val != expected_num_actions):
          for c in xrange(len(master_spec.component)):
            logging.error('component %s:\nname=%s\ntotal=%s\ncorrect=%s', c,
                          master_spec.component[c].name, eval_res_val[2 * c],
                          eval_res_val[2 * c + 1])

        assert correct_val == total_val, 'Did not converge! %d vs %d.' % (
            correct_val, total_val)

        self.assertEqual(expected_num_actions, correct_val)
        self.assertEqual(expected_num_actions, total_val)

        builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model'))

        logging.info('Running test.')
        logging.info('Printing annotations')
        annotations = sess.run(
            anno['annotations'],
            feed_dict={anno['input_batch']: test_reader_strings})
        logging.info('Put %d inputs in, got %d annotations out.',
                     len(test_reader_strings), len(annotations))

        # Also run the annotation graph with tracing enabled.
        annotations_with_trace, traces = sess.run(
            [trace['annotations'], trace['traces']],
            feed_dict={trace['input_batch']: test_reader_strings})

        # The result of the two annotation graphs should be identical.
        self.assertItemsEqual(annotations, annotations_with_trace)

        # Check that the inference traces are not empty.
        for serialized_trace in traces:
          master_trace = trace_pb2.MasterTrace()
          master_trace.ParseFromString(serialized_trace)
          self.assertTrue(master_trace.component_trace)
          self.assertTrue(master_trace.component_trace[0].step_trace)

        self.assertEqual(len(test_reader_strings), len(annotations))
        pred_sentences = []
        for annotation in annotations:
          pred_sentences.append(sentence_pb2.Sentence())
          pred_sentences[-1].ParseFromString(annotation)

        if expected is None:
          expected = _TAGGER_EXPECTED_SENTENCES

        expected_sentences = [expected[i] for i in [0, 0, 1, 0]]

        for i, pred_sentence in enumerate(pred_sentences):
          self.assertProtoEquals(expected_sentences[i], pred_sentence)
Beispiel #8
0
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Read hyperparams and master spec.
    hyperparam_config = spec_pb2.GridPoint()
    text_format.Parse(FLAGS.hyperparams, hyperparam_config)
    print hyperparam_config
    master_spec = spec_pb2.MasterSpec()

    with gfile.GFile(FLAGS.master_spec, 'r') as fin:
        text_format.Parse(fin.read(), master_spec)

    # Make output folder
    if not gfile.Exists(FLAGS.output_folder):
        gfile.MakeDirs(FLAGS.output_folder)

    # Construct TF Graph.
    graph = tf.Graph()

    with graph.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

        # Construct default per-component targets.
        default_targets = [
            spec_pb2.TrainTarget(name=component.name,
                                 max_index=idx + 1,
                                 unroll_using_oracle=[False] * idx + [True])
            for idx, component in enumerate(master_spec.component)
            if (component.transition_system.registered_name != 'shift-only')
        ]

        # Add default and manually specified targets.
        trainers = []
        for target in default_targets:
            trainers += [builder.add_training_from_config(target)]
        check.Eq(len(trainers), 1,
                 "Expected only one training target (FF unit)")

        # Construct annotation and saves. Will use moving average if enabled.
        annotator = builder.add_annotation()
        builder.add_saver()

        # Add backwards compatible training summary.
        summaries = []
        for component in builder.components:
            summaries += component.get_summaries()
        merged_summaries = tf.summary.merge_all()

        # Construct target to initialize variables.
        tf.group(tf.global_variables_initializer(), name='inits')

    # Prepare tensorboard dir.
    events_dir = os.path.join(FLAGS.output_folder, "tensorboard")
    empty_dir(events_dir)
    summary_writer = tf.summary.FileWriter(events_dir, graph)
    print "Wrote events (incl. graph) for Tensorboard to folder:", events_dir
    print "The graph can be viewed via"
    print "  tensorboard --logdir=" + events_dir
    print "  then navigating to http://localhost:6006 and clicking on 'GRAPHS'"

    with graph.as_default():
        tf.set_random_seed(hyperparam_config.seed)

    # Read train and dev corpora.
    print "Reading corpora..."
    train_corpus = read_corpus(FLAGS.train_corpus)
    dev_corpus = read_corpus(FLAGS.dev_corpus)

    # Prepare checkpoint folder.
    checkpoint_path = os.path.join(FLAGS.output_folder, 'checkpoints/best')
    checkpoint_dir = os.path.dirname(checkpoint_path)
    empty_dir(checkpoint_dir)

    with tf.Session(FLAGS.tf_master, graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())

        # Run training.
        trainer_lib.run_training(
            sess,
            trainers,
            annotator,
            evaluator,
            [0],  # pretrain_steps
            [FLAGS.train_steps],
            train_corpus,
            dev_corpus,
            dev_corpus,
            FLAGS.batch_size,
            summary_writer,
            FLAGS.report_every,
            builder.saver,
            checkpoint_path)

        # Convert model to a Myelin flow.
        if len(FLAGS.flow) != 0:
            tf.logging.info('Saving flow to %s', FLAGS.flow)
            flow = convert_model(master_spec, sess)
            flow.save(FLAGS.flow)

    tf.logging.info('Best checkpoint written to %s', checkpoint_path)
Beispiel #9
0
def main(argv):
  del argv  # unused
  # Constructs lexical resources for SyntaxNet in the given resource path, from
  # the training data.
  lexicon.build_lexicon(
      lexicon_dir,
      training_sentence,
      training_corpus_format='sentence-prototext')

  # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
  # sequence tagger.
  tagger = spec_builder.ComponentSpecBuilder('tagger')
  tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
  tagger.set_transition_system(name='tagger')
  tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
  tagger.add_rnn_link(embedding_dim=-1)
  tagger.fill_from_resources(lexicon_dir)

  # Construct the ComponentSpec for parsing.
  parser = spec_builder.ComponentSpecBuilder('parser')
  parser.set_network_unit(
      name='FeedForwardNetwork',
      hidden_layer_sizes='256',
      layer_norm_hidden='True')
  parser.set_transition_system(name='arc-standard')
  parser.add_token_link(
      source=tagger,
      fml='input.focus stack.focus stack(1).focus',
      embedding_dim=32,
      source_layer='logits')

  # Recurrent connection for the arc-standard parser. For both tokens on the
  # stack, we connect to the last time step to either SHIFT or REDUCE that
  # token. This allows the parser to build up compositional representations of
  # phrases.
  parser.add_link(
      source=parser,  # recurrent connection
      name='rnn-stack',  # unique identifier
      fml='stack.focus stack(1).focus',  # look for both stack tokens
      source_translator='shift-reduce-step',  # maps token indices -> step
      embedding_dim=32)  # project down to 32 dims
  parser.fill_from_resources(lexicon_dir)

  master_spec = spec_pb2.MasterSpec()
  master_spec.component.extend([tagger.spec, parser.spec])

  hyperparam_config = spec_pb2.GridPoint()

  # Build the TensorFlow graph.
  graph = tf.Graph()
  with graph.as_default():
    builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)

    target = spec_pb2.TrainTarget()
    target.name = 'all'
    target.unroll_using_oracle.extend([True, True])
    dry_run = builder.add_training_from_config(target, trace_only=True)

  # Read in serialized protos from training data.
  sentence = sentence_pb2.Sentence()
  text_format.Merge(open(training_sentence).read(), sentence)
  training_set = [sentence.SerializeToString()]

  with tf.Session(graph=graph) as sess:
    # Make sure to re-initialize all underlying state.
    sess.run(tf.initialize_all_variables())
    traces = sess.run(
        dry_run['traces'], feed_dict={dry_run['input_batch']: training_set})

  with open('dragnn_tutorial_2.html', 'w') as f:
    f.write(
        visualization.trace_html(
            traces[0], height='400px', master_spec=master_spec).encode('utf-8'))