コード例 #1
0
def main(argv=None):
    if gfile.Exists(FLAGS.train_dir):
        gfile.DeleteRecursively(FLAGS.train_dir)
    gfile.MakeDirs(FLAGS.train_dir)
    train()
コード例 #2
0
def _CreateCleanDirectory(path):
    if gfile.IsDirectory(path):
        gfile.DeleteRecursively(path)
    gfile.MkDir(path)
コード例 #3
0
ファイル: keras_test.py プロジェクト: zm714981790/tensorflow
 def tearDown(self):
     writer_cache.FileWriterCache.clear()
     if os.path.isdir(self._base_dir):
         gfile.DeleteRecursively(self._base_dir)
コード例 #4
0
    def testScalarsRealistically(self):
        """Test accumulator by writing values and then reading them."""
        def FakeScalarSummary(tag, value):
            value = tf.Summary.Value(tag=tag, simple_value=value)
            summary = tf.Summary(value=[value])
            return summary

        directory = os.path.join(self.get_temp_dir(), 'values_dir')
        if gfile.IsDirectory(directory):
            gfile.DeleteRecursively(directory)
        gfile.MkDir(directory)

        writer = tf.train.SummaryWriter(directory, max_queue=100)
        graph_def = tf.GraphDef(node=[tf.NodeDef(name='A', op='Mul')])
        # Add a graph to the summary writer.
        writer.add_graph(graph_def)

        # Write a bunch of events using the writer
        for i in xrange(30):
            summ_id = FakeScalarSummary('id', i)
            summ_sq = FakeScalarSummary('sq', i * i)
            writer.add_summary(summ_id, i * 5)
            writer.add_summary(summ_sq, i * 5)
        writer.flush()

        # Verify that we can load those events properly
        acc = ea.EventAccumulator(directory)
        acc.Reload()
        self.assertTagsEqual(
            acc.Tags(), {
                ea.IMAGES: [],
                ea.SCALARS: ['id', 'sq'],
                ea.HISTOGRAMS: [],
                ea.COMPRESSED_HISTOGRAMS: [],
                ea.GRAPH: True
            })
        id_events = acc.Scalars('id')
        sq_events = acc.Scalars('sq')
        self.assertEqual(30, len(id_events))
        self.assertEqual(30, len(sq_events))
        for i in xrange(30):
            self.assertEqual(i * 5, id_events[i].step)
            self.assertEqual(i * 5, sq_events[i].step)
            self.assertEqual(i, id_events[i].value)
            self.assertEqual(i * i, sq_events[i].value)

        # Write a few more events to test incremental reloading
        for i in xrange(30, 40):
            summ_id = FakeScalarSummary('id', i)
            summ_sq = FakeScalarSummary('sq', i * i)
            writer.add_summary(summ_id, i * 5)
            writer.add_summary(summ_sq, i * 5)
        writer.flush()

        # Verify we can now see all of the data
        acc.Reload()
        self.assertEqual(40, len(id_events))
        self.assertEqual(40, len(sq_events))
        for i in xrange(40):
            self.assertEqual(i * 5, id_events[i].step)
            self.assertEqual(i * 5, sq_events[i].step)
            self.assertEqual(i, id_events[i].value)
            self.assertEqual(i * i, sq_events[i].value)
コード例 #5
0
def main(argv=None):
    if gfile.Exists(TRAIN_DIR):
        gfile.DeleteRecursively(TRAIN_DIR)
    gfile.MakeDirs(TRAIN_DIR)
    train()
コード例 #6
0
 def safe_create(self, output_dir):
   if gfile.Exists(output_dir):
     gfile.DeleteRecursively(output_dir)
   gfile.MakeDirs(output_dir)
コード例 #7
0
ファイル: keras_test.py プロジェクト: yueyedeai/estimator
 def tearDown(self):
     # Make sure nothing is stuck in limbo.
     writer_cache.FileWriterCache.clear()
     if os.path.isdir(self._base_dir):
         gfile.DeleteRecursively(self._base_dir)
     super(TestKerasEstimator, self).tearDown()
コード例 #8
0
 def tearDown(self):
     self.etcd.delete_prefix(self.data_source.data_source_meta.name)
     if gfile.Exists(self.data_source.raw_data_dir):
         gfile.DeleteRecursively(self.data_source.raw_data_dir)
