def testBasicTraining(self): """Test that we can learn a constant label of 0.0 for a fixed example.""" hparams = model.create_hparams( 'sequence_features=[Observation.code],' 'time_crossed_features=[Observation.code:' 'Observation.value.quantity.value:Observation.value.quantity.unit:' 'Observation.value.string],' 'time_concat_bucket_sizes=[12],' 'learning_rate=0.5') time_crossed_features = [ features.split(':') for features in hparams.time_crossed_features ] estimator = model.make_estimator(hparams, LABEL_VALUES, FLAGS.test_tmpdir) estimator.train(input_fn=model.get_input_fn( mode=tf.estimator.ModeKeys.TRAIN, input_files=[self.input_data_dir], label_name='label.length_of_stay_range.class', dedup=hparams.dedup, time_windows=hparams.time_windows, include_age=hparams.include_age, categorical_context_features=hparams.categorical_context_features, sequence_features=hparams.sequence_features, time_crossed_features=time_crossed_features, batch_size=10), steps=100) estimator.evaluate( input_fn=model.get_input_fn( mode=tf.estimator.ModeKeys.EVAL, input_files=[self.input_data_dir], label_name='label.length_of_stay_range.class', dedup=hparams.dedup, time_windows=hparams.time_windows, include_age=hparams.include_age, categorical_context_features=hparams. categorical_context_features, sequence_features=hparams.sequence_features, time_crossed_features=time_crossed_features, # Use a batch_size larger than the dataset to ensure we don't rely # on the static batch_size anywhere. batch_size=3), steps=1) results = list( estimator.predict(input_fn=model.get_input_fn( mode=tf.estimator.ModeKeys.EVAL, input_files=[self.input_data_dir], label_name='label.length_of_stay_range.class', dedup=hparams.dedup, time_windows=hparams.time_windows, include_age=hparams.include_age, categorical_context_features=hparams. categorical_context_features, sequence_features=hparams.sequence_features, time_crossed_features=time_crossed_features, batch_size=1, shuffle=False))) self.assertAllClose([0.0, 0.0, 0.0, 1.0], results[0]['probabilities'], atol=0.1)
def main(unused_argv: List[str]): hparams = model.create_hparams(FLAGS.hparams_override) tf.logging.info('Using hyperparameters %s', hparams) time_crossed_features = [ cross.split(':') for cross in hparams.time_crossed_features if cross and cross != 'n/a' ] train_input_fn = model.get_input_fn( mode=tf.estimator.ModeKeys.TRAIN, input_files=glob.glob(os.path.join(FLAGS.input_dir, 'train*')), label_name=FLAGS.label_name, dedup=hparams.dedup, time_windows=hparams.time_windows, include_age=hparams.include_age, categorical_context_features=hparams.categorical_context_features, sequence_features=hparams.sequence_features, time_crossed_features=time_crossed_features, batch_size=hparams.batch_size) eval_input_fn = model.get_input_fn( mode=tf.estimator.ModeKeys.EVAL, input_files=glob.glob(os.path.join(FLAGS.input_dir, 'validation*')), label_name=FLAGS.label_name, dedup=hparams.dedup, time_windows=hparams.time_windows, include_age=hparams.include_age, categorical_context_features=hparams.categorical_context_features, sequence_features=hparams.sequence_features, time_crossed_features=time_crossed_features, # Fixing the batch size to get comparable evaluations. batch_size=32) train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) eval_spec = tf.estimator.EvalSpec( input_fn=eval_input_fn, steps=FLAGS.num_eval_steps, throttle_secs=60) estimator = model.make_estimator( hparams, FLAGS.label_values, FLAGS.output_dir) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
def testInputFn(self, num_too_old_events): event_ids = [1528917644, 1528917645, 1528917646, 1528917647] timestamp = 1528917657 oldest_event_cutoff = event_ids[num_too_old_events] feature_map, label = model.get_input_fn( tf.estimator.ModeKeys.TRAIN, [self.input_data_dir], 'label.length_of_stay_range.class', dedup=True, time_windows=[timestamp - oldest_event_cutoff, 0], include_age=True, categorical_context_features=[], sequence_features=['Observation.code'], time_crossed_features=[[ 'Observation.code', 'Observation.value.quantity.value', 'Observation.value.quantity.unit' ]], batch_size=1, shuffle=False)() with self.test_session() as sess: sess.run(tf.tables_initializer()) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) feature_map['label'] = label results = sess.run(feature_map) self.assertAllEqual([b'above_14'], results['label']) key = (model.SEQUENCE_KEY_PREFIX + 'Observation.code-til-%d' % 0) self.assertAllEqual( # Second "loinc:6" will be de-duped. [b'loinc:2', b'loinc:4', b'loinc:6'][num_too_old_events:], results[key].values) self.assertAllEqual([[0, 0], [0, 1], [0, 2]][:3 - num_too_old_events], results[key].indices) self.assertAllEqual([1, 3 - num_too_old_events], results[key].dense_shape) cross_key = (model.SEQUENCE_KEY_PREFIX + 'Observation.code_Observation.value.quantity.value_' 'Observation.value.quantity.unit-til-0') all_loincs = [ b'loinc:2-1.000000-mg/L', b'loinc:4-2.000000-n/a', b'loinc:6-n/a-n/a' ] self.assertAllEqual(all_loincs[num_too_old_events:], results[cross_key].values) self.assertAllEqual([[0, 0], [0, 1], [0, 2]][:3 - num_too_old_events], results[cross_key].indices) self.assertAllEqual([1, 3 - num_too_old_events], results[cross_key].dense_shape) self.assertAllClose([85.505402], results[model.CONTEXT_KEY_PREFIX + model.AGE_KEY])
def testInputFnBatchDedup(self): feature_map, label = model.get_input_fn( tf.estimator.ModeKeys.TRAIN, [self.input_data_dir], 'label.length_of_stay_range.class', dedup=True, time_windows=[12, 0], include_age=True, categorical_context_features=[], sequence_features=[ 'Observation.code', 'Observation.value.quantity.value' ], time_crossed_features=[[ 'Observation.code', 'Observation.value.quantity.value', 'Observation.value.quantity.unit' ]], batch_size=2, shuffle=False)() with self.test_session() as sess: sess.run(tf.tables_initializer()) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) feature_map['label'] = label results = sess.run(feature_map) self.assertAllEqual([b'above_14', b'above_14'], results['label']) code_key = (model.SEQUENCE_KEY_PREFIX + 'Observation.code-til-%d' % 0) self.assertAllEqual( # First loinc:2 from example1 is out of range. # Second "loinc:6" will be deduped. [b'loinc:4', b'loinc:6', b'loinc:1', b'loinc:4'], results[code_key].values) # Indices are reordered on axis 1 due to deduping. self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1]], results[code_key].indices) self.assertAllEqual([2, 2], results[code_key].dense_shape) value_key = (model.SEQUENCE_KEY_PREFIX + 'Observation.value.quantity.value-til-%d' % 0) self.assertAllEqual( # First value from example1 is out of range. [2.0, 1.0, 2.0], results[value_key].values) # Indices are reordered on axis 1 due to deduping. self.assertAllEqual([[0, 0], [1, 0], [1, 1]], results[value_key].indices) self.assertAllEqual([2, 2], results[value_key].dense_shape) self.assertAllClose([85.505402, 85.505402], results[model.CONTEXT_KEY_PREFIX + model.AGE_KEY])