Пример #1
0
def load_model(dragnn_spec,
               resource_path,
               checkpoint_filename,
               enable_tracing=False,
               tf_master=''):
    logging.set_verbosity(logging.WARN)
    # check
    check.IsTrue(dragnn_spec)
    check.IsTrue(resource_path)
    check.IsTrue(checkpoint_filename)
    # Load master spec
    master_spec = load_master_spec(dragnn_spec, resource_path)
    # Build graph
    graph, builder, annotator = build_inference_graph(
        master_spec, enable_tracing=enable_tracing)
    with graph.as_default():
        # Restore model
        sess = tf.Session(target=tf_master, graph=graph)
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())
        builder.saver.restore(sess, checkpoint_filename)
    m = {}
    m['session'] = sess
    m['graph'] = graph
    m['builder'] = builder
    m['annotator'] = annotator
    return m
Пример #2
0
def main(unused_argv) :
    if len(sys.argv) == 1 :
        flags._global_parser.print_help()
        sys.exit(0)

    logging.set_verbosity(logging.INFO)
    check.IsTrue(FLAGS.training_corpus_path)
    check.IsTrue(FLAGS.tune_corpus_path)
    check.IsTrue(FLAGS.resource_path)
    check.IsTrue(FLAGS.checkpoint_filename)

    if not gfile.IsDirectory(FLAGS.resource_path):
        gfile.MakeDirs(FLAGS.resource_path)

    training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0]
    tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0]

    # SummaryWriter for TensorBoard
    tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir)
    tf.logging.info('Deleting prior data if exists...')

    stats_file = '%s.stats' % FLAGS.checkpoint_filename
    try :
        stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',')
        stats = [int(x) for x in stats]
    except errors.OpError :
        stats = [-1, 0, 0]

    tf.logging.info('Read ckpt stats: %s', str(stats))
    do_restore = True
    if stats[0] < FLAGS.job_id :
        do_restore = False
        tf.logging.info('Deleting last job: %d', stats[0])
        try :
            gfile.DeleteRecursively(FLAGS.tensorboard_dir)
            gfile.Remove(FLAGS.checkpoint_filename)
        except errors.OpError as err :
            tf.logging.error('Unable to delete prior files: %s', err)
        stats = [FLAGS.job_id, 0, 0]

    tf.logging.info('Creating the directory again...')
    gfile.MakeDirs(FLAGS.tensorboard_dir)
    tf.logging.info('Created! Instatiating SummaryWriter...')
    summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)
    tf.logging.info('Creating TensorFlow checkpoint dir...')
    gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))

    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    if FLAGS.compute_lexicon : 
        logging.info('Computing lexicon...')
        lexicon.build_lexicon(FLAGS.resource_path, training_corpus_path, morph_to_pos=True)

    # Load master spec
    master_spec = model.load_master_spec(FLAGS.dragnn_spec, FLAGS.resource_path)
    # Build graph
    graph, builder, trainers, annotator = model.build_train_graph(master_spec)
    # Train
    train(graph, builder, trainers, annotator, summary_writer, do_restore, stats)
Пример #3
0
  def fill_from_resources(self, resource_path, tf_master=''):
    """Fills in feature sizes and vocabularies using SyntaxNet lexicon.

    Must be called before the spec is ready to be used to build TensorFlow
    graphs. Requires a SyntaxNet lexicon built at the resource_path. Using the
    lexicon, this will call the SyntaxNet custom ops to return the number of
    features and vocabulary sizes based on the FML specifications and the
    lexicons. It will also compute the number of actions of the transition
    system.

    This will often CHECK-fail if the spec doesn't correspond to a valid
    transition system or feature setup.

    Args:
      resource_path: Path to the lexicon.
      tf_master: TensorFlow master executor (string, defaults to '' to use the
        local instance).
    """
    check.IsTrue(
        self.spec.transition_system.registered_name,
        'Set a transition system before calling fill_from_resources().')

    context = lexicon.create_lexicon_context(resource_path)
    for key, value in self.spec.transition_system.parameters.iteritems():
      context.parameter.add(name=key, value=value)

    context.parameter.add(
        name='brain_parser_embedding_dims',
        value=';'.join(
            [str(x.embedding_dim) for x in self.spec.fixed_feature]))
    context.parameter.add(
        name='brain_parser_features',
        value=';'.join([x.fml for x in self.spec.fixed_feature]))
    context.parameter.add(
        name='brain_parser_predicate_maps',
        value=';'.join(['' for x in self.spec.fixed_feature]))
    context.parameter.add(
        name='brain_parser_embedding_names',
        value=';'.join([x.name for x in self.spec.fixed_feature]))
    context.parameter.add(
        name='brain_parser_transition_system',
        value=self.spec.transition_system.registered_name)

    # Propagate information from SyntaxNet C++ backends into the DRAGNN
    # self.spec.
    with tf.Session(tf_master) as sess:
      feature_sizes, domain_sizes, _, num_actions = sess.run(
          gen_parser_ops.feature_size(task_context_str=str(context)))
      self.spec.num_actions = int(num_actions)
      for i in xrange(len(feature_sizes)):
        self.spec.fixed_feature[i].size = int(feature_sizes[i])
        self.spec.fixed_feature[i].vocabulary_size = int(domain_sizes[i])

    for i in xrange(len(self.spec.linked_feature)):
      self.spec.linked_feature[i].size = len(
          self.spec.linked_feature[i].fml.split(' '))

    for resource in context.input:
      self.spec.resource.add(name=resource.name).part.add(
          file_pattern=resource.part[0].file_pattern)
