예제 #1
0
def run_inference(model_architecture='resnet7_2_1',
                  checkpoint_path='./data/pytorch_model_epoch300.bin',
                  image_path='./data/52658141.png'):
    device = 'cpu'
    '''
	Create an instance of a resnet model and load a checkpoint
	'''
    output_channels = 4
    if model_architecture == 'resnet7_2_1':
        resnet_model = resnet7_2_1(pretrained=True,
                                   pretrained_model_path=checkpoint_path,
                                   output_channels=output_channels)
    resnet_model = resnet_model.to(device)
    '''
	Load the input image
	'''
    image = load_image(image_path)
    '''
	Run model inference on the image
	'''
    pred = main_utils.inference(resnet_model, image)
    pred = pred[0]
    severity = sum([i * pred[i] for i in range(len(pred))])

    print(f"{image_path} has severity of {severity}")
예제 #2
0
def run_inference_gradcam(
        image_path,
        model_architecture='resnet7_2_1',
        checkpoint_path='/opt/mlmodel/data/pytorch_model_epoch300.bin'):
    device = 'cpu'
    '''
	Create an instance of a resnet model and load a checkpoint
	'''
    output_channels = 4
    if model_architecture == 'resnet7_2_1':
        resnet_model = resnet7_2_1(pretrained=True,
                                   pretrained_model_path=checkpoint_path,
                                   output_channels=output_channels)
    resnet_model = resnet_model.to(device)
    '''
	Create an instance of model with Grad-CAM 
	'''
    model_gcam = GradCAM(model=resnet_model)
    '''
	Load the input image
	'''
    image = load_image(image_path)
    '''
	Run model inference on the image with Grad-CAM
	'''
    pred, gcam_img, input_img = main_utils.inference_gradcam(
        model_gcam, image, 'layer7.1.conv2')
    pred = pred[0]
    severity = sum([i * pred[i] for i in range(len(pred))])

    return severity, gcam_img, input_img
예제 #3
0
    def __call__(self, study_name):

        png_path = Path.png_path(study_name)

        image = load_image(png_path)
        xray_transform = CenterCrop(2048)
        image = xray_transform(image)
        image = 65535 * image
        image = image.astype(np.uint16)

        return run_inference(png_path), image
예제 #4
0
def run_inference_gradcam(model_architecture='resnet7_2_1',
                          checkpoint_path='./data/pytorch_model_epoch300.bin',
                          image_path='./data/52658141.png',
                          gcam_path='/mnt/images/52658141_gcam.png'):
    device = 'cpu'
    '''
	Create an instance of a resnet model and load a checkpoint
	'''
    output_channels = 4
    if model_architecture == 'resnet7_2_1':
        resnet_model = resnet7_2_1(pretrained=True,
                                   pretrained_model_path=checkpoint_path,
                                   output_channels=output_channels)
    resnet_model = resnet_model.to(device)
    '''
	Create an instance of model with Grad-CAM 
	'''
    model_gcam = GradCAM(model=resnet_model)
    '''
	Load the input image
	'''
    image = load_image(image_path)
    '''
	Run model inference on the image with Grad-CAM
	'''
    pred, gcam_img, input_img = main_utils.inference_gradcam(
        model_gcam, image, 'layer7.1.conv2')

    pred = pred[0]
    severity = sum([i * pred[i] for i in range(len(pred))])

    print(f"{image_path} has severity of {severity}")

    save_gradcam_overlay(gcam_path, gcam_img[0], input_img[0])

    print(f"Grad-CAM overlay saved at {gcam_path}")

    return
