예제 #1
0
 def __init__(self,
              meta_file,
              ckpt=None,
              data_dir=None,
              input_height=224,
              input_width=224,
              max_bbox_jitter=0.025,
              max_rotation=10,
              max_shear=0.15,
              max_pixel_shift=10,
              max_pixel_scale_change=0.2):
     self.meta_file = meta_file
     self.ckpt = ckpt
     self.input_height = input_height
     self.input_width = input_width
     if data_dir is None:
         self.dataset = None
     else:
         self.dataset = COVIDxCTDataset(
             data_dir,
             image_height=input_height,
             image_width=input_width,
             max_bbox_jitter=max_bbox_jitter,
             max_rotation=max_rotation,
             max_shear=max_shear,
             max_pixel_shift=max_pixel_shift,
             max_pixel_scale_change=max_pixel_scale_change)
예제 #2
0
    def __init__(self,
                 meta_file,
                 ckpt=None,
                 data_dir=None,
                 input_height=512,
                 input_width=512,
                 lr=0.001,
                 momentum=0.9,
                 fc_only=False,
                 max_bbox_jitter=0.025,
                 max_rotation=10,
                 max_shear=0.15,
                 max_pixel_shift=10,
                 max_pixel_scale_change=0.2):
        self.meta_file = meta_file
        self.ckpt = ckpt
        self.input_height = input_height
        self.input_width = input_width
        if data_dir is None:
            self.dataset = None
        else:
            self.dataset = COVIDxCTDataset(
                data_dir,
                image_height=input_height,
                image_width=input_width,
                max_bbox_jitter=max_bbox_jitter,
                max_rotation=max_rotation,
                max_shear=max_shear,
                max_pixel_shift=max_pixel_shift,
                max_pixel_scale_change=max_pixel_scale_change)

        # Load graph/checkpoint and add optimizer
        self.graph, self.sess, self.saver = load_graph(self.meta_file)
        with self.graph.as_default():
            self.train_op = self._add_optimizer(lr, momentum, fc_only)
            load_ckpt(self.ckpt, self.sess, self.saver)
