예제 #1
0
def _get_hook_name(component, variable_name, suffix):
    """Builds the name of a hook node.

  Specifically, the name of the hook node is:

    <component.name>/<variable_name><suffix><remainder>

  where <remainder> is whatever follows <variable_name> in the name of the op
  that produces the named variable.  Recall that component.get_variable() may
  return either the original variable or its moving average.  These might have
  names like:

    foo_component/bar_variable
    foo_component/bar_variable/ExponentialMovingAverage

  In the examples above, the <remainder> is "" for the original variable and
  "/ExponentialMovingAverage" for its moving average.  Calling this function
  with suffix="/baz_suffix" in either case would add hook nodes named:

    foo_component/bar_variable/baz_suffix
    foo_component/bar_variable/baz_suffix/ExponentialMovingAverage

  Note that the suffix is inserted after the variable name, not necessarily at
  the end of the entire op name.

  Args:
    component: Component that the hook node belongs to.
    variable_name: Variable that the hook node name is based on.
    suffix: Suffix to append to the variable name.

  Returns:
    Name of the hook node.
  """
    variable = component.get_variable(variable_name)
    full_name = variable.op.name
    prefix = component.name + '/' + variable_name
    hook_name = re.sub('^' + re.escape(prefix), prefix + suffix, full_name)

    # If re.sub() did not match anything, it returns the unmodified input (i.e.,
    # |full_name|).  Enforce that some change was made.
    check.Ne(
        full_name, hook_name,
        'Failed to match expected variable prefix "{}" in variable "{}"'.
        format(prefix, full_name))

    return hook_name
예제 #2
0
def _add_hooks_for_trainable_params(component, params):
    """Adds runtime hooks for a variable of trainable parameters.

  Ignores parameters that are not statically-deducible as matrices.

  Args:
    component: Component for which to add hooks.
    params: Variable for which to add hooks.
  """
    full_name = params.op.name
    matrix = component.get_variable(var_params=params)

    # Only add hooks for tensors that are statically-deducible as matrices.
    if params.shape.ndims != 2:
        tf.logging.info('Not adding hooks for trainable params %s', full_name)
        return

    # Infer the suffix to append to variable names, if any, based on whether the
    # possibly-averaged |matrix| is named differently than the |params|.
    suffix = re.sub('^' + re.escape(full_name), '', matrix.op.name)
    check.Ne(suffix, matrix.op.name,
             'Failed to find suffix for params %s' % full_name)

    def _hook_name(base_name):
        """Returns a hook node name constructed from a base name."""
        return full_name + base_name + suffix

    # Add the matrix and its transpose.
    transposed = tf.transpose(matrix)
    _add_hook_node(matrix, _hook_name('/matrix'))
    _add_hook_node(transposed, _hook_name('/transposed'))

    # Add blocked versions of the matrix and its transpose.
    for blocked, blocked_suffix in _blocked_and_dtype_transformations(matrix):
        _add_hook_node(blocked, _hook_name('/matrix' + blocked_suffix))
    for blocked, blocked_suffix in _blocked_and_dtype_transformations(
            transposed):
        _add_hook_node(blocked, _hook_name('/transposed' + blocked_suffix))

    # Also add hooks for the original shapes, which are obscured by padding.
    _add_hook_node(tf.shape(matrix), _hook_name('/matrix/shape'))
    _add_hook_node(tf.shape(transposed), _hook_name('/transposed/shape'))