예제 #5
0
def main():
    args = parser.get_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # assert torch.cuda.is_available(), "No GPU/CUDA is detected!"

    assert args.do_train or args.do_eval, \
     "Either do_train or do_eval has to be True!"
    assert not(args.do_train and args.do_eval), \
     "do_train and do_eval cannot be both True!"

    # To track results from different commits (temporary)
    if args.commit_sha == None:
        args.run_id = args.run_id + '_' + str(sha)
    else:
        args.run_id = args.run_id + '_' + args.commit_sha

    if not args.run_id == None:
        args.output_dir = os.path.join(args.output_dir, args.run_id)
    if not (os.path.exists(args.output_dir)) and args.do_train:
        os.makedirs(args.output_dir)
    if args.do_eval:
        # output_dir has to exist if doing evaluation
        assert os.path.exists(args.output_dir), \
         "Output directory {} doesn't exist!".format(args.output_dir)
        # if args.data_split_mode=='testing':
        # 	# Checkpoint has to exist if doing evaluation with testing split
        # 	assert os.path.exists(args.checkpoint_path), \
        # 		"Checkpoint doesn't exist!"
    '''
	Configure a log file
	'''
    if args.do_train:
        log_path = os.path.join(args.output_dir, 'training.log')
    if args.do_eval:
        log_path = os.path.join(args.output_dir, 'evaluation.log')
    logging.basicConfig(filename=log_path,
                        level=logging.INFO,
                        filemode='w',
                        format='%(asctime)s - %(name)s %(message)s',
                        datefmt='%m-%d %H:%M')
    '''
	Log important info
	'''
    logger = logging.getLogger(__name__)
    logger.info("***** Code info *****")
    logger.info("  Git commit sha: %s", sha)
    '''
	Print important info
	'''
    print('Model architecture:', args.model_architecture)
    print('Training folds:', args.training_folds)
    print('Evaluation folds:', args.evaluation_folds)
    print('Device being used:', device)
    print('Output directory:', args.output_dir)
    print('Logging in:\t {}'.format(log_path))
    print('Input image formet:', args.image_format)
    print('Loss function: {}'.format(args.loss))

    if args.do_inference:
        '''
		Create an instance of a resnet model and load a checkpoint
		'''
        output_channels = 4
        if args.model_architecture == 'resnet7_2_1':
            resnet_model = resnet7_2_1(
                pretrained=True,
                pretrained_model_path=args.checkpoint_path,
                output_channels=output_channels)
        resnet_model = resnet_model.to(device)
        '''
		Load the input image
		'''
        image = load_image(args.image_path)
        '''
		Run model inference on the image
		'''
        pred = main_utils.inference(resnet_model, image)
        pred = pred[0]
        severity = sum([i * pred[i] for i in range(len(pred))])

        print(f"{args.image_path} has severity of {severity}")

        return

    if args.do_train:
        '''
		Create tensorboard and checkpoint directories if they don't exist
		'''
        args.tsbd_dir = os.path.join(args.output_dir, 'tsbd')
        args.checkpoints_dir = os.path.join(args.output_dir, 'checkpoints')
        directories = [args.tsbd_dir, args.checkpoints_dir]
        for directory in directories:
            if not (os.path.exists(directory)):
                os.makedirs(directory)
        # Avoid overwriting previous tensorboard and checkpoint data
        args.tsbd_dir = os.path.join(
            args.tsbd_dir, 'tsbd_{}'.format(len(os.listdir(args.tsbd_dir))))
        if not os.path.exists(args.tsbd_dir):
            os.makedirs(args.tsbd_dir)
        args.checkpoints_dir = os.path.join(
            args.checkpoints_dir,
            'checkpoints_{}'.format(len(os.listdir(args.checkpoints_dir))))
        if not os.path.exists(args.checkpoints_dir):
            os.makedirs(args.checkpoints_dir)
        '''
		Create an instance of a resnet model
		'''
        output_channels = 4
        if args.model_architecture == 'resnet7_2_1':
            resnet_model = resnet7_2_1(output_channels=output_channels)
        resnet_model = resnet_model.to(device)
        '''
		Train the model
		'''
        print("***** Training the model *****")
        main_utils.train(args, device, resnet_model)
        print("***** Finished training *****")

    if args.do_eval:

        def run_eval_on_checkpoint(checkpoint_path):
            '''
			Create an instance of a resnet model and load a checkpoint
			'''
            output_channels = 4
            if args.model_architecture == 'resnet7_2_1':
                resnet_model = resnet7_2_1(
                    pretrained=True,
                    pretrained_model_path=checkpoint_path,
                    output_channels=output_channels)
            resnet_model = resnet_model.to(device)
            '''
			Evaluate the model
			'''
            print("***** Evaluating the model *****")
            eval_results, embeddings, labels_raw = main_utils.evaluate(
                args, device, resnet_model)
            print("***** Finished evaluation *****")

            return eval_results, embeddings, labels_raw

        eval_results, _, _ = run_eval_on_checkpoint(
            checkpoint_path=args.checkpoint_path)

        results_path = os.path.join(args.output_dir, 'eval_results.json')
        with open(results_path, 'w') as fp:
            json.dump(eval_results, fp)