def main(unused_argv): logging.set_verbosity(logging.INFO) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon(FLAGS.resource_path, FLAGS.training_corpus_path) # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN # sequence model, which encodes the context to the right of each token. It has # no loss except for the downstream components. lookahead = spec_builder.ComponentSpecBuilder('lookahead') lookahead.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') lookahead.set_transition_system(name='shift-only', left_to_right='false') lookahead.add_fixed_feature(name='char', fml='input(-1).char input.char input(1).char', embedding_dim=32) lookahead.add_fixed_feature(name='char-bigram', fml='input.char-bigram', embedding_dim=32) lookahead.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Construct the ComponentSpec for segmentation. segmenter = spec_builder.ComponentSpecBuilder('segmenter') segmenter.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='128') segmenter.set_transition_system(name='binary-segment-transitions') segmenter.add_token_link(source=lookahead, fml='input.focus stack.focus', embedding_dim=64) segmenter.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Build and write master_spec. master_spec = spec_pb2.MasterSpec() master_spec.component.extend([lookahead.spec, segmenter.spec]) logging.info('Constructed master spec: %s', str(master_spec)) with gfile.GFile(FLAGS.resource_path + '/master_spec', 'w') as f: f.write(str(master_spec).encode('utf-8')) hyperparam_config = spec_pb2.GridPoint() try: text_format.Parse(FLAGS.hyperparams, hyperparam_config) except text_format.ParseError: text_format.Parse(base64.b64decode(FLAGS.hyperparams), hyperparam_config) # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = spec_builder.default_targets_from_spec(master_spec) trainers = [ builder.add_training_from_config(target) for target in component_targets ] assert len(trainers) == 1 annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = ConllSentenceReader(FLAGS.training_corpus_path, projectivize=False).corpus() dev_set = ConllSentenceReader(FLAGS.dev_corpus_path, projectivize=False).corpus() # Convert word-based docs to char-based documents for segmentation training # and evaluation. with tf.Session(graph=tf.Graph()) as tmp_session: char_training_set_op = gen_parser_ops.segmenter_training_data_constructor( training_set) char_dev_set_op = gen_parser_ops.char_token_generator(dev_set) char_training_set = tmp_session.run(char_training_set_op) char_dev_set = tmp_session.run(char_dev_set_op) # Ready to train! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(dev_set)) pretrain_steps = [0] train_steps = [FLAGS.num_epochs * len(training_set)] tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) trainer_lib.run_training( sess, trainers, annotator, evaluation.segmentation_summaries, pretrain_steps, train_steps, char_training_set, char_dev_set, dev_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename)
def main(unused_argv): logging.set_verbosity(logging.INFO) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon(FLAGS.resource_path, FLAGS.training_corpus_path) # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN # sequence model, which encodes the context to the right of each token. It has # no loss except for the downstream components. char2word = spec_builder.ComponentSpecBuilder('char_lstm') char2word.set_network_unit( name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') char2word.set_transition_system(name='char-shift-only', left_to_right='true') char2word.add_fixed_feature(name='chars', fml='char-input.text-char', embedding_dim=16) char2word.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) lookahead = spec_builder.ComponentSpecBuilder('lookahead') lookahead.set_network_unit( name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') lookahead.set_transition_system(name='shift-only', left_to_right='false') lookahead.add_link(source=char2word, fml='input.last-char-focus', embedding_dim=32) lookahead.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN # sequence tagger. tagger = spec_builder.ComponentSpecBuilder('tagger') tagger.set_network_unit( name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') tagger.set_transition_system(name='tagger') tagger.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32) tagger.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Construct the ComponentSpec for parsing. parser = spec_builder.ComponentSpecBuilder('parser') parser.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256', layer_norm_hidden='True') parser.set_transition_system(name='arc-standard') parser.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32) parser.add_token_link( source=tagger, fml='input.focus stack.focus stack(1).focus', embedding_dim=32) # Recurrent connection for the arc-standard parser. For both tokens on the # stack, we connect to the last time step to either SHIFT or REDUCE that # token. This allows the parser to build up compositional representations of # phrases. parser.add_link( source=parser, # recurrent connection name='rnn-stack', # unique identifier fml='stack.focus stack(1).focus', # look for both stack tokens source_translator='shift-reduce-step', # maps token indices -> step embedding_dim=32) # project down to 32 dims parser.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) master_spec = spec_pb2.MasterSpec() master_spec.component.extend([char2word.spec, lookahead.spec, tagger.spec, parser.spec]) logging.info('Constructed master spec: %s', str(master_spec)) hyperparam_config = spec_pb2.GridPoint() hyperparam_config.decay_steps = 128000 hyperparam_config.learning_rate = 0.001 hyperparam_config.learning_method = 'adam' hyperparam_config.adam_beta1 = 0.9 hyperparam_config.adam_beta2 = 0.9 hyperparam_config.adam_eps = 0.0001 hyperparam_config.gradient_clip_norm = 1 hyperparam_config.self_norm_alpha = 1.0 hyperparam_config.use_moving_average = True hyperparam_config.dropout_rate = 0.7 hyperparam_config.seed = 1 # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = spec_builder.default_targets_from_spec(master_spec) trainers = [ builder.add_training_from_config(target) for target in component_targets ] assert len(trainers) == 2 annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = sentence_io.ConllSentenceReader( FLAGS.training_corpus_path, projectivize=FLAGS.projectivize_training_set).corpus() dev_set = sentence_io.ConllSentenceReader( FLAGS.dev_corpus_path, projectivize=False).corpus() # Ready to train! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(dev_set)) pretrain_steps = [100, 0] tagger_steps = 1000 train_steps = [tagger_steps, 8 * tagger_steps] tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) trainer_lib.run_training( sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, training_set, dev_set, dev_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename)
def main(unused_argv): logging.set_verbosity(logging.INFO) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon(FLAGS.resource_path, FLAGS.training_corpus_path) # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN # sequence model, which encodes the context to the right of each token. It has # no loss except for the downstream components. char2word = spec_builder.ComponentSpecBuilder('char_lstm') char2word.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') char2word.set_transition_system(name='char-shift-only', left_to_right='true') char2word.add_fixed_feature(name='chars', fml='char-input.text-char', embedding_dim=16) char2word.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) lookahead = spec_builder.ComponentSpecBuilder('lookahead') lookahead.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') lookahead.set_transition_system(name='shift-only', left_to_right='false') lookahead.add_link(source=char2word, fml='input.last-char-focus', embedding_dim=32) lookahead.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN # sequence tagger. tagger = spec_builder.ComponentSpecBuilder('tagger') tagger.set_network_unit(name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') tagger.set_transition_system(name='tagger') tagger.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32) tagger.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Construct the ComponentSpec for parsing. parser = spec_builder.ComponentSpecBuilder('parser') parser.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256', layer_norm_hidden='True') parser.set_transition_system(name='arc-standard') parser.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32) parser.add_token_link(source=tagger, fml='input.focus stack.focus stack(1).focus', embedding_dim=32) # Recurrent connection for the arc-standard parser. For both tokens on the # stack, we connect to the last time step to either SHIFT or REDUCE that # token. This allows the parser to build up compositional representations of # phrases. parser.add_link( source=parser, # recurrent connection name='rnn-stack', # unique identifier fml='stack.focus stack(1).focus', # look for both stack tokens source_translator='shift-reduce-step', # maps token indices -> step embedding_dim=32) # project down to 32 dims parser.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) master_spec = spec_pb2.MasterSpec() master_spec.component.extend( [char2word.spec, lookahead.spec, tagger.spec, parser.spec]) logging.info('Constructed master spec: %s', str(master_spec)) hyperparam_config = spec_pb2.GridPoint() hyperparam_config.decay_steps = 128000 hyperparam_config.learning_rate = 0.001 hyperparam_config.learning_method = 'adam' hyperparam_config.adam_beta1 = 0.9 hyperparam_config.adam_beta2 = 0.9 hyperparam_config.adam_eps = 0.0001 hyperparam_config.gradient_clip_norm = 1 hyperparam_config.self_norm_alpha = 1.0 hyperparam_config.use_moving_average = True hyperparam_config.dropout_rate = 0.7 hyperparam_config.seed = 1 # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = spec_builder.default_targets_from_spec(master_spec) trainers = [ builder.add_training_from_config(target) for target in component_targets ] assert len(trainers) == 2 annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = sentence_io.ConllSentenceReader( FLAGS.training_corpus_path, projectivize=FLAGS.projectivize_training_set).corpus() dev_set = sentence_io.ConllSentenceReader(FLAGS.dev_corpus_path, projectivize=False).corpus() # Ready to train! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(dev_set)) pretrain_steps = [100, 0] tagger_steps = 1000 train_steps = [tagger_steps, 8 * tagger_steps] tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) trainer_lib.run_training(sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, training_set, dev_set, dev_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) check.NotNone(FLAGS.model_dir, '--model_dir is required') check.Ne( FLAGS.pretrain_steps is None, FLAGS.pretrain_epochs is None, 'Exactly one of --pretrain_steps or --pretrain_epochs is required') check.Ne(FLAGS.train_steps is None, FLAGS.train_epochs is None, 'Exactly one of --train_steps or --train_epochs is required') config_path = os.path.join(FLAGS.model_dir, 'config.txt') master_path = os.path.join(FLAGS.model_dir, 'master.pbtxt') hyperparameters_path = os.path.join(FLAGS.model_dir, 'hyperparameters.pbtxt') targets_path = os.path.join(FLAGS.model_dir, 'targets.pbtxt') checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoints/best') tensorboard_dir = os.path.join(FLAGS.model_dir, 'tensorboard') with tf.gfile.FastGFile(config_path) as config_file: config = collections.defaultdict(bool, ast.literal_eval(config_file.read())) train_corpus_path = config['train_corpus_path'] tune_corpus_path = config['tune_corpus_path'] projectivize_train_corpus = config['projectivize_train_corpus'] master = _read_text_proto(master_path, spec_pb2.MasterSpec) hyperparameters = _read_text_proto(hyperparameters_path, spec_pb2.GridPoint) targets = spec_builder.default_targets_from_spec(master) if tf.gfile.Exists(targets_path): targets = _read_text_proto(targets_path, spec_pb2.TrainingGridSpec).target # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): tf.set_random_seed(hyperparameters.seed) builder = graph_builder.MasterBuilder(master, hyperparameters) trainers = [ builder.add_training_from_config(target) for target in targets ] annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. train_corpus = sentence_io.ConllSentenceReader( train_corpus_path, projectivize=projectivize_train_corpus).corpus() tune_corpus = sentence_io.ConllSentenceReader(tune_corpus_path, projectivize=False).corpus() gold_tune_corpus = tune_corpus # Convert to char-based corpora, if requested. if config['convert_to_char_corpora']: # NB: Do not convert the |gold_tune_corpus|, which should remain word-based # for segmentation evaluation purposes. train_corpus = _convert_to_char_corpus(train_corpus) tune_corpus = _convert_to_char_corpus(tune_corpus) pretrain_steps = _get_steps(FLAGS.pretrain_steps, FLAGS.pretrain_epochs, len(train_corpus)) train_steps = _get_steps(FLAGS.train_steps, FLAGS.train_epochs, len(train_corpus)) check.Eq(len(targets), len(pretrain_steps), 'Length mismatch between training targets and --pretrain_steps') check.Eq(len(targets), len(train_steps), 'Length mismatch between training targets and --train_steps') # Ready to train! tf.logging.info('Training on %d sentences.', len(train_corpus)) tf.logging.info('Tuning on %d sentences.', len(tune_corpus)) tf.logging.info('Creating TensorFlow checkpoint dir...') summary_writer = trainer_lib.get_summary_writer(tensorboard_dir) checkpoint_dir = os.path.dirname(checkpoint_path) if tf.gfile.IsDirectory(checkpoint_dir): tf.gfile.DeleteRecursively(checkpoint_dir) elif tf.gfile.Exists(checkpoint_dir): tf.gfile.Remove(checkpoint_dir) tf.gfile.MakeDirs(checkpoint_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) trainer_lib.run_training(sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, train_corpus, tune_corpus, gold_tune_corpus, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, checkpoint_path) tf.logging.info('Best checkpoint written to:\n%s', checkpoint_path)
def main(unused_argv): logging.set_verbosity(logging.INFO) if not gfile.IsDirectory(FLAGS.resource_path): gfile.MakeDirs(FLAGS.resource_path) # Constructs lexical resources for SyntaxNet in the given resource path, from # the training data. if FLAGS.compute_lexicon: logging.info('Computing lexicon...') lexicon.build_lexicon(FLAGS.resource_path, FLAGS.training_corpus_path) # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN # sequence model, which encodes the context to the right of each token. It has # no loss except for the downstream components. lookahead = spec_builder.ComponentSpecBuilder('lookahead') lookahead.set_network_unit( name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='256') lookahead.set_transition_system(name='shift-only', left_to_right='false') lookahead.add_fixed_feature(name='char', fml='input(-1).char input.char input(1).char', embedding_dim=32) lookahead.add_fixed_feature(name='char-bigram', fml='input.char-bigram', embedding_dim=32) lookahead.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Construct the ComponentSpec for segmentation. segmenter = spec_builder.ComponentSpecBuilder('segmenter') segmenter.set_network_unit( name='wrapped_units.LayerNormBasicLSTMNetwork', hidden_layer_sizes='128') segmenter.set_transition_system(name='binary-segment-transitions') segmenter.add_token_link( source=lookahead, fml='input.focus stack.focus', embedding_dim=64) segmenter.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master) # Build and write master_spec. master_spec = spec_pb2.MasterSpec() master_spec.component.extend([lookahead.spec, segmenter.spec]) logging.info('Constructed master spec: %s', str(master_spec)) with gfile.GFile(FLAGS.resource_path + '/master_spec', 'w') as f: f.write(str(master_spec).encode('utf-8')) hyperparam_config = spec_pb2.GridPoint() try: text_format.Parse(FLAGS.hyperparams, hyperparam_config) except text_format.ParseError: text_format.Parse(base64.b64decode(FLAGS.hyperparams), hyperparam_config) # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): builder = graph_builder.MasterBuilder(master_spec, hyperparam_config) component_targets = spec_builder.default_targets_from_spec(master_spec) trainers = [ builder.add_training_from_config(target) for target in component_targets ] assert len(trainers) == 1 annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. training_set = ConllSentenceReader( FLAGS.training_corpus_path, projectivize=False).corpus() dev_set = ConllSentenceReader( FLAGS.dev_corpus_path, projectivize=False).corpus() # Convert word-based docs to char-based documents for segmentation training # and evaluation. with tf.Session(graph=tf.Graph()) as tmp_session: char_training_set_op = gen_parser_ops.segmenter_training_data_constructor( training_set) char_dev_set_op = gen_parser_ops.char_token_generator(dev_set) char_training_set = tmp_session.run(char_training_set_op) char_dev_set = tmp_session.run(char_dev_set_op) # Ready to train! logging.info('Training on %d sentences.', len(training_set)) logging.info('Tuning on %d sentences.', len(dev_set)) pretrain_steps = [0] train_steps = [FLAGS.num_epochs * len(training_set)] tf.logging.info('Creating TensorFlow checkpoint dir...') gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename)) summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) trainer_lib.run_training( sess, trainers, annotator, evaluation.segmentation_summaries, pretrain_steps, train_steps, char_training_set, char_dev_set, dev_set, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, FLAGS.checkpoint_filename)
def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) check.NotNone(FLAGS.model_dir, '--model_dir is required') check.Ne(FLAGS.pretrain_steps is None, FLAGS.pretrain_epochs is None, 'Exactly one of --pretrain_steps or --pretrain_epochs is required') check.Ne(FLAGS.train_steps is None, FLAGS.train_epochs is None, 'Exactly one of --train_steps or --train_epochs is required') config_path = os.path.join(FLAGS.model_dir, 'config.txt') master_path = os.path.join(FLAGS.model_dir, 'master.pbtxt') hyperparameters_path = os.path.join(FLAGS.model_dir, 'hyperparameters.pbtxt') targets_path = os.path.join(FLAGS.model_dir, 'targets.pbtxt') checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoints/best') tensorboard_dir = os.path.join(FLAGS.model_dir, 'tensorboard') with tf.gfile.FastGFile(config_path) as config_file: config = collections.defaultdict(bool, ast.literal_eval(config_file.read())) train_corpus_path = config['train_corpus_path'] tune_corpus_path = config['tune_corpus_path'] projectivize_train_corpus = config['projectivize_train_corpus'] master = _read_text_proto(master_path, spec_pb2.MasterSpec) hyperparameters = _read_text_proto(hyperparameters_path, spec_pb2.GridPoint) targets = spec_builder.default_targets_from_spec(master) if tf.gfile.Exists(targets_path): targets = _read_text_proto(targets_path, spec_pb2.TrainingGridSpec).target # Build the TensorFlow graph. graph = tf.Graph() with graph.as_default(): tf.set_random_seed(hyperparameters.seed) builder = graph_builder.MasterBuilder(master, hyperparameters) trainers = [ builder.add_training_from_config(target) for target in targets ] annotator = builder.add_annotation() builder.add_saver() # Read in serialized protos from training data. train_corpus = sentence_io.ConllSentenceReader( train_corpus_path, projectivize=projectivize_train_corpus).corpus() tune_corpus = sentence_io.ConllSentenceReader( tune_corpus_path, projectivize=False).corpus() gold_tune_corpus = tune_corpus # Convert to char-based corpora, if requested. if config['convert_to_char_corpora']: # NB: Do not convert the |gold_tune_corpus|, which should remain word-based # for segmentation evaluation purposes. train_corpus = _convert_to_char_corpus(train_corpus) tune_corpus = _convert_to_char_corpus(tune_corpus) pretrain_steps = _get_steps(FLAGS.pretrain_steps, FLAGS.pretrain_epochs, len(train_corpus)) train_steps = _get_steps(FLAGS.train_steps, FLAGS.train_epochs, len(train_corpus)) check.Eq(len(targets), len(pretrain_steps), 'Length mismatch between training targets and --pretrain_steps') check.Eq(len(targets), len(train_steps), 'Length mismatch between training targets and --train_steps') # Ready to train! tf.logging.info('Training on %d sentences.', len(train_corpus)) tf.logging.info('Tuning on %d sentences.', len(tune_corpus)) tf.logging.info('Creating TensorFlow checkpoint dir...') summary_writer = trainer_lib.get_summary_writer(tensorboard_dir) checkpoint_dir = os.path.dirname(checkpoint_path) if tf.gfile.IsDirectory(checkpoint_dir): tf.gfile.DeleteRecursively(checkpoint_dir) elif tf.gfile.Exists(checkpoint_dir): tf.gfile.Remove(checkpoint_dir) tf.gfile.MakeDirs(checkpoint_dir) with tf.Session(FLAGS.tf_master, graph=graph) as sess: # Make sure to re-initialize all underlying state. sess.run(tf.global_variables_initializer()) trainer_lib.run_training(sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps, train_steps, train_corpus, tune_corpus, gold_tune_corpus, FLAGS.batch_size, summary_writer, FLAGS.report_every, builder.saver, checkpoint_path) tf.logging.info('Best checkpoint written to:\n%s', checkpoint_path)