def test_get_dataset_raises_error_for_empty_data_split(self): dataset_config_pbtext_filename = _test_dataset_config( 'test_get_dataset_raises_error_for_empty_data_split.pbtxt', name='some_dataset_name') expected_exception_message = ( 'The dataset in the config {} does not ' 'have a tfrecord_path.'.format(dataset_config_pbtext_filename)) with self.assertRaisesRegexp(ValueError, expected_exception_message): data_providers.get_dataset(dataset_config_pbtext_filename)
def test_get_dataset_raises_error_for_empty_data_split(self): dataset_config_pbtext_filename = _test_dataset_config( 'test_get_dataset_raises_error_for_empty_data_split.pbtxt', name='some_dataset_name') expected_exception_message = ('The dataset in the config {} does not ' 'have a tfrecord_path.' .format(dataset_config_pbtext_filename)) with self.assertRaisesRegexp(ValueError, expected_exception_message): data_providers.get_dataset(dataset_config_pbtext_filename)
def test_get_dataset_raises_error_for_empty_num_examples(self): dataset_config_pbtext_filename = _test_dataset_config( 'test_get_dataset_raises_error_for_empty_num_examples.pbtxt', name='some_dataset_name', tfrecord_path='/path/to/dataset') expected_exception_message = ( 'The dataset in the config {} does not have ' 'a num_examples.'.format(dataset_config_pbtext_filename)) with self.assertRaisesRegexp(ValueError, expected_exception_message): data_providers.get_dataset(dataset_config_pbtext_filename)
def test_get_dataset_raises_error_for_empty_num_examples(self): dataset_config_pbtext_filename = _test_dataset_config( 'test_get_dataset_raises_error_for_empty_num_examples.pbtxt', name='some_dataset_name', tfrecord_path='/path/to/dataset') expected_exception_message = ('The dataset in the config {} does not have ' 'a num_examples.' .format(dataset_config_pbtext_filename)) with self.assertRaisesRegexp(ValueError, expected_exception_message): data_providers.get_dataset(dataset_config_pbtext_filename)
def test_get_dataset(self): dataset_config_pbtext_filename = _test_dataset_config( 'golden.dataset_config.pbtxt', name='some_dataset_name', tfrecord_path='/path/to/dataset', num_examples=1000) ds = data_providers.get_dataset(dataset_config_pbtext_filename, tensor_shape=[3, 4, 7]) self.assertEqual('some_dataset_name', ds.name) self.assertEqual('/path/to/dataset', ds.source) self.assertEqual(1000, ds.num_examples) self.assertEqual([3, 4, 7], ds.tensor_shape)
def test_get_dataset(self): dataset_config_pbtext_filename = _test_dataset_config( 'golden.dataset_config.pbtxt', name='some_dataset_name', tfrecord_path='/path/to/dataset', num_examples=1000) ds = data_providers.get_dataset( dataset_config_pbtext_filename, tensor_shape=[3, 4, pileup_image.DEFAULT_NUM_CHANNEL]) self.assertEqual('some_dataset_name', ds.name) self.assertEqual('/path/to/dataset', ds.source) self.assertEqual(1000, ds.num_examples) self.assertEqual([3, 4, pileup_image.DEFAULT_NUM_CHANNEL], ds.tensor_shape)
def test_reading_sharded_dataset(self, compressed_inputs): golden_dataset = make_golden_dataset(compressed_inputs) n_shards = 3 sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards)) io_utils.write_tfrecords( io_utils.read_tfrecords(golden_dataset.source), sharded_path) config_file = _test_dataset_config( 'test_sharded.pbtxt', name='sharded_test', tfrecord_path=sharded_path, num_examples=golden_dataset.num_examples) self.assertDataSetExamplesMatchExpected( data_providers.get_dataset(config_file).get_slim_dataset(), golden_dataset)
def test_reading_sharded_dataset(self, compressed_inputs): golden_dataset = make_golden_dataset(compressed_inputs) n_shards = 3 sharded_path = test_utils.test_tmpfile('sharded@{}'.format(n_shards)) io_utils.write_tfrecords( io_utils.read_tfrecords(golden_dataset.source), sharded_path) config_file = _test_dataset_config( 'test_sharded.pbtxt', name='sharded_test', tfrecord_path=sharded_path, num_examples=golden_dataset.num_examples) self.assertDataSetExamplesMatchExpected( data_providers.get_dataset(config_file).get_slim_dataset(), golden_dataset)
def test_good_dataset(self): dataset_config_pbtext_filename = _test_dataset_config( 'test_good_dataset.pbtxt', name='some_dataset_name', tfrecord_path='/path/to/dataset', num_examples=1000) ds = data_providers.get_dataset(dataset_config_pbtext_filename, tensor_shape=[100, 221, 7]) # Test that the slim.DataSet we create from the dataset has the values # and fields we expect. with tf.Session(): slim_ds = ds.get_slim_dataset() self.assertEqual(ds.num_examples, slim_ds.num_samples) self.assertItemsEqual( ['image', 'label', 'locus', 'variant', 'truth_variant'], slim_ds.decoder.list_items()) self.assertEqual([100, 221, 7], ds.tensor_shape)
def test_good_dataset(self): dataset_config_pbtext_filename = _test_dataset_config( 'test_good_dataset.pbtxt', name='some_dataset_name', tfrecord_path='/path/to/dataset', num_examples=1000) ds = data_providers.get_dataset( dataset_config_pbtext_filename, tensor_shape=[100, 221, pileup_image.DEFAULT_NUM_CHANNEL]) # Test that the slim.DataSet we create from the dataset has the values # and fields we expect. with tf.Session(): slim_ds = ds.get_slim_dataset() self.assertEqual(ds.num_examples, slim_ds.num_samples) self.assertItemsEqual(['image', 'label', 'locus', 'variant'], slim_ds.decoder.list_items()) self.assertEqual([100, 221, pileup_image.DEFAULT_NUM_CHANNEL], ds.tensor_shape)
def run(target, is_chief, device_fn): """Run training. Args: target: The target of the TensorFlow standard server to use. Can be the empty string to run locally using an inprocess server. is_chief: Boolean indicating whether this process is the chief. device_fn: Device function used to assign ops to devices. """ if not FLAGS.dataset_config_pbtxt: logging.error('Need to specify --dataset_config_pbtxt') return g = tf.Graph() with g.as_default(): model = modeling.get_model(FLAGS.model_name) dataset = data_providers.get_dataset(FLAGS.dataset_config_pbtxt) print('Running training on {} with model {}\n'.format(dataset, model)) with tf.device(device_fn): # If ps_tasks is zero, the local device is used. When using multiple # (non-local) replicas, the ReplicaDeviceSetter distributes the variables # across the different devices. images, labels, _ = data_providers.make_batches( dataset.get_slim_dataset(), model, FLAGS.batch_size, mode='TRAIN') endpoints = model.create(images, dataset.num_classes, is_training=True) labels = slim.one_hot_encoding(labels, dataset.num_classes) total_loss = loss( endpoints['Logits'], labels, label_smoothing=FLAGS.label_smoothing) # Setup the moving averages: moving_average_variables = slim.get_model_variables() moving_average_variables.extend(slim.losses.get_losses()) moving_average_variables.append(total_loss) variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, slim.get_or_create_global_step()) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, variable_averages.apply(moving_average_variables)) # Configure the learning rate using an exponetial decay. decay_steps = int(((1.0 * dataset.num_examples) / FLAGS.batch_size) * FLAGS.num_epochs_per_decay) learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, slim.get_or_create_global_step(), decay_steps, FLAGS.learning_rate_decay_factor, staircase=True) opt = tf.train.RMSPropOptimizer(learning_rate, FLAGS.rmsprop_decay, FLAGS.rmsprop_momentum, FLAGS.rmsprop_epsilon) # Create training op train_tensor = slim.learning.create_train_op( total_loss, optimizer=opt, update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)) # Summaries: slim.summaries.add_histogram_summaries(slim.get_model_variables()) slim.summaries.add_scalar_summaries(slim.losses.get_losses(), 'losses') slim.summaries.add_scalar_summary(total_loss, 'Total_Loss', 'losses') slim.summaries.add_scalar_summary(learning_rate, 'Learning_Rate', 'training') slim.summaries.add_histogram_summaries(endpoints.values()) slim.summaries.add_zero_fraction_summaries(endpoints.values()) # redacted # Set start-up delay startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps init_fn = model_init_function(model, dataset.num_classes, FLAGS.start_from_checkpoint) saver = tf.train.Saver( max_to_keep=FLAGS.max_checkpoints_to_keep, keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours) # Train model slim.learning.train( train_tensor, number_of_steps=FLAGS.number_of_steps, logdir=FLAGS.train_dir, master=target, init_fn=init_fn, is_chief=is_chief, saver=saver, startup_delay_steps=startup_delay_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def eval_loop(master, dataset_config_pbtxt, checkpoint_dir, model_name, batch_size, moving_average_decay, max_examples, eval_dir, max_evaluations): logging.info('Running fixed eval for: %s', dataset_config_pbtxt) num_evaluations = 0 for checkpoint_path in checkpoints_iterator(checkpoint_dir): logging.info('Using checkpoint %s %d', checkpoint_path, num_evaluations) g = tf.Graph() with g.as_default(): tf_global_step = tf.train.get_or_create_global_step() # redacted model = modeling.get_model(model_name) dataset = data_providers.get_dataset(dataset_config_pbtxt) logging.info('Running evaluations on %s with model %s', dataset, model) images, labels, encoded_variant = data_providers.make_batches( dataset.get_slim_dataset(), model, batch_size, mode='EVAL') endpoints = model.create(images, dataset.num_classes, is_training=False) predictions = tf.argmax(endpoints['Predictions'], 1) # For eval, explicitly add moving_mean and moving_variance variables to # the MOVING_AVERAGE_VARIABLES collection. variable_averages = tf.train.ExponentialMovingAverage( moving_average_decay, tf_global_step) for var in tf.get_collection('moving_vars'): tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var) for var in slim.get_model_variables(): tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var) variables_to_restore = variable_averages.variables_to_restore() variables_to_restore[tf_global_step.op.name] = tf_global_step names_to_values, names_to_updates = make_metrics( predictions, labels, encoded_variant) for name, value in names_to_values.iteritems(): slim.summaries.add_scalar_summary(value, name, print_summary=True) num_batches = int( math.floor( min(max_examples, dataset.num_examples) / float(batch_size))) num_samples = batch_size * num_batches logging.info('Dataset has %d samples, doing eval over %d', dataset.num_examples, num_samples) names_to_values = slim.evaluation.evaluate_once( master=master, checkpoint_path=checkpoint_path, logdir=eval_dir, variables_to_restore=variables_to_restore, num_evals=num_batches, initial_op=tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()), eval_op=names_to_updates.values(), final_op=names_to_values, ) # --- LOW LEVEL [WIP], hangs, initialization seems busted --- # This is (marginally) nicer as it can eliminate the slim dep. # saver = tf.train.Saver(variables_to_restore) # scaffold = tf.train.Scaffold(saver=saver) # names_to_values = tf.contrib.training.evaluate_once( # checkpoint_path=checkpoint_path, # master=FLAGS.master, # scaffold=scaffold, # eval_ops=names_to_updates.values(), # final_ops=names_to_values, # ) _write_checkpoint_metrics(checkpoint_path, names_to_values, eval_dir=eval_dir) num_evaluations += 1 if max_evaluations is not None and num_evaluations >= max_evaluations: return
def test_get_dataset_raises_error_for_empty_name(self): dataset_config_pbtext_filename = _test_dataset_config( 'test_get_dataset_raises_error_for_empty_name.pbtxt') with self.assertRaisesRegexp(ValueError, 'dataset_config needs to have a name'): data_providers.get_dataset(dataset_config_pbtext_filename)
def main(_): proto_utils.uses_fast_cpp_protos_or_die() if not FLAGS.dataset_config_pbtxt: logging.error('Need to specify --dataset_config_pbtxt') logging_level.set_from_flag() g = tf.Graph() with g.as_default(): tf_global_step = slim.get_or_create_global_step() model = modeling.get_model(FLAGS.model_name) dataset = data_providers.get_dataset(FLAGS.dataset_config_pbtxt) print('Running evaluations on {} with model {}\n'.format( dataset, model)) batch = data_providers.make_training_batches( dataset.get_slim_dataset(), model, FLAGS.batch_size) images, labels, encoded_truth_variants = batch endpoints = model.create(images, dataset.num_classes, is_training=False) predictions = tf.argmax(endpoints['Predictions'], 1) # For eval, explicitly add moving_mean and moving_variance variables to # the MOVING_AVERAGE_VARIABLES collection. variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, tf_global_step) for var in tf.get_collection('moving_vars'): tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var) for var in slim.get_model_variables(): tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var) variables_to_restore = variable_averages.variables_to_restore() variables_to_restore[tf_global_step.op.name] = tf_global_step # Define the metrics: metrics = { 'Accuracy': tf.contrib.metrics.streaming_accuracy, 'Mean_absolute_error': tf.contrib.metrics.streaming_mean_absolute_error, 'FPs': tf.contrib.metrics.streaming_false_positives, 'FNs': tf.contrib.metrics.streaming_false_negatives, } def _make_selector(func): return select_variants_weights(func, encoded_truth_variants) selectors = { 'All': None, 'SNPs': _make_selector(variantutils.is_snp), 'Indels': _make_selector(variantutils.is_indel), 'Insertions': _make_selector(variantutils.has_insertion), 'Deletions': _make_selector(variantutils.has_deletion), 'BiAllelic': _make_selector(variantutils.is_biallelic), 'MultiAllelic': _make_selector(variantutils.is_multiallelic), # These haven't proven particularly useful, but are commented out here # in case someone wants to do some more explorations. # 'HomRef': tf.equal(labels, 0), # 'Het': tf.equal(labels, 1), # 'HomAlt': tf.equal(labels, 2), # 'NonRef': tf.greater(labels, 0), } metrics = calling_metrics(metrics, selectors, predictions, labels) names_to_values, names_to_updates = slim.metrics.aggregate_metric_map( metrics) for name, value in names_to_values.iteritems(): slim.summaries.add_scalar_summary(value, name, print_summary=True) slim.evaluation.evaluation_loop( FLAGS.master, FLAGS.checkpoint_dir, logdir=FLAGS.eval_dir, num_evals=FLAGS.batches_per_eval_step, eval_op=names_to_updates.values(), variables_to_restore=variables_to_restore, max_number_of_evaluations=FLAGS.max_evaluations, eval_interval_secs=FLAGS.eval_interval_secs)
def eval_loop(master, dataset_config_pbtxt, checkpoint_dir, model_name, batch_size, moving_average_decay, max_examples, eval_dir, max_evaluations): logging.info('Running fixed eval for: %s', dataset_config_pbtxt) num_evaluations = 0 for checkpoint_path in checkpoints_iterator(checkpoint_dir): logging.info('Using checkpoint %s %d', checkpoint_path, num_evaluations) g = tf.Graph() with g.as_default(): tf_global_step = tf.train.get_or_create_global_step() # redacted model = modeling.get_model(model_name) dataset = data_providers.get_dataset(dataset_config_pbtxt) logging.info('Running evaluations on %s with model %s', dataset, model) images, labels, encoded_variant = data_providers.make_batches( dataset.get_slim_dataset(), model, batch_size, mode='EVAL') endpoints = model.create(images, dataset.num_classes, is_training=False) predictions = tf.argmax(endpoints['Predictions'], 1) # For eval, explicitly add moving_mean and moving_variance variables to # the MOVING_AVERAGE_VARIABLES collection. variable_averages = tf.train.ExponentialMovingAverage( moving_average_decay, tf_global_step) for var in tf.get_collection('moving_vars'): tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var) for var in slim.get_model_variables(): tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var) variables_to_restore = variable_averages.variables_to_restore() variables_to_restore[tf_global_step.op.name] = tf_global_step names_to_values, names_to_updates = make_metrics(predictions, labels, encoded_variant) for name, value in names_to_values.iteritems(): slim.summaries.add_scalar_summary(value, name, print_summary=True) num_batches = int( math.floor( min(max_examples, dataset.num_examples) / float(batch_size))) num_samples = batch_size * num_batches logging.info('Dataset has %d samples, doing eval over %d', dataset.num_examples, num_samples) names_to_values = slim.evaluation.evaluate_once( master=master, checkpoint_path=checkpoint_path, logdir=eval_dir, variables_to_restore=variables_to_restore, num_evals=num_batches, initial_op=tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()), eval_op=names_to_updates.values(), final_op=names_to_values, ) # --- LOW LEVEL [WIP], hangs, initialization seems busted --- # This is (marginally) nicer as it can eliminate the slim dep. # saver = tf.train.Saver(variables_to_restore) # scaffold = tf.train.Scaffold(saver=saver) # names_to_values = tf.contrib.training.evaluate_once( # checkpoint_path=checkpoint_path, # master=FLAGS.master, # scaffold=scaffold, # eval_ops=names_to_updates.values(), # final_ops=names_to_values, # ) _write_checkpoint_metrics( checkpoint_path, names_to_values, eval_dir=eval_dir) num_evaluations += 1 if max_evaluations is not None and num_evaluations >= max_evaluations: return
def run(target, is_chief, device_fn): """Run training. Args: target: The target of the TensorFlow standard server to use. Can be the empty string to run locally using an inprocess server. is_chief: Boolean indicating whether this process is the chief. device_fn: Device function used to assign ops to devices. """ if not FLAGS.dataset_config_pbtxt: logging.error('Need to specify --dataset_config_pbtxt') return g = tf.Graph() with g.as_default(): model = modeling.get_model(FLAGS.model_name) dataset = data_providers.get_dataset(FLAGS.dataset_config_pbtxt) print('Running training on {} with model {}\n'.format(dataset, model)) with tf.device(device_fn): # If ps_tasks is zero, the local device is used. When using multiple # (non-local) replicas, the ReplicaDeviceSetter distributes the variables # across the different devices. images, labels, _ = data_providers.make_training_batches( dataset.get_slim_dataset(), model, FLAGS.batch_size) endpoints = model.create(images, dataset.num_classes, is_training=True) labels = slim.one_hot_encoding(labels, dataset.num_classes) total_loss = model.loss(endpoints, labels) # Setup the moving averages: moving_average_variables = slim.get_model_variables() moving_average_variables.extend(slim.losses.get_losses()) moving_average_variables.append(total_loss) variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, slim.get_or_create_global_step()) tf.add_to_collection( tf.GraphKeys.UPDATE_OPS, variable_averages.apply(moving_average_variables)) # Configure the learning rate using an exponetial decay. decay_steps = int( ((1.0 * dataset.num_examples) / FLAGS.batch_size) * FLAGS.num_epochs_per_decay) learning_rate = tf.train.exponential_decay( FLAGS.learning_rate, slim.get_or_create_global_step(), decay_steps, FLAGS.learning_rate_decay_factor, staircase=True) opt = tf.train.RMSPropOptimizer(learning_rate, FLAGS.rmsprop_decay, FLAGS.rmsprop_momentum, FLAGS.rmsprop_epsilon) # Create training op train_tensor = slim.learning.create_train_op( total_loss, optimizer=opt, update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)) # Summaries: slim.summaries.add_histogram_summaries(slim.get_model_variables()) slim.summaries.add_scalar_summaries(slim.losses.get_losses(), 'losses') slim.summaries.add_scalar_summary(total_loss, 'Total_Loss', 'losses') slim.summaries.add_scalar_summary(learning_rate, 'Learning_Rate', 'training') slim.summaries.add_histogram_summaries(endpoints.values()) slim.summaries.add_zero_fraction_summaries(endpoints.values()) # redacted # Set start-up delay startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps init_fn = model_init_function(model, dataset.num_classes, FLAGS.start_from_checkpoint) saver = tf.train.Saver(max_to_keep=FLAGS.max_checkpoints_to_keep, keep_checkpoint_every_n_hours=FLAGS. keep_checkpoint_every_n_hours) # Train model slim.learning.train(train_tensor, number_of_steps=FLAGS.number_of_steps, logdir=FLAGS.train_dir, master=target, init_fn=init_fn, is_chief=is_chief, saver=saver, startup_delay_steps=startup_delay_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs)
def test_get_dataset_raises_error_for_empty_name(self): dataset_config_pbtext_filename = _test_dataset_config( 'test_get_dataset_raises_error_for_empty_name.pbtxt') with self.assertRaisesRegexp(ValueError, 'dataset_config needs to have a name'): data_providers.get_dataset(dataset_config_pbtext_filename)