コード例 #9
0
    def _checkpoint_if_preempted(self):
        """Checkpoint if any worker has received a preemption signal.

    This function handles preemption signal reported by any worker in the
    cluster. The current implementation relies on the fact that all workers in a
    MultiWorkerMirroredStrategy training cluster have a step number difference
    maximum of 1.
    - If the signal comes from the worker itself (i.e., where this failure
    handler sits), the worker will notify all peers to checkpoint after they
    finish CURRENT_STEP+1 steps, where CURRENT_STEP is the step this worker has
    just finished. And the worker will wait for all peers to acknowledge that
    they have received its preemption signal and the final-step number before
    the worker proceeds on training the final step.
    - If the signal comes from another member in the cluster but NO final-step
    info is available, proceed on training, because it will be available after
    finishing the next step.
    - If the signal comes from some other member in the cluster, and final-step
    info is available, if the worker has not finished these steps yet, keep
    training; otherwise, checkpoint and exit with a cluster-recognized restart
    code.
    """
        if self._final_checkpoint_countdown:
            run_count_config_key = _FINAL_RUN_COUNT_KEY

        else:
            run_count_config_key = _INITIAL_RUN_COUNT_KEY

        if self._received_checkpoint_step.is_set():

            if self._step_to_checkpoint == str(self._run_counter):
                self._save_checkpoint()

                if self._time_to_exit():
                    self._stop_poll_termination_signal_thread()
                    self._stop_cluster_wise_termination_watcher_thread()
                    if self._api_made_checkpoint_manager and not self._is_chief:
                        gfile.DeleteRecursively(
                            os.path.dirname(
                                self._write_checkpoint_manager.directory))
                    logging.info(
                        'PreemptionCheckpointHandler: checkpoint saved. Exiting.'
                    )

                    self._exit_fn()

                else:
                    logging.info('Continue training for the grace period.')
                    self._final_checkpoint_countdown = True
                    self._received_checkpoint_step.clear()

        elif self._received_own_sigterm.is_set():
            # Only the worker who gets termination signal first among the cluster
            # will enter this branch. The following will happen in chronological
            # order:
            # 1. The worker just receives a preemption signal and enters this branch
            # for the first time. It will set a step-to-checkpoint and let the cluster
            # know.
            # 2. If there is a long grace period, it will also set
            # _final_checkpoint_countdown, so that during this grace period, it will
            # re-enter this branch to check if grace period is ending.
            # 3. If it is, set a step-to-checkpoint key again.

            if self._final_checkpoint_countdown:
                if self._target_time_for_termination < time.time():
                    logging.info(
                        'Grace period almost ended. Final call to save a checkpoint!'
                    )
                else:
                    return

            step_to_save_at = str(self._run_counter + 1)

            logging.info(
                'Termination caught in main thread on preempted worker')

            if self._local_mode:
                self._step_to_checkpoint = step_to_save_at
                self._received_checkpoint_step.set()

            else:
                context.context().set_config_key_value(run_count_config_key,
                                                       step_to_save_at)
                logging.info('%s set to %s', run_count_config_key,
                             step_to_save_at)

                if not self._local_mode:
                    worker_count = multi_worker_util.worker_count(
                        self._cluster_resolver.cluster_spec(),
                        self._cluster_resolver.task_type)
                    for i in range(worker_count):
                        context.context().get_config_key_value(
                            f'{_ACKNOWLEDGE_KEY}_{run_count_config_key}_{i}')
                        logging.info(
                            'Sigterm acknowledgement from replica %d received',
                            i)

            self._setup_countdown_if_has_grace_period_and_not_already_counting_down(
            )
コード例 #10
0
 def tearDown(self):
   gfile.DeleteRecursively(self._base_dir)
コード例 #11
0
def tearDownModule():
    gfile.DeleteRecursively(test.get_temp_dir())
コード例 #12
0

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_22(x):
    return tf.nn.max_pool(x,
                          ksize=[1, 2, 2, 1],
                          strides=[1, 2, 2, 1],
                          padding='SAME')


mnist = input_data.read_data_sets("MNISTZIP/", one_hot=True)
if gfile.Exists("E:/lmgod/Pythonws/CNNMNISTLOG"):
    gfile.DeleteRecursively("E:/lmgod/Pythonws/CNNMNISTLOG")
sess = tf.InteractiveSession()

with tf.name_scope("input"):
    x = tf.placeholder(tf.float32, [None, 784], name="x")
    y_ = tf.placeholder(tf.float32, [None, 10], name="y")
with tf.name_scope('input_reshape'):
    image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
    s1 = tf.summary.image('input', image_shaped_input, 10)
with tf.name_scope("conv1"):
    W_conv1 = weight_variable([5, 5, 1, 32])
    s2 = tf.summary.histogram("conv1_W", W_conv1)
    b_conv1 = bias_variable([32])
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
    h_pool1 = max_pool_22(h_conv1)
コード例 #13
0
ファイル: saver_test.py プロジェクト: jipatsaa/tensorflow-1
    def testNonSharded(self):
        save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_non_sharded")
        try:
            gfile.DeleteRecursively(save_dir)
        except OSError:
            pass  # Ignore
        gfile.MakeDirs(save_dir)

        with self.test_session() as sess:
            v = tf.Variable(10.0, name="v")
            save = tf.train.Saver({"v": v}, max_to_keep=2)
            tf.initialize_all_variables().run()
            self.assertEqual([], save.last_checkpoints)

            s1 = save.save(sess, os.path.join(save_dir, "s1"))
            self.assertEqual([s1], save.last_checkpoints)
            self.assertTrue(gfile.Exists(s1))

            s2 = save.save(sess, os.path.join(save_dir, "s2"))
            self.assertEqual([s1, s2], save.last_checkpoints)
            self.assertTrue(gfile.Exists(s1))
            self.assertTrue(gfile.Exists(s2))

            s3 = save.save(sess, os.path.join(save_dir, "s3"))
            self.assertEqual([s2, s3], save.last_checkpoints)
            self.assertFalse(gfile.Exists(s1))
            self.assertTrue(gfile.Exists(s2))
            self.assertTrue(gfile.Exists(s3))

            # Create a second helper, identical to the first.
            save2 = tf.train.Saver(saver_def=save.as_saver_def())
            save2.set_last_checkpoints(save.last_checkpoints)

            # Create a third helper, with the same configuration but no knowledge of
            # previous checkpoints.
            save3 = tf.train.Saver(saver_def=save.as_saver_def())

            # Exercise the first helper.

            # Adding s2 again (old s2 is removed first, then new s2 appended)
            s2 = save.save(sess, os.path.join(save_dir, "s2"))
            self.assertEqual([s3, s2], save.last_checkpoints)
            self.assertFalse(gfile.Exists(s1))
            self.assertTrue(gfile.Exists(s3))
            self.assertTrue(gfile.Exists(s2))

            # Adding s1 (s3 should now be deleted as oldest in list)
            s1 = save.save(sess, os.path.join(save_dir, "s1"))
            self.assertEqual([s2, s1], save.last_checkpoints)
            self.assertFalse(gfile.Exists(s3))
            self.assertTrue(gfile.Exists(s2))
            self.assertTrue(gfile.Exists(s1))

            # Exercise the second helper.

            # Adding s2 again (old s2 is removed first, then new s2 appended)
            s2 = save2.save(sess, os.path.join(save_dir, "s2"))
            self.assertEqual([s3, s2], save2.last_checkpoints)
            # Created by the first helper.
            self.assertTrue(gfile.Exists(s1))
            # Deleted by the first helper.
            self.assertFalse(gfile.Exists(s3))
            self.assertTrue(gfile.Exists(s2))

            # Adding s1 (s3 should now be deleted as oldest in list)
            s1 = save2.save(sess, os.path.join(save_dir, "s1"))
            self.assertEqual([s2, s1], save2.last_checkpoints)
            self.assertFalse(gfile.Exists(s3))
            self.assertTrue(gfile.Exists(s2))
            self.assertTrue(gfile.Exists(s1))

            # Exercise the third helper.

            # Adding s2 again (but helper is unaware of previous s2)
            s2 = save3.save(sess, os.path.join(save_dir, "s2"))
            self.assertEqual([s2], save3.last_checkpoints)
            # Created by the first helper.
            self.assertTrue(gfile.Exists(s1))
            # Deleted by the first helper.
            self.assertFalse(gfile.Exists(s3))
            self.assertTrue(gfile.Exists(s2))

            # Adding s1 (s3 should not be deleted because helper is unaware of it)
            s1 = save3.save(sess, os.path.join(save_dir, "s1"))
            self.assertEqual([s2, s1], save3.last_checkpoints)
            self.assertFalse(gfile.Exists(s3))
            self.assertTrue(gfile.Exists(s2))
            self.assertTrue(gfile.Exists(s1))
