def test_get_training_batches(self, compressed_inputs):
        golden_dataset = make_golden_dataset(compressed_inputs)
        batch_size = 16
        with tf.Session() as sess:
            mock_model = mock.MagicMock(autospec=modeling.DeepVariantModel)
            mock_model.preprocess_image.side_effect = functools.partial(
                tf.image.resize_image_with_crop_or_pad,
                target_height=107,
                target_width=221)
            batch = data_providers.make_training_batches(
                golden_dataset.get_slim_dataset(), mock_model, batch_size)

            # We should have called our preprocess_image exactly once. We don't have
            # the actual objects to test for the call, though.
            test_utils.assert_called_once_workaround(
                mock_model.preprocess_image)

            # Get our images, labels, and variants for further testing.
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sess)
            images, labels, variants = sess.run(batch)

            # Checks that our labels are the right shape and are one-hot encoded.
            self.assertEqual(
                (batch_size, 107, 221, pileup_image.DEFAULT_NUM_CHANNEL),
                images.shape)
            self.assertEqual((batch_size, ), labels.shape)
            for label in labels:
                self.assertTrue(0 <= label <= 2)

            # Check that our variants has the shape we expect and actually contain
            # variants by decoding them and checking the reference_name.
            self.assertEqual((batch_size, ), variants.shape)
            for variant in variantutils.decode_variants(variants):
                self.assertEqual(variant.reference_name, 'chr20')

            # Shutdown tensorflow
            coord.request_stop()
            coord.join(threads)
def main(_):
    proto_utils.uses_fast_cpp_protos_or_die()

    if not FLAGS.dataset_config_pbtxt:
        logging.error('Need to specify --dataset_config_pbtxt')
    logging_level.set_from_flag()

    g = tf.Graph()
    with g.as_default():
        tf_global_step = slim.get_or_create_global_step()

        model = modeling.get_model(FLAGS.model_name)
        dataset = data_providers.get_dataset(FLAGS.dataset_config_pbtxt)
        print('Running evaluations on {} with model {}\n'.format(
            dataset, model))

        batch = data_providers.make_training_batches(
            dataset.get_slim_dataset(), model, FLAGS.batch_size)
        images, labels, encoded_truth_variants = batch
        endpoints = model.create(images,
                                 dataset.num_classes,
                                 is_training=False)
        predictions = tf.argmax(endpoints['Predictions'], 1)

        # For eval, explicitly add moving_mean and moving_variance variables to
        # the MOVING_AVERAGE_VARIABLES collection.
        variable_averages = tf.train.ExponentialMovingAverage(
            FLAGS.moving_average_decay, tf_global_step)

        for var in tf.get_collection('moving_vars'):
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
        for var in slim.get_model_variables():
            tf.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var)

        variables_to_restore = variable_averages.variables_to_restore()
        variables_to_restore[tf_global_step.op.name] = tf_global_step

        # Define the metrics:
        metrics = {
            'Accuracy': tf.contrib.metrics.streaming_accuracy,
            'Mean_absolute_error':
            tf.contrib.metrics.streaming_mean_absolute_error,
            'FPs': tf.contrib.metrics.streaming_false_positives,
            'FNs': tf.contrib.metrics.streaming_false_negatives,
        }

        def _make_selector(func):
            return select_variants_weights(func, encoded_truth_variants)

        selectors = {
            'All': None,
            'SNPs': _make_selector(variantutils.is_snp),
            'Indels': _make_selector(variantutils.is_indel),
            'Insertions': _make_selector(variantutils.has_insertion),
            'Deletions': _make_selector(variantutils.has_deletion),
            'BiAllelic': _make_selector(variantutils.is_biallelic),
            'MultiAllelic': _make_selector(variantutils.is_multiallelic),
            # These haven't proven particularly useful, but are commented out here
            # in case someone wants to do some more explorations.
            # 'HomRef': tf.equal(labels, 0),
            # 'Het': tf.equal(labels, 1),
            # 'HomAlt': tf.equal(labels, 2),
            # 'NonRef': tf.greater(labels, 0),
        }
        metrics = calling_metrics(metrics, selectors, predictions, labels)
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
            metrics)

        for name, value in names_to_values.iteritems():
            slim.summaries.add_scalar_summary(value, name, print_summary=True)

        slim.evaluation.evaluation_loop(
            FLAGS.master,
            FLAGS.checkpoint_dir,
            logdir=FLAGS.eval_dir,
            num_evals=FLAGS.batches_per_eval_step,
            eval_op=names_to_updates.values(),
            variables_to_restore=variables_to_restore,
            max_number_of_evaluations=FLAGS.max_evaluations,
            eval_interval_secs=FLAGS.eval_interval_secs)
