示例#1
0
def _get_trace_proto_string():
    trace = trace_pb2.MasterTrace()
    trace.component_trace.add(
        step_trace=[
            trace_pb2.ComponentStepTrace(fixed_feature_trace=[]),
        ],
        name='test_component',
    )
    return trace.SerializeToString()
示例#2
0
def _get_trace_proto_string():
    trace = trace_pb2.MasterTrace()
    trace.component_trace.add(
        step_trace=[
            trace_pb2.ComponentStepTrace(fixed_feature_trace=[]),
        ],
        # Google Translate says this is "component" in Chinese. (To test UTF-8).
        name='零件',
    )
    return trace.SerializeToString()
示例#3
0
  def RunFullTrainingAndInference(self,
                                  test_name,
                                  master_spec_path=None,
                                  master_spec=None,
                                  component_weights=None,
                                  unroll_using_oracle=None,
                                  num_evaluated_components=1,
                                  expected_num_actions=None,
                                  expected=None,
                                  batch_size_limit=None):
    if not master_spec:
      master_spec = self.LoadSpec(master_spec_path)

    gold_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE, gold_doc)
    gold_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_GOLD_SENTENCE_2, gold_doc_2)
    gold_reader_strings = [
        gold_doc.SerializeToString(), gold_doc_2.SerializeToString()
    ]

    test_doc = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_TEST_SENTENCE, test_doc)
    test_doc_2 = sentence_pb2.Sentence()
    text_format.Parse(_DUMMY_TEST_SENTENCE_2, test_doc_2)
    test_reader_strings = [
        test_doc.SerializeToString(), test_doc.SerializeToString(),
        test_doc_2.SerializeToString(), test_doc.SerializeToString()
    ]

    if batch_size_limit is not None:
      gold_reader_strings = gold_reader_strings[:batch_size_limit]
      test_reader_strings = test_reader_strings[:batch_size_limit]

    with tf.Graph().as_default():
      tf.set_random_seed(1)
      hyperparam_config = spec_pb2.GridPoint()
      builder = graph_builder.MasterBuilder(
          master_spec, hyperparam_config, pool_scope=test_name)
      target = spec_pb2.TrainTarget()
      target.name = 'testFullInference-train-%s' % test_name
      if component_weights:
        target.component_weights.extend(component_weights)
      else:
        target.component_weights.extend([0] * len(master_spec.component))
        target.component_weights[-1] = 1.0
      if unroll_using_oracle:
        target.unroll_using_oracle.extend(unroll_using_oracle)
      else:
        target.unroll_using_oracle.extend([False] * len(master_spec.component))
        target.unroll_using_oracle[-1] = True
      train = builder.add_training_from_config(target)
      oracle_trace = builder.add_training_from_config(
          target, prefix='train_traced-', trace_only=True)
      builder.add_saver()

      anno = builder.add_annotation(test_name)
      trace = builder.add_annotation(test_name + '-traced', enable_tracing=True)

      # Verifies that the summaries can be built.
      for component in builder.components:
        component.get_summaries()

      config = tf.ConfigProto(
          intra_op_parallelism_threads=0, inter_op_parallelism_threads=0)
      with self.test_session(config=config) as sess:
        logging.info('Initializing')
        sess.run(tf.global_variables_initializer())

        logging.info('Dry run oracle trace...')
        traces = sess.run(
            oracle_trace['traces'],
            feed_dict={oracle_trace['input_batch']: gold_reader_strings})

        # Check that the oracle traces are not empty.
        for serialized_trace in traces:
          master_trace = trace_pb2.MasterTrace()
          master_trace.ParseFromString(serialized_trace)
          self.assertTrue(master_trace.component_trace)
          self.assertTrue(master_trace.component_trace[0].step_trace)

        logging.info('Simulating training...')
        break_iter = 400
        is_resolved = False
        for i in range(0,
                       400):  # needs ~100 iterations, but is not deterministic
          cost, eval_res_val = sess.run(
              [train['cost'], train['metrics']],
              feed_dict={train['input_batch']: gold_reader_strings})
          logging.info('cost = %s', cost)
          self.assertFalse(np.isnan(cost))
          total_val = eval_res_val.reshape((-1, 2))[:, 0].sum()
          correct_val = eval_res_val.reshape((-1, 2))[:, 1].sum()
          if correct_val == total_val and not is_resolved:
            logging.info('... converged on iteration %d with (correct, total) '
                         '= (%d, %d)', i, correct_val, total_val)
            is_resolved = True
            # Run for slightly longer than convergence to help with quantized
            # weight tiebreakers.
            break_iter = i + 50

          if i == break_iter:
            break

        # If training failed, report total/correct actions for each component.
        if not expected_num_actions:
          expected_num_actions = 4 * num_evaluated_components
        if (correct_val != total_val or correct_val != expected_num_actions or
            total_val != expected_num_actions):
          for c in xrange(len(master_spec.component)):
            logging.error('component %s:\nname=%s\ntotal=%s\ncorrect=%s', c,
                          master_spec.component[c].name, eval_res_val[2 * c],
                          eval_res_val[2 * c + 1])

        assert correct_val == total_val, 'Did not converge! %d vs %d.' % (
            correct_val, total_val)

        self.assertEqual(expected_num_actions, correct_val)
        self.assertEqual(expected_num_actions, total_val)

        builder.saver.save(sess, os.path.join(FLAGS.test_tmpdir, 'model'))

        logging.info('Running test.')
        logging.info('Printing annotations')
        annotations = sess.run(
            anno['annotations'],
            feed_dict={anno['input_batch']: test_reader_strings})
        logging.info('Put %d inputs in, got %d annotations out.',
                     len(test_reader_strings), len(annotations))

        # Also run the annotation graph with tracing enabled.
        annotations_with_trace, traces = sess.run(
            [trace['annotations'], trace['traces']],
            feed_dict={trace['input_batch']: test_reader_strings})

        # The result of the two annotation graphs should be identical.
        self.assertItemsEqual(annotations, annotations_with_trace)

        # Check that the inference traces are not empty.
        for serialized_trace in traces:
          master_trace = trace_pb2.MasterTrace()
          master_trace.ParseFromString(serialized_trace)
          self.assertTrue(master_trace.component_trace)
          self.assertTrue(master_trace.component_trace[0].step_trace)

        self.assertEqual(len(test_reader_strings), len(annotations))
        pred_sentences = []
        for annotation in annotations:
          pred_sentences.append(sentence_pb2.Sentence())
          pred_sentences[-1].ParseFromString(annotation)

        if expected is None:
          expected = _TAGGER_EXPECTED_SENTENCES

        expected_sentences = [expected[i] for i in [0, 0, 1, 0]]

        for i, pred_sentence in enumerate(pred_sentences):
          self.assertProtoEquals(expected_sentences[i], pred_sentence)