コード例 #14
0
ファイル: train.py プロジェクト: avnishr/KFP-MNIST
def train():

    print(ARGS.log_dir)
    if gfile.Exists(ARGS.log_dir):
        gfile.DeleteRecursively(ARGS.log_dir)
    gfile.MakeDirs(ARGS.log_dir)

    log_file = ARGS.log_dir + '/train'
    print("The log file is ", log_file)

    # load dataset
    (trainX, trainy), (testX, testy) = fashion_mnist.load_data()

    #Data Normalization - Dividing by 255 as the maximum possible value

    trainX = trainX / 255

    testX = testX / 255

    trainX = trainX.reshape(trainX.shape[0], 28, 28, 1)

    testX = testX.reshape(testX.shape[0], 28, 28, 1)

    cnn = tf.keras.models.Sequential()

    cnn.add(
        tf.keras.layers.Conv2D(32, (3, 3),
                               activation='relu',
                               input_shape=(28, 28, 1)))

    cnn.add(tf.keras.layers.MaxPooling2D(2, 2))

    cnn.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))

    cnn.add(tf.keras.layers.Flatten())

    cnn.add(tf.keras.layers.Dense(64, activation='relu'))

    cnn.add(tf.keras.layers.Dense(10, activation='softmax'))

    print(ARGS)

    optimizer = tf.keras.optimizers.get(ARGS.optimizer_name)

    optimizer.learning_rate = ARGS.learning_rate

    optimizer.momentum = ARGS.momentum

    print(optimizer)

    cnn.compile(optimizer=optimizer,
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    cnn.summary()

    cnn.fit(trainX, trainy, epochs=ARGS.epochs, batch_size=ARGS.batch_size)

    test_loss, test_acc = cnn.evaluate(testX, testy, verbose=2)

    writer = tf.summary.create_file_writer(log_file)

    with writer.as_default():
        tf.summary.scalar('accuracy', test_acc, step=1)
        writer.flush()

    print("accuracy={}".format(test_acc))
コード例 #15
0
  def testTrainWithInitFromCheckpoint(self):
    logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
    logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')

    if gfile.Exists(logdir1):  # For running on jenkins.
      gfile.DeleteRecursively(logdir1)
    if gfile.Exists(logdir2):  # For running on jenkins.
      gfile.DeleteRecursively(logdir2)

    # First, train the model one step (make sure the error is high).
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      train_op = self.create_train_op()
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir1,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir1, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=1),
          ],
          save_checkpoint_secs=None,
          save_summaries_steps=None)
      self.assertGreater(loss, .5)

    # Next, train the model to convergence.
    with ops.Graph().as_default():
      random_seed.set_random_seed(1)
      train_op = self.create_train_op()
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir1,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir1, save_steps=300, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=300),
          ],
          save_checkpoint_secs=None,
          save_summaries_steps=None)
      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)

    # Finally, advance the model a single step and validate that the loss is
    # still low.
    with ops.Graph().as_default():
      random_seed.set_random_seed(2)
      train_op = self.create_train_op()

      model_variables = variables_lib2.global_variables()
      model_path = checkpoint_management.latest_checkpoint(logdir1)

      assign_fn = variables_lib.assign_from_checkpoint_fn(
          model_path, model_variables)

      def init_fn(_, session):
        assign_fn(session)

      loss = training.train(
          train_op,
          None,
          scaffold=monitored_session.Scaffold(init_fn=init_fn),
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)],
          save_checkpoint_secs=None,
          save_summaries_steps=None)

      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)
コード例 #16
0
 def tearDown(self):
     if self._model_dir:
         gfile.DeleteRecursively(self._model_dir)
     self._model_dir = None
     self._estimator = None