def run(target, is_chief, device_fn):
    """Run training.

  Args:
     target: The target of the TensorFlow standard server to use. Can be the
       empty string to run locally using an inprocess server.
     is_chief: Boolean indicating whether this process is the chief.
     device_fn: Device function used to assign ops to devices.
  """
    if not FLAGS.dataset_config_pbtxt:
        logging.error('Need to specify --dataset_config_pbtxt')
        return

    g = tf.Graph()
    with g.as_default():
        model = modeling.get_model(FLAGS.model_name)
        dataset = data_providers.get_dataset(FLAGS.dataset_config_pbtxt)
        print('Running training on {} with model {}\n'.format(dataset, model))

        with tf.device(device_fn):
            # If ps_tasks is zero, the local device is used. When using multiple
            # (non-local) replicas, the ReplicaDeviceSetter distributes the variables
            # across the different devices.
            images, labels, _ = data_providers.make_training_batches(
                dataset.get_slim_dataset(), model, FLAGS.batch_size)
            endpoints = model.create(images,
                                     dataset.num_classes,
                                     is_training=True)
            labels = slim.one_hot_encoding(labels, dataset.num_classes)
            total_loss = model.loss(endpoints, labels)

            # Setup the moving averages:
            moving_average_variables = slim.get_model_variables()
            moving_average_variables.extend(slim.losses.get_losses())
            moving_average_variables.append(total_loss)

            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, slim.get_or_create_global_step())

            tf.add_to_collection(
                tf.GraphKeys.UPDATE_OPS,
                variable_averages.apply(moving_average_variables))

            # Configure the learning rate using an exponetial decay.
            decay_steps = int(
                ((1.0 * dataset.num_examples) / FLAGS.batch_size) *
                FLAGS.num_epochs_per_decay)

            learning_rate = tf.train.exponential_decay(
                FLAGS.learning_rate,
                slim.get_or_create_global_step(),
                decay_steps,
                FLAGS.learning_rate_decay_factor,
                staircase=True)

            opt = tf.train.RMSPropOptimizer(learning_rate, FLAGS.rmsprop_decay,
                                            FLAGS.rmsprop_momentum,
                                            FLAGS.rmsprop_epsilon)

            # Create training op
            train_tensor = slim.learning.create_train_op(
                total_loss,
                optimizer=opt,
                update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))

            # Summaries:
            slim.summaries.add_histogram_summaries(slim.get_model_variables())
            slim.summaries.add_scalar_summaries(slim.losses.get_losses(),
                                                'losses')
            slim.summaries.add_scalar_summary(total_loss, 'Total_Loss',
                                              'losses')
            slim.summaries.add_scalar_summary(learning_rate, 'Learning_Rate',
                                              'training')
            slim.summaries.add_histogram_summaries(endpoints.values())
            slim.summaries.add_zero_fraction_summaries(endpoints.values())
            # redacted

            # Set start-up delay
            startup_delay_steps = FLAGS.task * FLAGS.startup_delay_steps

            init_fn = model_init_function(model, dataset.num_classes,
                                          FLAGS.start_from_checkpoint)

            saver = tf.train.Saver(max_to_keep=FLAGS.max_checkpoints_to_keep,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)

            # Train model
            slim.learning.train(train_tensor,
                                number_of_steps=FLAGS.number_of_steps,
                                logdir=FLAGS.train_dir,
                                master=target,
                                init_fn=init_fn,
                                is_chief=is_chief,
                                saver=saver,
                                startup_delay_steps=startup_delay_steps,
                                save_summaries_secs=FLAGS.save_summaries_secs,
                                save_interval_secs=FLAGS.save_interval_secs)