Пример #4
0
def _validate_embedded_fixed_features(comp):
  """Checks that the embedded fixed features of |comp| are set up properly."""
  for feature in comp.spec.fixed_feature:
    check.Gt(feature.embedding_dim, 0,
             'Embeddings requested for non-embedded feature: %s' % feature)
    if feature.is_constant:
      check.IsTrue(feature.HasField('pretrained_embedding_matrix'),
                   'Constant embeddings must be pretrained: %s' % feature)
Пример #5
0
def main(unused_argv):
    logging.set_verbosity(logging.INFO)
    check.IsTrue(FLAGS.checkpoint_filename)
    check.IsTrue(FLAGS.tensorboard_dir)
    check.IsTrue(FLAGS.resource_path)

    if not gfile.IsDirectory(FLAGS.resource_path):
        gfile.MakeDirs(FLAGS.resource_path)

    training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0]
    tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0]

    # SummaryWriter for TensorBoard
    tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir)
    tf.logging.info('Deleting prior data if exists...')

    stats_file = '%s.stats' % FLAGS.checkpoint_filename
    try:
        stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',')
        stats = [int(x) for x in stats]
    except errors.OpError:
        stats = [-1, 0, 0]

    tf.logging.info('Read ckpt stats: %s', str(stats))
    do_restore = True
    if stats[0] < FLAGS.job_id:
        do_restore = False
        tf.logging.info('Deleting last job: %d', stats[0])
        try:
            gfile.DeleteRecursively(FLAGS.tensorboard_dir)
            gfile.Remove(FLAGS.checkpoint_filename)
        except errors.OpError as err:
            tf.logging.error('Unable to delete prior files: %s', err)
        stats = [FLAGS.job_id, 0, 0]

    tf.logging.info('Creating the directory again...')
    gfile.MakeDirs(FLAGS.tensorboard_dir)
    tf.logging.info('Created! Instatiating SummaryWriter...')
    summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)
    tf.logging.info('Creating TensorFlow checkpoint dir...')
    gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))

    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    if FLAGS.compute_lexicon:
        logging.info('Computing lexicon...')
        lexicon.build_lexicon(FLAGS.resource_path,
                              training_corpus_path,
                              morph_to_pos=True)

    tf.logging.info('Loading MasterSpec...')
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin:
        text_format.Parse(fin.read(), master_spec)
    spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path)
    logging.info('Constructed master spec: %s', str(master_spec))
    hyperparam_config = spec_pb2.GridPoint()

    # Build the TensorFlow graph.
    tf.logging.info('Building Graph...')
    hyperparam_config = spec_pb2.GridPoint()
    try:
        text_format.Parse(FLAGS.hyperparams, hyperparam_config)
    except text_format.ParseError:
        text_format.Parse(base64.b64decode(FLAGS.hyperparams),
                          hyperparam_config)
    g = tf.Graph()
    with g.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        component_targets = [
            spec_pb2.TrainTarget(name=component.name,
                                 max_index=idx + 1,
                                 unroll_using_oracle=[False] * idx + [True])
            for idx, component in enumerate(master_spec.component)
            if 'shift-only' not in component.transition_system.registered_name
        ]
        trainers = [
            builder.add_training_from_config(target)
            for target in component_targets
        ]
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    training_set = ConllSentenceReader(
        training_corpus_path,
        projectivize=FLAGS.projectivize_training_set,
        morph_to_pos=True).corpus()
    tune_set = ConllSentenceReader(tune_corpus_path,
                                   projectivize=False,
                                   morph_to_pos=True).corpus()

    # Ready to train_bkp!
    logging.info('Training on %d sentences.', len(training_set))
    logging.info('Tuning on %d sentences.', len(tune_set))

    pretrain_steps = [10000, 0]
    tagger_steps = 100000
    train_steps = [tagger_steps, 8 * tagger_steps]

    with tf.Session(FLAGS.tf_master, graph=g) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())

        if do_restore:
            tf.logging.info('Restoring from checkpoint...')
            builder.saver.restore(sess, FLAGS.checkpoint_filename)

            prev_tagger_steps = stats[1]
            prev_parser_steps = stats[2]
            tf.logging.info('adjusting schedule from steps: %d, %d',
                            prev_tagger_steps, prev_parser_steps)
            pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0)
            tf.logging.info('new pretrain steps: %d', pretrain_steps[0])

        trainer_lib.run_training(sess, trainers, annotator,
                                 evaluation.parser_summaries, pretrain_steps,
                                 train_steps, training_set, tune_set, tune_set,
                                 FLAGS.batch_size, summary_writer,
                                 FLAGS.report_every, builder.saver,
                                 FLAGS.checkpoint_filename, stats)
Пример #6
0
 def testCheckIsTrue(self):
     check.IsTrue(1 == 1.0, 'foo')
     check.IsTrue(True, 'foo')
     check.IsTrue([0], 'foo')
     check.IsTrue({'x': 1}, 'foo')
     check.IsTrue(not 0, 'foo')
     check.IsTrue(not None, 'foo')
     with self.assertRaisesRegexp(ValueError, 'bar'):
         check.IsTrue(False, 'bar')
     with self.assertRaisesRegexp(ValueError, 'bar'):
         check.IsTrue(None, 'bar')
     with self.assertRaisesRegexp(ValueError, 'bar'):
         check.IsTrue(0, 'bar')
     with self.assertRaisesRegexp(ValueError, 'bar'):
         check.IsTrue([], 'bar')
     with self.assertRaisesRegexp(ValueError, 'bar'):
         check.IsTrue({}, 'bar')
     with self.assertRaisesRegexp(RuntimeError, 'baz'):
         check.IsTrue('', 'baz', RuntimeError)