コード例 #17
0
  def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
    logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
    if gfile.Exists(logdir):  # For running on jenkins.
      gfile.DeleteRecursively(logdir)

    # First, train only the weights of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      weights = variables_lib.get_variables_by_name('weights')

      train_op = training.create_train_op(
          total_loss, optimizer, variables_to_train=weights)

      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=200, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=200),
          ],
          save_checkpoint_secs=None,
          save_summaries_steps=None)
      self.assertGreater(loss, .015)
      self.assertLess(loss, .05)

    # Next, train the biases of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(1)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      biases = variables_lib.get_variables_by_name('biases')

      train_op = training.create_train_op(
          total_loss, optimizer, variables_to_train=biases)

      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=300, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=300),
          ],
          save_checkpoint_secs=None,
          save_summaries_steps=None)
      self.assertGreater(loss, .015)
      self.assertLess(loss, .05)

    # Finally, train both weights and bias to get lower loss.
    with ops.Graph().as_default():
      random_seed.set_random_seed(2)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer)
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.StopAtStepHook(num_steps=400),
          ],
          save_checkpoint_secs=None,
          save_summaries_steps=None)
      self.assertIsNotNone(loss)
      self.assertLess(loss, .015)
コード例 #18
0
def main(argv=None):

    if gfile.Exists(FLAGS.train_dir):
        gfile.DeleteRecursively(FLAGS.train_dir)
        gfile.MakeDirs(FLAGS.train_dir)

    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)
        # Get images, sentences, labels batch for RN
        images, sentences, answers = generate_batch(
            batch_size=FLAGS.batch_size, flag='train')
        # preprocess
        images = tf.image.resize_bilinear(images, size=[128, 128])
        images = tf.image.resize_image_with_crop_or_pad(images,
                                                        target_height=136,
                                                        target_width=136)
        images = tf.random_crop(images, size=[FLAGS.batch_size, 128, 128, 3])
        rota_range = 0.05  # rads
        rota_range = rota_range * random.random()
        images = tf.contrib.image.rotate(images, angles=rota_range)

        tf.summary.image('image', tf.expand_dims(images[0], axis=0))
        print(images)
        print(sentences)
        print(answers)

        # Build a graph that computer the logits predictions from inference model
        #        logits, cnn_features = model.inference(images, sentences, answers)
        logits, predict_labels = model.inference_demo(images, sentences,
                                                      answers,
                                                      FLAGS.batch_size)
        # Calculate loss
        loss = model.loss(logits, answers)

        # gradCAM
        #        softmax_linear = tf.nn.softmax(logits, dim = -1)[0, :]
        ##        print (softmax_linear)
        ##        print (softmax_linear[answers[0]])
        ##        print (loss)
        #        grads = tf.gradients(softmax_linear[answers[0]], cnn_features)
        #        print (grads, 'grads')
        ##        grads_oneimage = tf.expand_dims(grads[0], 0) # summary_image request [batch_size, height, width, nchannels]
        #        grads_oneimage = tf.maximum(grads[0][0], 0)
        #        grads_max = tf.reduce_max(grads_oneimage)
        #        importance_weights = tf.reduce_mean(tf.reduce_mean(grads_oneimage, 0, keep_dims=True), 1, keep_dims=True)
        #        print (importance_weights.shape, 'importance_weights.shape')
        #        grads_oneimage = tf.multiply(importance_weights, grads_oneimage)
        #        grads_oneimage = tf.reduce_mean(grads_oneimage, 2, keep_dims = True) / grads_max * 255
        #        print (grads_oneimage, 'grads_oneimage')
        #        grads = tf.gradients(softmax_linear[answers[0]], images)
        #        backpro_oneimage = tf.maximum(grads[0][0], 0)
        #        backpro_oneimage = tf.reduce_mean(backpro_oneimage, 2, keep_dims=True) / tf.reduce_max(backpro_oneimage) * 255
        #        grads_oneimage = tf.squeeze(tf.image.resize_bilinear(tf.expand_dims(grads_oneimage, 0), size = [320, 480]), axis = 0)

        #        grad_cam = tf.multiply(grads_oneimage, backpro_oneimage)
        #        tf.summary.image('grad-cam', tf.expand_dims(grad_cam, 0))

        #        acc = model.accuracy(logits, answers, embedding_matrix)

        #        acc = model.accuracy(logits, answers)
        acc = tf.reduce_mean(
            tf.cast(tf.equal(predict_labels, tf.cast(answers, tf.int64)),
                    tf.float32))
        tf.summary.scalar('accuracy', acc)
        # Build a graph that trains the model with one batch of examples and updates the model parameters
        train_op = model.train(loss, global_step)
        # Create a Saver
        saver = tf.train.Saver(tf.global_variables())
        # Build the summary operation based on the TF collection of Summaries
        summary_op = tf.summary.merge_all()
        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()
        # Start running operations on the Graph.
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9,
                                    allow_growth=True)
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False,
                                                gpu_options=gpu_options))

        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            print('Restore from ' + ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(init)

        tf.train.start_queue_runners(sess=sess)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)

        start_time = time.time()

        max_steps = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size)
        for epoch in range(FLAGS.num_epoch):
            average_loss = 0
            for step in xrange(max_steps):

                #                epoch_num = int(step * FLAGS.batch_size / NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN)
                #                step_tmp = int(step % int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN/FLAGS.batch_size))

                start_time = time.time()

                #            s = sess.run([sentences])
                #            print (s)
                #            _, loss_value, accuracy_value = sess.run([train_op, loss, acc])
                answers_value, _, loss_value, accuracy_value, predict_labels_value = sess.run(
                    [answers, train_op, loss, acc, predict_labels])

                #            print (logits_value[0])
                #            print (softmax_linear_value[0])
                # test dataset
                #            for b in range(FLAGS.batch_size):
                #                imsave('pic/' + str(step) + '_' + str(b) + '.bmp', images_value[b])
                #                for i in range(MX_LEN):
                #    #                    f.write(str(sentences_value[0, i]) + ' ')
                #                    if sentences_value[b, i] != 0 :
                #                        print question_index_to_word[str(sentences_value[b, i])], # python2
                #                print ('\n')
                #                print (answer_index_to_word[str(answers_value[b])])
                ##                print ('\n!')

                duration = time.time() - start_time

                #            for i in range(FLAGS.batch_size):
                #                # display
                #                prediction_str = str(predict_labels_value[i])
                #                label_str = str(answers_value[i])
                #                print (answer_index_to_word[prediction_str], answer_index_to_word[label_str])

                average_loss += loss_value

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'
                if step % 1 == 0:
                    #                num_examples_per_step = FLAGS.batch_size
                    examples_per_sec = FLAGS.batch_size / duration
                    #                sec_per_batch = float(duration)

                    format_str = (
                        '%s: epoch %d, step %d/%d, %.2f, loss = %.5f, average_loss = %.5f, accuracy = %.5f'
                    )
                    print(format_str %
                          (datetime.now(), epoch, step,
                           int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
                               FLAGS.batch_size), examples_per_sec, loss_value,
                           average_loss / (step + 1), accuracy_value))
    #                format_str = ('%s: step %d, loss = %d')
    #                print (format_str % (datetime.now(), step, loss_value))

                if step % 100 == 0:
                    # print sentences to file
                    #                f = open(sentence_log, 'a')
                    #                f.write('step: ' + str(step) + ' ')
                    #                for i in range(MX_LEN):
                    ##                    f.write(str(sentences_value[0, i]) + ' ')
                    #                    if sentences_value[0, i] != 0 :
                    #                        f.write(question_index_to_word[str(sentences_value[0, i])] + ' ')
                    #                f.write(answer_index_to_word[str(answers_value[0])])
                    #                imsave('pic/' + str(step) + '.bmp', images_value[0])
                    #                f.write('\n')
                    #                f.close()

                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str, step)

                if step % 1000 == 0:
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=epoch * max_steps + step)
コード例 #19
0
    def test_export_savedmodel(self):
        tmpdir = tempfile.mkdtemp()
        est, serving_input_fn = _build_estimator_for_export_tests(tmpdir)

        extra_file_name = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('my_extra_file'))
        extra_file = gfile.GFile(extra_file_name, mode='w')
        extra_file.write(EXTRA_FILE_CONTENT)
        extra_file.close()
        assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}

        export_dir_base = os.path.join(compat.as_bytes(tmpdir),
                                       compat.as_bytes('export'))
        export_dir = est.export_savedmodel(export_dir_base,
                                           serving_input_fn,
                                           assets_extra=assets_extra)

        self.assertTrue(gfile.Exists(export_dir_base))
        self.assertTrue(gfile.Exists(export_dir))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('saved_model.pb'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('variables'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('variables/variables.index'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(
                    compat.as_bytes(export_dir),
                    compat.as_bytes(
                        'variables/variables.data-00000-of-00001'))))

        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('assets'))))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('assets/my_vocab_file'))))
        self.assertEqual(
            compat.as_bytes(VOCAB_FILE_CONTENT),
            compat.as_bytes(
                gfile.GFile(
                    os.path.join(
                        compat.as_bytes(export_dir),
                        compat.as_bytes('assets/my_vocab_file'))).read()))

        expected_extra_path = os.path.join(
            compat.as_bytes(export_dir),
            compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
        self.assertTrue(
            gfile.Exists(
                os.path.join(compat.as_bytes(export_dir),
                             compat.as_bytes('assets.extra'))))
        self.assertTrue(gfile.Exists(expected_extra_path))
        self.assertEqual(
            compat.as_bytes(EXTRA_FILE_CONTENT),
            compat.as_bytes(gfile.GFile(expected_extra_path).read()))

        expected_vocab_file = os.path.join(compat.as_bytes(tmpdir),
                                           compat.as_bytes('my_vocab_file'))
        # Restore, to validate that the export was well-formed.
        with ops.Graph().as_default() as graph:
            with session_lib.Session(graph=graph) as sess:
                loader.load(sess, [tag_constants.SERVING], export_dir)
                assets = [
                    x.eval() for x in graph.get_collection(
                        ops.GraphKeys.ASSET_FILEPATHS)
                ]
                self.assertItemsEqual([expected_vocab_file], assets)
                graph_ops = [x.name for x in graph.get_operations()]
                self.assertTrue('input_example_tensor' in graph_ops)
                self.assertTrue('ParseExample/ParseExample' in graph_ops)
                self.assertTrue('linear/linear/feature/matmul' in graph_ops)

        # cleanup
        gfile.DeleteRecursively(tmpdir)
コード例 #20
0
def main(_):
    if gfile.Exists(FLAGS.log_dir):
        gfile.DeleteRecursively(FLAGS.log_dir)
    gfile.MakeDirs(FLAGS.log_dir)
    run_train()
コード例 #21
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)

        # Get images and labels for CIFAR-10.
        # images, labels = cifar10.distorted_inputs()
        images, labels = cifar10.read_and_decode("imagetrain.tfrecords")

        images, labels = cifar10.distort_inputs_train(images, labels)
        tf.summary.image('input_image', images, 20)

        # Build a Graph that computes the logits predictions from the
        # inference model.

        inception_resnet_v2_arg_scope = cifar10.inception_resnet_v2_arg_scope

        with slim.arg_scope(inception_resnet_v2_arg_scope(is_training=True)):

            logits, end_points = cifar10.inference(images,
                                                   num_classes=3,
                                                   is_training=True)

        # Calculate loss.
        loss = cifar10.loss(logits=logits, labels=labels)
        train__acc = evaluation(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step)

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables())

        # Build the summary operation based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # Start running operations on the Graph.
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement))

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()
        sess.run(init)

        ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("restore from file")
        else:
            print('No checkpoint file found')

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value, accuracy = sess.run([train_op, loss, train__acc])
            duration = time.time() - start_time

            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

            if step % 10 == 0:
                num_examples_per_step = FLAGS.batch_size
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = float(duration)

                format_str = (
                    '%s: step %d, loss = %.5f (%.1f examples/sec; %.3f '
                    'sec/batch), train accuracy = %.2f%%')
                print(format_str %
                      (datetime.now(), step, loss_value, examples_per_sec,
                       sec_per_batch, accuracy * 100))

            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            # Save the model checkpoint periodically.
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

                if gfile.Exists(FLAGS.eval_dir):
                    gfile.DeleteRecursively(FLAGS.eval_dir)
                gfile.MakeDirs(FLAGS.eval_dir)
                bn_ince_resnet_eval.evaluate()

        summary_writer.close()