예제 #3
0
class COVIDNetCTRunner:
    """Primary training/testing/inference class"""
    def __init__(self, meta_file, ckpt=None, data_dir=None, input_height=224, input_width=224, max_bbox_jitter=0.025,
                 max_rotation=10, max_shear=0.15, max_pixel_shift=10, max_pixel_scale_change=0.2):
        self.meta_file = meta_file
        self.ckpt = ckpt
        self.input_height = input_height
        self.input_width = input_width
        if data_dir is None:
            self.dataset = None
        else:
            self.dataset = COVIDxCTDataset(
                data_dir,
                image_height=input_height,
                image_width=input_width,
                max_bbox_jitter=max_bbox_jitter,
                max_rotation=max_rotation,
                max_shear=max_shear,
                max_pixel_shift=max_pixel_shift,
                max_pixel_scale_change=max_pixel_scale_change
            )

    def load_graph(self):
        """Creates new graph and session"""
        graph = tf.Graph()
        with graph.as_default():
            # Create session and load model
            sess = create_session()

            # Load meta file
            print('Loading meta graph from ' + self.meta_file)
            saver = tf.train.import_meta_graph(self.meta_file)
        return graph, sess, saver

    def load_ckpt(self, sess, saver):
        """Helper for loading weights"""
        # Load weights
        if self.ckpt is not None:
            print('Loading weights from ' + self.ckpt)
            saver.restore(sess, self.ckpt)

    def trainval(self, epochs, output_dir, batch_size=1, learning_rate=0.001, momentum=0.9,
                 fc_only=False, train_split_file='train.txt', val_split_file='val.txt',
                 log_interval=20, val_interval=1000, save_interval=1000):
        """Run training with intermittent validation"""
        ckpt_path = os.path.join(output_dir, CKPT_NAME)
        graph, sess, saver = self.load_graph()
        with graph.as_default():
            # Create optimizer
            optimizer = tf.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=momentum
            )

            # Create train op
            global_step = tf.train.get_or_create_global_step()
            loss = graph.get_tensor_by_name(LOSS_TENSOR)
            grad_vars = optimizer.compute_gradients(loss)
            if fc_only:
                grad_vars = dense_grad_filter(grad_vars)
            minimize_op = optimizer.apply_gradients(grad_vars, global_step)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            train_op = tf.group(minimize_op, update_ops)

            # Load checkpoint
            sess.run(tf.global_variables_initializer())
            self.load_ckpt(sess, saver)

            # Create train dataset
            dataset, num_images, batch_size = self.dataset.train_dataset(train_split_file, batch_size)
            data_next = dataset.make_one_shot_iterator().get_next()
            num_iters = ceil(num_images / batch_size) * epochs

            # Create feed and fetch dicts
            feed_dict = {TRAINING_PH_TENSOR: True}
            fetch_dict = {
                TRAIN_OP_KEY: train_op,
                LOSS_KEY: LOSS_TENSOR
            }

            # Add summaries
            summary_writer = tf.summary.FileWriter(os.path.join(output_dir, 'events'), graph)
            fetch_dict[TF_SUMMARY_KEY] = self._get_train_summary_op(graph)

            # Create validation function
            run_validation = self._get_validation_fn(sess, batch_size, val_split_file)

            # Baseline saving and validation
            print('Saving baseline checkpoint')
            saver = tf.train.Saver()
            saver.save(sess, ckpt_path, global_step=0)
            print('Starting baseline validation')
            metrics = run_validation()
            self._log_and_print_metrics(metrics, 0, summary_writer)

            # Training loop
            print('Training with batch_size {} for {} steps'.format(batch_size, num_iters))
            for i in range(num_iters):
                # Run training step
                data = sess.run(data_next)
                feed_dict[IMAGE_INPUT_TENSOR] = data['image']
                feed_dict[LABEL_INPUT_TENSOR] = data['label']
                results = sess.run(fetch_dict, feed_dict)

                # Log and save
                step = i + 1
                if step % log_interval == 0:
                    summary_writer.add_summary(results[TF_SUMMARY_KEY], step)
                    print('[step: {}, loss: {}]'.format(step, results[LOSS_KEY]))
                if step % save_interval == 0:
                    print('Saving checkpoint at step {}'.format(step))
                    saver.save(sess, ckpt_path, global_step=step)
                if val_interval > 0 and step % val_interval == 0:
                    print('Starting validation at step {}'.format(step))
                    metrics = run_validation()
                    self._log_and_print_metrics(metrics, step, summary_writer)

            print('Saving checkpoint at last step')
            saver.save(sess, ckpt_path, global_step=num_iters)

    def test(self, batch_size=1, test_split_file='test.txt', plot_confusion=False):
        """Run test on a checkpoint"""
        graph, sess, saver = self.load_graph()
        with graph.as_default():
            # Load checkpoint
            self.load_ckpt(sess, saver)

            # Run test
            print('Starting test')
            metrics = self._get_validation_fn(sess, batch_size, test_split_file)()
            self._log_and_print_metrics(metrics)

            if plot_confusion:
                # Plot confusion matrix
                fig, ax = plt.subplots()
                disp = ConfusionMatrixDisplay(confusion_matrix=metrics['confusion matrix'],
                                              display_labels=CLASS_NAMES)
                disp.plot(include_values=True, cmap='Blues', ax=ax, xticks_rotation='horizontal', values_format='.5g')
                plt.show()

    def infer(self, image_file, autocrop=False):
        """Run inference on the given image"""
        # Load and preprocess image
        image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
        if autocrop:
            image, _ = auto_body_crop(image)
        image = cv2.resize(image, (self.input_width, self.input_height), cv2.INTER_CUBIC)
        image = image.astype(np.float32) / 255.0
        image = np.expand_dims(np.stack((image, image, image), axis=-1), axis=0)

        # Create feed dict
        feed_dict = {IMAGE_INPUT_TENSOR: image, TRAINING_PH_TENSOR: False}

        # Run inference
        graph, sess, saver = self.load_graph()
        with graph.as_default():
            # Load checkpoint
            self.load_ckpt(sess, saver)

            # Run image through model
            class_, probs = sess.run([CLASS_PRED_TENSOR, CLASS_PROB_TENSOR], feed_dict=feed_dict)
            print('\nPredicted Class: ' + CLASS_NAMES[class_[0]])
            print('Confidences:' + ', '.join(
                '{}: {}'.format(name, conf) for name, conf in zip(CLASS_NAMES, probs[0])))
            print('**DISCLAIMER**')
            print('Do not use this prediction for self-diagnosis. '
                  'You should check with your local authorities for '
                  'the latest advice on seeking medical assistance.')

    def _get_validation_fn(self, sess, batch_size=1, val_split_file='val.txt'):
        """Creates validation function to call in self.trainval() or self.test()"""
        # Create val dataset
        dataset, num_images, batch_size = self.dataset.validation_dataset(val_split_file, batch_size)
        dataset = dataset.repeat()  # repeat so there is no need to reconstruct it
        data_next = dataset.make_one_shot_iterator().get_next()
        num_iters = ceil(num_images / batch_size)

        # Create running accuracy metric
        metrics = Metrics()

        # Create feed and fetch dicts
        fetch_dict = {'classes': CLASS_PRED_TENSOR}
        feed_dict = {TRAINING_PH_TENSOR: False}

        def run_validation():
            metrics.reset()
            for i in range(num_iters):
                data = sess.run(data_next)
                feed_dict[IMAGE_INPUT_TENSOR] = data['image']
                results = sess.run(fetch_dict, feed_dict)
                metrics.update(data['label'], results['classes'])
            return metrics.values()

        return run_validation

    @staticmethod
    def _log_and_print_metrics(metrics, step=None, summary_writer=None, tag_prefix='val/'):
        """Helper for logging and printing"""
        # Pop temporarily and print
        cm = metrics.pop('confusion matrix')
        print('\tconfusion matrix:')
        print('\t' + str(cm).replace('\n', '\n\t'))

        # Print scalar metrics
        for name, val in sorted(metrics.items()):
            print('\t{}: {}'.format(name, val))

        # Log scalar metrics
        if summary_writer is not None:
            summary = simple_summary(metrics, tag_prefix)
            summary_writer.add_summary(summary, step)

        # Restore confusion matrix
        metrics['confusion matrix'] = cm

    @staticmethod
    def _get_train_summary_op(graph, tag_prefix='train/'):
        loss = graph.get_tensor_by_name(LOSS_TENSOR)
        loss_summary = tf.summary.scalar(tag_prefix + 'loss', loss)
        return loss_summary