def main(unused_argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    check.NotNone(FLAGS.model_dir, '--model_dir is required')
    check.Ne(
        FLAGS.pretrain_steps is None, FLAGS.pretrain_epochs is None,
        'Exactly one of --pretrain_steps or --pretrain_epochs is required')
    check.Ne(FLAGS.train_steps is None, FLAGS.train_epochs is None,
             'Exactly one of --train_steps or --train_epochs is required')

    config_path = os.path.join(FLAGS.model_dir, 'config.txt')
    master_path = os.path.join(FLAGS.model_dir, 'master.pbtxt')
    hyperparameters_path = os.path.join(FLAGS.model_dir,
                                        'hyperparameters.pbtxt')
    targets_path = os.path.join(FLAGS.model_dir, 'targets.pbtxt')
    checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoints/best')
    tensorboard_dir = os.path.join(FLAGS.model_dir, 'tensorboard')

    with tf.gfile.FastGFile(config_path) as config_file:
        config = collections.defaultdict(bool,
                                         ast.literal_eval(config_file.read()))
    train_corpus_path = config['train_corpus_path']
    tune_corpus_path = config['tune_corpus_path']
    projectivize_train_corpus = config['projectivize_train_corpus']

    master = _read_text_proto(master_path, spec_pb2.MasterSpec)
    hyperparameters = _read_text_proto(hyperparameters_path,
                                       spec_pb2.GridPoint)
    targets = spec_builder.default_targets_from_spec(master)
    if tf.gfile.Exists(targets_path):
        targets = _read_text_proto(targets_path,
                                   spec_pb2.TrainingGridSpec).target

    # Build the TensorFlow graph.
    graph = tf.Graph()
    with graph.as_default():
        tf.set_random_seed(hyperparameters.seed)
        builder = graph_builder.MasterBuilder(master, hyperparameters)
        trainers = [
            builder.add_training_from_config(target) for target in targets
        ]
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    train_corpus = sentence_io.ConllSentenceReader(
        train_corpus_path, projectivize=projectivize_train_corpus).corpus()
    tune_corpus = sentence_io.ConllSentenceReader(tune_corpus_path,
                                                  projectivize=False).corpus()
    gold_tune_corpus = tune_corpus

    # Convert to char-based corpora, if requested.
    if config['convert_to_char_corpora']:
        # NB: Do not convert the |gold_tune_corpus|, which should remain word-based
        # for segmentation evaluation purposes.
        train_corpus = _convert_to_char_corpus(train_corpus)
        tune_corpus = _convert_to_char_corpus(tune_corpus)

    pretrain_steps = _get_steps(FLAGS.pretrain_steps, FLAGS.pretrain_epochs,
                                len(train_corpus))
    train_steps = _get_steps(FLAGS.train_steps, FLAGS.train_epochs,
                             len(train_corpus))
    check.Eq(len(targets), len(pretrain_steps),
             'Length mismatch between training targets and --pretrain_steps')
    check.Eq(len(targets), len(train_steps),
             'Length mismatch between training targets and --train_steps')

    # Ready to train!
    tf.logging.info('Training on %d sentences.', len(train_corpus))
    tf.logging.info('Tuning on %d sentences.', len(tune_corpus))

    tf.logging.info('Creating TensorFlow checkpoint dir...')
    summary_writer = trainer_lib.get_summary_writer(tensorboard_dir)

    checkpoint_dir = os.path.dirname(checkpoint_path)
    if tf.gfile.IsDirectory(checkpoint_dir):
        tf.gfile.DeleteRecursively(checkpoint_dir)
    elif tf.gfile.Exists(checkpoint_dir):
        tf.gfile.Remove(checkpoint_dir)
    tf.gfile.MakeDirs(checkpoint_dir)

    with tf.Session(FLAGS.tf_master, graph=graph) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())
        trainer_lib.run_training(sess, trainers, annotator,
                                 evaluation.parser_summaries, pretrain_steps,
                                 train_steps, train_corpus, tune_corpus,
                                 gold_tune_corpus, FLAGS.batch_size,
                                 summary_writer, FLAGS.report_every,
                                 builder.saver, checkpoint_path)

    tf.logging.info('Best checkpoint written to:\n%s', checkpoint_path)
예제 #4
0
 def testCheckNe(self):
     check.Ne(1, 2, 'foo')
     with self.assertRaisesRegexp(ValueError, 'bar'):
         check.Ne(1, 1, 'bar')
     with self.assertRaisesRegexp(RuntimeError, 'baz'):
         check.Ne(1, 1, 'baz', RuntimeError)