コード例 #22
0
def main(argv=None):  # pylint: disable=unused-argument
  if gfile.Exists(FLAGS.train_dir):
    gfile.DeleteRecursively(FLAGS.train_dir)
  gfile.MakeDirs(FLAGS.train_dir)
  train()
コード例 #23
0
  def export(self,
             export_dir_base,
             global_step_tensor,
             sess=None,
             exports_to_keep=None):
    """Exports the model.

    Args:
      export_dir_base: A string path to the base export dir.
      global_step_tensor: An Tensor or tensor name providing the
        global step counter to append to the export directory path and set
        in the manifest version.
      sess: A Session to use to save the parameters.
      exports_to_keep: a gc.Path filter function used to determine the set of
        exports to keep. If set to None, all versions will be kept.

    Returns:
      The string path to the exported directory.

    Raises:
      RuntimeError: if init is not called.
      RuntimeError: if the export would overwrite an existing directory.
    """
    if not self._has_init:
      raise RuntimeError("init must be called first")

    # Export dir must not end with / or it will break exports to keep. Strip /.
    if export_dir_base.endswith("/"):
      export_dir_base = export_dir_base[:-1]

    global_step = training_util.global_step(sess, global_step_tensor)
    export_dir = os.path.join(
        compat.as_bytes(export_dir_base),
        compat.as_bytes(constants.VERSION_FORMAT_SPECIFIER % global_step))

    # Prevent overwriting on existing exports which could lead to bad/corrupt
    # storage and loading of models. This is an important check that must be
    # done before any output files or directories are created.
    if gfile.Exists(export_dir):
      raise RuntimeError("Overwriting exports can cause corruption and are "
                         "not allowed. Duplicate export dir: %s" % export_dir)

    # Output to a temporary directory which is atomically renamed to the final
    # directory when complete.
    tmp_export_dir = compat.as_text(export_dir) + "-tmp"
    gfile.MakeDirs(tmp_export_dir)

    self._saver.save(
        sess,
        os.path.join(
            compat.as_text(tmp_export_dir),
            compat.as_text(constants.EXPORT_BASE_NAME)),
        meta_graph_suffix=constants.EXPORT_SUFFIX_NAME)

    # Run the asset callback.
    if self._assets_callback and self._assets_to_copy:
      assets_dir = os.path.join(
          compat.as_bytes(tmp_export_dir),
          compat.as_bytes(constants.ASSETS_DIRECTORY))
      gfile.MakeDirs(assets_dir)
      self._assets_callback(self._assets_to_copy, assets_dir)

    # TODO(b/27794910): Delete *checkpoint* file before rename.
    gfile.Rename(tmp_export_dir, export_dir)

    if exports_to_keep:
      # create a simple parser that pulls the export_version from the directory.
      def parser(path):
        if os.name == "nt":
          match = re.match(
              r"^" + export_dir_base.replace("\\", "/") + r"/(\d{8})$",
              path.path.replace("\\", "/"))
        else:
          match = re.match(r"^" + export_dir_base + r"/(\d{8})$", path.path)
        if not match:
          return None
        return path._replace(export_version=int(match.group(1)))

      paths_to_delete = gc.negation(exports_to_keep)
      for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
        gfile.DeleteRecursively(p.path)

    return export_dir
