Пример #1
0
 def testBasicTraining(self):
     """Test that we can learn a constant label of 0.0 for a fixed example."""
     hparams = model.create_hparams(
         'sequence_features=[Observation.code],'
         'time_crossed_features=[Observation.code:'
         'Observation.value.quantity.value:Observation.value.quantity.unit:'
         'Observation.value.string],'
         'time_concat_bucket_sizes=[12],'
         'learning_rate=0.5')
     time_crossed_features = [
         features.split(':') for features in hparams.time_crossed_features
     ]
     estimator = model.make_estimator(hparams, LABEL_VALUES,
                                      FLAGS.test_tmpdir)
     estimator.train(input_fn=model.get_input_fn(
         mode=tf.estimator.ModeKeys.TRAIN,
         input_files=[self.input_data_dir],
         label_name='label.length_of_stay_range.class',
         dedup=hparams.dedup,
         time_windows=hparams.time_windows,
         include_age=hparams.include_age,
         categorical_context_features=hparams.categorical_context_features,
         sequence_features=hparams.sequence_features,
         time_crossed_features=time_crossed_features,
         batch_size=10),
                     steps=100)
     estimator.evaluate(
         input_fn=model.get_input_fn(
             mode=tf.estimator.ModeKeys.EVAL,
             input_files=[self.input_data_dir],
             label_name='label.length_of_stay_range.class',
             dedup=hparams.dedup,
             time_windows=hparams.time_windows,
             include_age=hparams.include_age,
             categorical_context_features=hparams.
             categorical_context_features,
             sequence_features=hparams.sequence_features,
             time_crossed_features=time_crossed_features,
             # Use a batch_size larger than the dataset to ensure we don't rely
             # on the static batch_size anywhere.
             batch_size=3),
         steps=1)
     results = list(
         estimator.predict(input_fn=model.get_input_fn(
             mode=tf.estimator.ModeKeys.EVAL,
             input_files=[self.input_data_dir],
             label_name='label.length_of_stay_range.class',
             dedup=hparams.dedup,
             time_windows=hparams.time_windows,
             include_age=hparams.include_age,
             categorical_context_features=hparams.
             categorical_context_features,
             sequence_features=hparams.sequence_features,
             time_crossed_features=time_crossed_features,
             batch_size=1,
             shuffle=False)))
     self.assertAllClose([0.0, 0.0, 0.0, 1.0],
                         results[0]['probabilities'],
                         atol=0.1)
Пример #2
0
def main(unused_argv: List[str]):
  hparams = model.create_hparams(FLAGS.hparams_override)
  tf.logging.info('Using hyperparameters %s', hparams)
  time_crossed_features = [
      cross.split(':')
      for cross in hparams.time_crossed_features
      if cross and cross != 'n/a'
  ]
  train_input_fn = model.get_input_fn(
      mode=tf.estimator.ModeKeys.TRAIN,
      input_files=glob.glob(os.path.join(FLAGS.input_dir, 'train*')),
      label_name=FLAGS.label_name,
      dedup=hparams.dedup,
      time_windows=hparams.time_windows,
      include_age=hparams.include_age,
      categorical_context_features=hparams.categorical_context_features,
      sequence_features=hparams.sequence_features,
      time_crossed_features=time_crossed_features,
      batch_size=hparams.batch_size)
  eval_input_fn = model.get_input_fn(
      mode=tf.estimator.ModeKeys.EVAL,
      input_files=glob.glob(os.path.join(FLAGS.input_dir, 'validation*')),
      label_name=FLAGS.label_name,
      dedup=hparams.dedup,
      time_windows=hparams.time_windows,
      include_age=hparams.include_age,
      categorical_context_features=hparams.categorical_context_features,
      sequence_features=hparams.sequence_features,
      time_crossed_features=time_crossed_features,
      # Fixing the batch size to get comparable evaluations.
      batch_size=32)
  train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn,
                                      max_steps=FLAGS.num_train_steps)
  eval_spec = tf.estimator.EvalSpec(
      input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, throttle_secs=60)

  estimator = model.make_estimator(
      hparams, FLAGS.label_values, FLAGS.output_dir)
  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
