def build_train_graph(master_spec, hyperparam_config=None): # Build the TensorFlow graph based on the DRAGNN network spec. tf.logging.info('Building Graph...') if not hyperparam_config: hyperparam_config = spec_pb2.GridPoint(learning_method='adam', learning_rate=0.0005, adam_beta1=0.9, adam_beta2=0.9, adam_eps=0.00001, decay_steps=128000, dropout_rate=0.8, gradient_clip_norm=1, use_moving_average=True, seed=1) graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = [ spec_pb2.TrainTarget(name=component.name, max_index=idx + 1, unroll_using_oracle=[False] * idx + [True]) for idx, component in enumerate(master_spec.component) if 'shift-only' not in component.transition_system.registered_name ] trainers = [ builder.add_training_from_config(target) for target in component_targets ] annotator = builder.add_annotation(enable_tracing=True) builder.add_saver() return graph, builder, trainers, annotator
def RunTraining(self, hyperparam_config): master_spec = self.LoadSpec('master_spec_link.textproto') self.assertTrue(isinstance(hyperparam_config, spec_pb2.GridPoint)) gold_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc) gold_doc_2 = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2) reader_strings = [ gold_doc.SerializeToString(), gold_doc_2.SerializeToString() ] tf.logging.info('Generating graph with config: %s', hyperparam_config) with tf.Graph().as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) target = spec_pb2.TrainTarget() target.name = 'testTraining-all' train = builder.add_training_from_config(target) with self.test_session() as sess: logging.info('Initializing') sess.run(tf.global_variables_initializer()) # Run one iteration of training and verify nothing crashes. logging.info('Training') sess.run(train['run'], feed_dict={train['input_batch']: reader_strings})
def default_targets_from_spec(spec): """Constructs a default set of TrainTarget protos from a DRAGNN spec. For each component in the DRAGNN spec, it adds a training target for that component's oracle. It also stops unrolling the graph with that component. It skips any 'shift-only' transition systems which have no oracle. E.g.: if there are three components, a 'shift-only', a 'tagger', and a 'arc-standard', it will construct two training targets, one for the tagger and one for the arc-standard parser. Arguments: spec: DRAGNN spec. Returns: List of TrainTarget protos. """ component_targets = [ spec_pb2.TrainTarget( name=component.name, max_index=idx + 1, unroll_using_oracle=[False] * idx + [True]) for idx, component in enumerate(spec.component) if not component.transition_system.registered_name.endswith('shift-only') ] return component_targets
def getBuilderAndTarget( self, test_name, master_spec_path='simple_parser_master_spec.textproto'): """Generates a MasterBuilder and TrainTarget based on a simple spec.""" master_spec = self.LoadSpec(master_spec_path) hyperparam_config = spec_pb2.GridPoint() target = spec_pb2.TrainTarget() target.name = 'test-%s-train' % test_name target.component_weights.extend([0] * len(master_spec.component)) target.component_weights[-1] = 1.0 target.unroll_using_oracle.extend([False] * len(master_spec.component)) target.unroll_using_oracle[-1] = True builder = graph_builder.MasterBuilder( master_spec, hyperparam_config, pool_scope=test_name) return builder, target
def main(argv): del argv # unused # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. lexicon.build_lexicon(lexicon_dir, training_sentence, training_corpus_format='sentence-prototext') # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN # sequence tagger. tagger = spec_builder.ComponentSpecBuilder('tagger') tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256') tagger.set_transition_system(name='tagger') tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64) tagger.add_rnn_link(embedding_dim=-1) tagger.fill_from_resources(lexicon_dir) master_spec = spec_pb2.MasterSpec() master_spec.component.extend([tagger.spec]) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) target = spec_pb2.TrainTarget() target.name = 'all' target.unroll_using_oracle.extend([True]) dry_run = builder.add_training_from_config(target, trace_only=True) # Read in serialized protos from training data. sentence = sentence_pb2.Sentence() text_format.Merge(open(training_sentence).read(), sentence) training_set = [sentence.SerializeToString()] with tf.Session(graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.initialize_all_variables()) traces = sess.run(dry_run['traces'], feed_dict={dry_run['input_batch']: training_set}) with open('dragnn_tutorial_1.html', 'w') as f: f.write( visualization.trace_html(traces[0], height='300px').encode('utf-8'))
def main(unused_argv): logging.set_verbosity(logging.INFO) check.IsTrue(FLAGS.checkpoint_filename) check.IsTrue(FLAGS.tensorboard_dir) check.IsTrue(FLAGS.resource_path) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0] tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0] # SummaryWriter for TensorBoard tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir) tf.logging.info('Deleting prior data if exists...') stats_file = '%s.stats' % FLAGS.checkpoint_filename try: stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',') stats = [int(x) for x in stats] except errors.OpError: stats = [-1, 0, 0] tf.logging.info('Read ckpt stats: %s', str(stats)) do_restore = True if stats[0] < FLAGS.job_id: do_restore = False tf.logging.info('Deleting last job: %d', stats[0]) try: gfile.DeleteRecursively(FLAGS.tensorboard_dir) gfile.Remove(FLAGS.checkpoint_filename) except errors.OpError as err: tf.logging.error('Unable to delete prior files: %s', err) stats = [FLAGS.job_id, 0, 0] tf.logging.info('Creating the directory again...') gfile.MakeDirs(FLAGS.tensorboard_dir) tf.logging.info('Created! Instatiating SummaryWriter...') summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon(FLAGS.resource_path, training_corpus_path, morph_to_pos=True) tf.logging.info('Loading MasterSpec...') master_spec = spec_pb2.MasterSpec() with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin: text_format.Parse(fin.read(), master_spec) spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path) logging.info('Constructed master spec: %s', str(master_spec)) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. tf.logging.info('Building Graph...') hyperparam_config = spec_pb2.GridPoint() try: text_format.Parse(FLAGS.hyperparams, hyperparam_config) except text_format.ParseError: text_format.Parse(base64.b64decode(FLAGS.hyperparams), hyperparam_config) g = tf.Graph() with g.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = [ spec_pb2.TrainTarget(name=component.name, max_index=idx + 1, unroll_using_oracle=[False] * idx + [True]) for idx, component in enumerate(master_spec.component) if 'shift-only' not in component.transition_system.registered_name ] trainers = [ builder.add_training_from_config(target) for target in component_targets ] annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = ConllSentenceReader( training_corpus_path, projectivize=FLAGS.projectivize_training_set, morph_to_pos=True).corpus() tune_set = ConllSentenceReader(tune_corpus_path, projectivize=False, morph_to_pos=True).corpus() # Ready to train_bkp! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(tune_set)) pretrain_steps = [10000, 0] tagger_steps = 100000 train_steps = [tagger_steps, 8 * tagger_steps] with tf.Session(FLAGS.tf_master, graph=g) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) if do_restore: tf.logging.info('Restoring from checkpoint...') builder.saver.restore(sess, FLAGS.checkpoint_filename) prev_tagger_steps = stats[1] prev_parser_steps = stats[2] tf.logging.info('adjusting schedule from steps: %d, %d', prev_tagger_steps, prev_parser_steps) pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0) tf.logging.info('new pretrain steps: %d', pretrain_steps[0]) trainer_lib.run_training(sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, training_set, tune_set, tune_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename, stats)
def RunFullTrainingAndInference(self, test_name, master_spec_path=None, master_spec=None, component_weights=None, unroll_using_oracle=None, num_evaluated_components=1, expected_num_actions=None, expected=None, batch_size_limit=None): if not master_spec: master_spec = self.LoadSpec(master_spec_path) gold_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc) gold_doc_2 = sentence_pb2.Sentence() text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2) gold_reader_strings = [ gold_doc.SerializeToString(), gold_doc_2.SerializeToString() ] test_doc = sentence_pb2.Sentence() text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc) test_doc_2 = sentence_pb2.Sentence() text_format.Parse(_DUMMY_TEST_SENTENCE_2, test_doc_2) test_reader_strings = [ test_doc.SerializeToString(), test_doc.SerializeToString(), test_doc_2.SerializeToString(), test_doc.SerializeToString() ] if batch_size_limit is not None: gold_reader_strings = gold_reader_strings[:batch_size_limit] test_reader_strings = test_reader_strings[:batch_size_limit] with tf.Graph().as_default(): tf.set_random_seed(1) hyperparam_config = spec_pb2.GridPoint() builder = graph_builder.MasterBuilder( master_spec, hyperparam_config, pool_scope=test_name) target = spec_pb2.TrainTarget() target.name = 'testFullInference-train-%s' % test_name if component_weights: target.component_weights.extend(component_weights) else: target.component_weights.extend([0] * len(master_spec.component)) target.component_weights[-1] = 1.0 if unroll_using_oracle: target.unroll_using_oracle.extend(unroll_using_oracle) else: target.unroll_using_oracle.extend([False] * len(master_spec.component)) target.unroll_using_oracle[-1] = True train = builder.add_training_from_config(target) oracle_trace = builder.add_training_from_config( target, prefix='train_traced-', trace_only=True) builder.add_saver() anno = builder.add_annotation(test_name) trace = builder.add_annotation(test_name + '-traced', enable_tracing=True) # Verifies that the summaries can be built. for component in builder.components: component.get_summaries() config = tf.ConfigProto( intra_op_parallelism_threads=0, inter_op_parallelism_threads=0) with self.test_session(config=config) as sess: logging.info('Initializing') sess.run(tf.global_variables_initializer()) logging.info('Dry run oracle trace...') traces = sess.run( oracle_trace['traces'], feed_dict={oracle_trace['input_batch']: gold_reader_strings}) # Check that the oracle traces are not empty. for serialized_trace in traces: master_trace = trace_pb2.MasterTrace() master_trace.ParseFromString(serialized_trace) self.assertTrue(master_trace.component_trace) self.assertTrue(master_trace.component_trace[0].step_trace) logging.info('Simulating training...') break_iter = 400 is_resolved = False for i in range(0, 400): # needs ~100 iterations, but is not deterministic cost, eval_res_val = sess.run( [train['cost'], train['metrics']], feed_dict={train['input_batch']: gold_reader_strings}) logging.info('cost = %s', cost) self.assertFalse(np.isnan(cost)) total_val = eval_res_val.reshape((-1, 2))[:, 0].sum() correct_val = eval_res_val.reshape((-1, 2))[:, 1].sum() if correct_val == total_val and not is_resolved: logging.info('... converged on iteration %d with (correct, total) ' '= (%d, %d)', i, correct_val, total_val) is_resolved = True # Run for slightly longer than convergence to help with quantized # weight tiebreakers. break_iter = i + 50 if i == break_iter: break # If training failed, report total/correct actions for each component. if not expected_num_actions: expected_num_actions = 4 * num_evaluated_components if (correct_val != total_val or correct_val != expected_num_actions or total_val != expected_num_actions): for c in xrange(len(master_spec.component)): logging.error('component %s:\nname=%s\ntotal=%s\ncorrect=%s', c, master_spec.component[c].name, eval_res_val[2 * c], eval_res_val[2 * c + 1]) assert correct_val == total_val, 'Did not converge! %d vs %d.' % ( correct_val, total_val) self.assertEqual(expected_num_actions, correct_val) self.assertEqual(expected_num_actions, total_val) builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model')) logging.info('Running test.') logging.info('Printing annotations') annotations = sess.run( anno['annotations'], feed_dict={anno['input_batch']: test_reader_strings}) logging.info('Put %d inputs in, got %d annotations out.', len(test_reader_strings), len(annotations)) # Also run the annotation graph with tracing enabled. annotations_with_trace, traces = sess.run( [trace['annotations'], trace['traces']], feed_dict={trace['input_batch']: test_reader_strings}) # The result of the two annotation graphs should be identical. self.assertItemsEqual(annotations, annotations_with_trace) # Check that the inference traces are not empty. for serialized_trace in traces: master_trace = trace_pb2.MasterTrace() master_trace.ParseFromString(serialized_trace) self.assertTrue(master_trace.component_trace) self.assertTrue(master_trace.component_trace[0].step_trace) self.assertEqual(len(test_reader_strings), len(annotations)) pred_sentences = [] for annotation in annotations: pred_sentences.append(sentence_pb2.Sentence()) pred_sentences[-1].ParseFromString(annotation) if expected is None: expected = _TAGGER_EXPECTED_SENTENCES expected_sentences = [expected[i] for i in [0, 0, 1, 0]] for i, pred_sentence in enumerate(pred_sentences): self.assertProtoEquals(expected_sentences[i], pred_sentence)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) # Read hyperparams and master spec. hyperparam_config = spec_pb2.GridPoint() text_format.Parse(FLAGS.hyperparams, hyperparam_config) print hyperparam_config master_spec = spec_pb2.MasterSpec() with gfile.GFile(FLAGS.master_spec, 'r') as fin: text_format.Parse(fin.read(), master_spec) # Make output folder if not gfile.Exists(FLAGS.output_folder): gfile.MakeDirs(FLAGS.output_folder) # Construct TF Graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) # Construct default per-component targets. default_targets = [ spec_pb2.TrainTarget(name=component.name, max_index=idx + 1, unroll_using_oracle=[False] * idx + [True]) for idx, component in enumerate(master_spec.component) if (component.transition_system.registered_name != 'shift-only') ] # Add default and manually specified targets. trainers = [] for target in default_targets: trainers += [builder.add_training_from_config(target)] check.Eq(len(trainers), 1, "Expected only one training target (FF unit)") # Construct annotation and saves. Will use moving average if enabled. annotator = builder.add_annotation() builder.add_saver() # Add backwards compatible training summary. summaries = [] for component in builder.components: summaries += component.get_summaries() merged_summaries = tf.summary.merge_all() # Construct target to initialize variables. tf.group(tf.global_variables_initializer(), name='inits') # Prepare tensorboard dir. events_dir = os.path.join(FLAGS.output_folder, "tensorboard") empty_dir(events_dir) summary_writer = tf.summary.FileWriter(events_dir, graph) print "Wrote events (incl. graph) for Tensorboard to folder:", events_dir print "The graph can be viewed via" print " tensorboard --logdir=" + events_dir print " then navigating to http://localhost:6006 and clicking on 'GRAPHS'" with graph.as_default(): tf.set_random_seed(hyperparam_config.seed) # Read train and dev corpora. print "Reading corpora..." train_corpus = read_corpus(FLAGS.train_corpus) dev_corpus = read_corpus(FLAGS.dev_corpus) # Prepare checkpoint folder. checkpoint_path = os.path.join(FLAGS.output_folder, 'checkpoints/best') checkpoint_dir = os.path.dirname(checkpoint_path) empty_dir(checkpoint_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) # Run training. trainer_lib.run_training( sess, trainers, annotator, evaluator, [0], # pretrain_steps [FLAGS.train_steps], train_corpus, dev_corpus, dev_corpus, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, checkpoint_path) # Convert model to a Myelin flow. if len(FLAGS.flow) != 0: tf.logging.info('Saving flow to %s', FLAGS.flow) flow = convert_model(master_spec, sess) flow.save(FLAGS.flow) tf.logging.info('Best checkpoint written to %s', checkpoint_path)
def main(argv): del argv # unused # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. lexicon.build_lexicon( lexicon_dir, training_sentence, training_corpus_format='sentence-prototext') # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN # sequence tagger. tagger = spec_builder.ComponentSpecBuilder('tagger') tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256') tagger.set_transition_system(name='tagger') tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64) tagger.add_rnn_link(embedding_dim=-1) tagger.fill_from_resources(lexicon_dir) # Construct the ComponentSpec for parsing. parser = spec_builder.ComponentSpecBuilder('parser') parser.set_network_unit( name='FeedForwardNetwork', hidden_layer_sizes='256', layer_norm_hidden='True') parser.set_transition_system(name='arc-standard') parser.add_token_link( source=tagger, fml='input.focus stack.focus stack(1).focus', embedding_dim=32, source_layer='logits') # Recurrent connection for the arc-standard parser. For both tokens on the # stack, we connect to the last time step to either SHIFT or REDUCE that # token. This allows the parser to build up compositional representations of # phrases. parser.add_link( source=parser, # recurrent connection name='rnn-stack', # unique identifier fml='stack.focus stack(1).focus', # look for both stack tokens source_translator='shift-reduce-step', # maps token indices -> step embedding_dim=32) # project down to 32 dims parser.fill_from_resources(lexicon_dir) master_spec = spec_pb2.MasterSpec() master_spec.component.extend([tagger.spec, parser.spec]) hyperparam_config = spec_pb2.GridPoint() # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) target = spec_pb2.TrainTarget() target.name = 'all' target.unroll_using_oracle.extend([True, True]) dry_run = builder.add_training_from_config(target, trace_only=True) # Read in serialized protos from training data. sentence = sentence_pb2.Sentence() text_format.Merge(open(training_sentence).read(), sentence) training_set = [sentence.SerializeToString()] with tf.Session(graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.initialize_all_variables()) traces = sess.run( dry_run['traces'], feed_dict={dry_run['input_batch']: training_set}) with open('dragnn_tutorial_2.html', 'w') as f: f.write( visualization.trace_html( traces[0], height='400px', master_spec=master_spec).encode('utf-8'))