コード例 #24
0
def run_and_gather_logs(name, test_name, test_args, benchmark_type):
    """Run the bazel test given by test_name.  Gather and return the logs.

  Args:
    name: Benchmark target identifier.
    test_name: A unique bazel target, e.g. "//path/to:test"
    test_args: A string containing all arguments to run the target with.
    benchmark_type: A string representing the BenchmarkType enum; the
      benchmark type for this target.

  Returns:
    A tuple (test_results, mangled_test_name), where
    test_results: A test_log_pb2.TestResults proto
    mangled_test_name: A string, the mangled test name.

  Raises:
    ValueError: If the test_name is not a valid target.
    subprocess.CalledProcessError: If the target itself fails.
    IOError: If there are problems gathering test log output from the test.
    MissingLogsError: If we couldn't find benchmark logs.
  """
    if not (test_name and test_name.startswith("//") and ".." not in test_name
            and not test_name.endswith(":") and not test_name.endswith(":all")
            and not test_name.endswith("...")
            and len(test_name.split(":")) == 2):
        raise ValueError(
            "Expected test_name parameter with a unique test, e.g.: "
            "--test_name=//path/to:test")
    test_executable = test_name.rstrip().strip("/").replace(":", "/")

    if gfile.Exists(os.path.join("bazel-bin", test_executable)):
        # Running in standalone mode from core of the repository
        test_executable = os.path.join("bazel-bin", test_executable)
    else:
        # Hopefully running in sandboxed mode
        test_executable = os.path.join(".", test_executable)

    test_adjusted_name = name
    gpu_config = gpu_info_lib.gather_gpu_devices()
    if gpu_config:
        gpu_name = gpu_config[0].model
        gpu_short_name_match = re.search(r"Tesla [KP][4,8]0", gpu_name)
        if gpu_short_name_match:
            gpu_short_name = gpu_short_name_match.group(0)
            test_adjusted_name = name + "|" + gpu_short_name.replace(" ", "_")

    temp_directory = tempfile.mkdtemp(prefix="run_and_gather_logs")
    mangled_test_name = (test_adjusted_name.strip("/").replace(
        "|", "_").replace("/", "_").replace(":", "_"))
    test_file_prefix = os.path.join(temp_directory, mangled_test_name)
    test_file_prefix = "%s." % test_file_prefix

    try:
        if not gfile.Exists(test_executable):
            raise ValueError("Executable does not exist: %s" % test_executable)
        test_args = shlex.split(test_args)

        # This key is defined in tf/core/util/reporter.h as
        # TestReporter::kTestReporterEnv.
        os.environ["TEST_REPORT_FILE_PREFIX"] = test_file_prefix
        start_time = time.time()
        subprocess.check_call([test_executable] + test_args)
        run_time = time.time() - start_time
        log_files = gfile.Glob("{}*".format(test_file_prefix))
        if not log_files:
            raise MissingLogsError("No log files found at %s." %
                                   test_file_prefix)

        return (process_test_logs(test_adjusted_name,
                                  test_name=test_name,
                                  test_args=test_args,
                                  benchmark_type=benchmark_type,
                                  start_time=int(start_time),
                                  run_time=run_time,
                                  log_files=log_files), mangled_test_name)

    finally:
        try:
            gfile.DeleteRecursively(temp_directory)
        except OSError:
            pass