Пример #3
0
 def testInputFn(self, num_too_old_events):
     event_ids = [1528917644, 1528917645, 1528917646, 1528917647]
     timestamp = 1528917657
     oldest_event_cutoff = event_ids[num_too_old_events]
     feature_map, label = model.get_input_fn(
         tf.estimator.ModeKeys.TRAIN, [self.input_data_dir],
         'label.length_of_stay_range.class',
         dedup=True,
         time_windows=[timestamp - oldest_event_cutoff, 0],
         include_age=True,
         categorical_context_features=[],
         sequence_features=['Observation.code'],
         time_crossed_features=[[
             'Observation.code', 'Observation.value.quantity.value',
             'Observation.value.quantity.unit'
         ]],
         batch_size=1,
         shuffle=False)()
     with self.test_session() as sess:
         sess.run(tf.tables_initializer())
         coord = tf.train.Coordinator()
         tf.train.start_queue_runners(sess=sess, coord=coord)
         feature_map['label'] = label
         results = sess.run(feature_map)
         self.assertAllEqual([b'above_14'], results['label'])
         key = (model.SEQUENCE_KEY_PREFIX + 'Observation.code-til-%d' % 0)
         self.assertAllEqual(
             #  Second "loinc:6" will be de-duped.
             [b'loinc:2', b'loinc:4', b'loinc:6'][num_too_old_events:],
             results[key].values)
         self.assertAllEqual([[0, 0], [0, 1], [0,
                                               2]][:3 - num_too_old_events],
                             results[key].indices)
         self.assertAllEqual([1, 3 - num_too_old_events],
                             results[key].dense_shape)
         cross_key = (model.SEQUENCE_KEY_PREFIX +
                      'Observation.code_Observation.value.quantity.value_'
                      'Observation.value.quantity.unit-til-0')
         all_loincs = [
             b'loinc:2-1.000000-mg/L', b'loinc:4-2.000000-n/a',
             b'loinc:6-n/a-n/a'
         ]
         self.assertAllEqual(all_loincs[num_too_old_events:],
                             results[cross_key].values)
         self.assertAllEqual([[0, 0], [0, 1], [0,
                                               2]][:3 - num_too_old_events],
                             results[cross_key].indices)
         self.assertAllEqual([1, 3 - num_too_old_events],
                             results[cross_key].dense_shape)
         self.assertAllClose([85.505402], results[model.CONTEXT_KEY_PREFIX +
                                                  model.AGE_KEY])
Пример #4
0
    def testInputFnBatchDedup(self):
        feature_map, label = model.get_input_fn(
            tf.estimator.ModeKeys.TRAIN, [self.input_data_dir],
            'label.length_of_stay_range.class',
            dedup=True,
            time_windows=[12, 0],
            include_age=True,
            categorical_context_features=[],
            sequence_features=[
                'Observation.code', 'Observation.value.quantity.value'
            ],
            time_crossed_features=[[
                'Observation.code', 'Observation.value.quantity.value',
                'Observation.value.quantity.unit'
            ]],
            batch_size=2,
            shuffle=False)()
        with self.test_session() as sess:
            sess.run(tf.tables_initializer())
            coord = tf.train.Coordinator()
            tf.train.start_queue_runners(sess=sess, coord=coord)
            feature_map['label'] = label
            results = sess.run(feature_map)
            self.assertAllEqual([b'above_14', b'above_14'], results['label'])

            code_key = (model.SEQUENCE_KEY_PREFIX +
                        'Observation.code-til-%d' % 0)
            self.assertAllEqual(
                #  First loinc:2 from example1 is out of range.
                #  Second "loinc:6" will be deduped.
                [b'loinc:4', b'loinc:6', b'loinc:1', b'loinc:4'],
                results[code_key].values)
            # Indices are reordered on axis 1 due to deduping.
            self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1]],
                                results[code_key].indices)
            self.assertAllEqual([2, 2], results[code_key].dense_shape)

            value_key = (model.SEQUENCE_KEY_PREFIX +
                         'Observation.value.quantity.value-til-%d' % 0)
            self.assertAllEqual(
                #  First value from example1 is out of range.
                [2.0, 1.0, 2.0],
                results[value_key].values)
            # Indices are reordered on axis 1 due to deduping.
            self.assertAllEqual([[0, 0], [1, 0], [1, 1]],
                                results[value_key].indices)
            self.assertAllEqual([2, 2], results[value_key].dense_shape)

            self.assertAllClose([85.505402, 85.505402],
                                results[model.CONTEXT_KEY_PREFIX +
                                        model.AGE_KEY])