コード例 #25
0
 def _tear_down(self):
   gfile.DeleteRecursively(self._temp_dir)
コード例 #26
0
FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string(
    'train_dir', 'information_pursue_train',
    """Directory where to write event logs """
    """and checkpoint.""")
tf.app.flags.DEFINE_integer('model_nums', 5, """Number of models to train""")
tf.app.flags.DEFINE_integer('batch_size', 28,
                            """Number of images to process in a batch.""")
tf.app.flags.DEFINE_string(
    'data_dir', 'information_pursue_data',
    """Path to the information_pursue data directory.""")
tf.app.flags.DEFINE_string('images_path', "256_ObjectCategories/",
                           "The images's path")

if __name__ == '__main__':
    if gfile.Exists(FLAGS.train_dir):
        gfile.DeleteRecursively(FLAGS.train_dir)
    gfile.MakeDirs(FLAGS.train_dir)

    if not gfile.Exists(FLAGS.data_dir):
        gfile.MakeDirs(FLAGS.data_dir)

    information_pursue_model = ip.Information_pursue(
        dataset_percent=20,
        data_dir=FLAGS.data_dir,
        images_path=FLAGS.images_path,
        model_nums=FLAGS.model_nums,
        train_dir=FLAGS.train_dir,
        batch_size=FLAGS.batch_size)
    information_pursue_model.run()
コード例 #27
0
    def testScalarsRealistically(self):
        """Test accumulator by writing values and then reading them."""
        def FakeScalarSummary(tag, value):
            value = tf.Summary.Value(tag=tag, simple_value=value)
            summary = tf.Summary(value=[value])
            return summary

        directory = os.path.join(self.get_temp_dir(), 'values_dir')
        if gfile.IsDirectory(directory):
            gfile.DeleteRecursively(directory)
        gfile.MkDir(directory)

        writer = tf.train.SummaryWriter(directory, max_queue=100)

        with tf.Graph().as_default() as graph:
            _ = tf.constant([2.0, 1.0])
        # Add a graph to the summary writer.
        writer.add_graph(graph)

        run_metadata = tf.RunMetadata()
        device_stats = run_metadata.step_stats.dev_stats.add()
        device_stats.device = 'test device'
        writer.add_run_metadata(run_metadata, 'test run')

        # Write a bunch of events using the writer
        for i in xrange(30):
            summ_id = FakeScalarSummary('id', i)
            summ_sq = FakeScalarSummary('sq', i * i)
            writer.add_summary(summ_id, i * 5)
            writer.add_summary(summ_sq, i * 5)
        writer.flush()

        # Verify that we can load those events properly
        acc = ea.EventAccumulator(directory)
        acc.Reload()
        self.assertTagsEqual(
            acc.Tags(), {
                ea.IMAGES: [],
                ea.SCALARS: ['id', 'sq'],
                ea.HISTOGRAMS: [],
                ea.COMPRESSED_HISTOGRAMS: [],
                ea.GRAPH: True,
                ea.RUN_METADATA: ['test run']
            })
        id_events = acc.Scalars('id')
        sq_events = acc.Scalars('sq')
        self.assertEqual(30, len(id_events))
        self.assertEqual(30, len(sq_events))
        for i in xrange(30):
            self.assertEqual(i * 5, id_events[i].step)
            self.assertEqual(i * 5, sq_events[i].step)
            self.assertEqual(i, id_events[i].value)
            self.assertEqual(i * i, sq_events[i].value)

        # Write a few more events to test incremental reloading
        for i in xrange(30, 40):
            summ_id = FakeScalarSummary('id', i)
            summ_sq = FakeScalarSummary('sq', i * i)
            writer.add_summary(summ_id, i * 5)
            writer.add_summary(summ_sq, i * 5)
        writer.flush()

        # Verify we can now see all of the data
        acc.Reload()
        self.assertEqual(40, len(id_events))
        self.assertEqual(40, len(sq_events))
        for i in xrange(40):
            self.assertEqual(i * 5, id_events[i].step)
            self.assertEqual(i * 5, sq_events[i].step)
            self.assertEqual(i, id_events[i].value)
            self.assertEqual(i * i, sq_events[i].value)
        self.assertProtoEquals(graph.as_graph_def(add_shapes=True),
                               acc.Graph())
コード例 #28
0
 def tearDown(self):
     if gfile.Exists(self.data_source.example_dumped_dir):
         gfile.DeleteRecursively(self.data_source.example_dumped_dir)
コード例 #29
0
def main(argv=None):  # pylint: disable=unused-argument
    cifar10.maybe_download_and_extract()
    if gfile.Exists(FLAGS.train_dir):
        gfile.DeleteRecursively(FLAGS.train_dir)
    gfile.MakeDirs(FLAGS.train_dir)
    train()
コード例 #30
0
def main(unused_argv):
    logging.set_verbosity(logging.INFO)
    check.IsTrue(FLAGS.checkpoint_filename)
    check.IsTrue(FLAGS.tensorboard_dir)
    check.IsTrue(FLAGS.resource_path)

    if not gfile.IsDirectory(FLAGS.resource_path):
        gfile.MakeDirs(FLAGS.resource_path)

    training_corpus_path = gfile.Glob(FLAGS.training_corpus_path)[0]
    tune_corpus_path = gfile.Glob(FLAGS.tune_corpus_path)[0]

    # SummaryWriter for TensorBoard
    tf.logging.info('TensorBoard directory: "%s"', FLAGS.tensorboard_dir)
    tf.logging.info('Deleting prior data if exists...')

    stats_file = '%s.stats' % FLAGS.checkpoint_filename
    try:
        stats = gfile.GFile(stats_file, 'r').readlines()[0].split(',')
        stats = [int(x) for x in stats]
    except errors.OpError:
        stats = [-1, 0, 0]

    tf.logging.info('Read ckpt stats: %s', str(stats))
    do_restore = True
    if stats[0] < FLAGS.job_id:
        do_restore = False
        tf.logging.info('Deleting last job: %d', stats[0])
        try:
            gfile.DeleteRecursively(FLAGS.tensorboard_dir)
            gfile.Remove(FLAGS.checkpoint_filename)
        except errors.OpError as err:
            tf.logging.error('Unable to delete prior files: %s', err)
        stats = [FLAGS.job_id, 0, 0]

    tf.logging.info('Creating the directory again...')
    gfile.MakeDirs(FLAGS.tensorboard_dir)
    tf.logging.info('Created! Instatiating SummaryWriter...')
    summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)
    tf.logging.info('Creating TensorFlow checkpoint dir...')
    gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))

    # Constructs lexical resources for SyntaxNet in the given resource path, from
    # the training data.
    if FLAGS.compute_lexicon:
        logging.info('Computing lexicon...')
        lexicon.build_lexicon(FLAGS.resource_path,
                              training_corpus_path,
                              morph_to_pos=True)

    tf.logging.info('Loading MasterSpec...')
    master_spec = spec_pb2.MasterSpec()
    with gfile.FastGFile(FLAGS.dragnn_spec, 'r') as fin:
        text_format.Parse(fin.read(), master_spec)
    spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_path)
    logging.info('Constructed master spec: %s', str(master_spec))
    hyperparam_config = spec_pb2.GridPoint()

    # Build the TensorFlow graph.
    tf.logging.info('Building Graph...')
    hyperparam_config = spec_pb2.GridPoint()
    try:
        text_format.Parse(FLAGS.hyperparams, hyperparam_config)
    except text_format.ParseError:
        text_format.Parse(base64.b64decode(FLAGS.hyperparams),
                          hyperparam_config)
    g = tf.Graph()
    with g.as_default():
        builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
        component_targets = [
            spec_pb2.TrainTarget(name=component.name,
                                 max_index=idx + 1,
                                 unroll_using_oracle=[False] * idx + [True])
            for idx, component in enumerate(master_spec.component)
            if 'shift-only' not in component.transition_system.registered_name
        ]
        trainers = [
            builder.add_training_from_config(target)
            for target in component_targets
        ]
        annotator = builder.add_annotation()
        builder.add_saver()

    # Read in serialized protos from training data.
    training_set = ConllSentenceReader(
        training_corpus_path,
        projectivize=FLAGS.projectivize_training_set,
        morph_to_pos=True).corpus()
    tune_set = ConllSentenceReader(tune_corpus_path,
                                   projectivize=False,
                                   morph_to_pos=True).corpus()

    # Ready to train!
    logging.info('Training on %d sentences.', len(training_set))
    logging.info('Tuning on %d sentences.', len(tune_set))

    pretrain_steps = [10000, 0]
    tagger_steps = 100000
    train_steps = [tagger_steps, 8 * tagger_steps]

    with tf.Session(FLAGS.tf_master, graph=g) as sess:
        # Make sure to re-initialize all underlying state.
        sess.run(tf.global_variables_initializer())

        if do_restore:
            tf.logging.info('Restoring from checkpoint...')
            builder.saver.restore(sess, FLAGS.checkpoint_filename)

            prev_tagger_steps = stats[1]
            prev_parser_steps = stats[2]
            tf.logging.info('adjusting schedule from steps: %d, %d',
                            prev_tagger_steps, prev_parser_steps)
            pretrain_steps[0] = max(pretrain_steps[0] - prev_tagger_steps, 0)
            tf.logging.info('new pretrain steps: %d', pretrain_steps[0])

        trainer_lib.run_training(sess, trainers, annotator,
                                 evaluation.parser_summaries, pretrain_steps,
                                 train_steps, training_set, tune_set, tune_set,
                                 FLAGS.batch_size, summary_writer,
                                 FLAGS.report_every, builder.saver,
                                 FLAGS.checkpoint_filename, stats)