def run_inference(model_architecture='resnet7_2_1', checkpoint_path='./data/pytorch_model_epoch300.bin', image_path='./data/52658141.png'): device = 'cpu' ''' Create an instance of a resnet model and load a checkpoint ''' output_channels = 4 if model_architecture == 'resnet7_2_1': resnet_model = resnet7_2_1(pretrained=True, pretrained_model_path=checkpoint_path, output_channels=output_channels) resnet_model = resnet_model.to(device) ''' Load the input image ''' image = load_image(image_path) ''' Run model inference on the image ''' pred = main_utils.inference(resnet_model, image) pred = pred[0] severity = sum([i * pred[i] for i in range(len(pred))]) print(f"{image_path} has severity of {severity}")
def run_inference_gradcam( image_path, model_architecture='resnet7_2_1', checkpoint_path='/opt/mlmodel/data/pytorch_model_epoch300.bin'): device = 'cpu' ''' Create an instance of a resnet model and load a checkpoint ''' output_channels = 4 if model_architecture == 'resnet7_2_1': resnet_model = resnet7_2_1(pretrained=True, pretrained_model_path=checkpoint_path, output_channels=output_channels) resnet_model = resnet_model.to(device) ''' Create an instance of model with Grad-CAM ''' model_gcam = GradCAM(model=resnet_model) ''' Load the input image ''' image = load_image(image_path) ''' Run model inference on the image with Grad-CAM ''' pred, gcam_img, input_img = main_utils.inference_gradcam( model_gcam, image, 'layer7.1.conv2') pred = pred[0] severity = sum([i * pred[i] for i in range(len(pred))]) return severity, gcam_img, input_img
def __call__(self, study_name): png_path = Path.png_path(study_name) image = load_image(png_path) xray_transform = CenterCrop(2048) image = xray_transform(image) image = 65535 * image image = image.astype(np.uint16) return run_inference(png_path), image
def run_inference_gradcam(model_architecture='resnet7_2_1', checkpoint_path='./data/pytorch_model_epoch300.bin', image_path='./data/52658141.png', gcam_path='/mnt/images/52658141_gcam.png'): device = 'cpu' ''' Create an instance of a resnet model and load a checkpoint ''' output_channels = 4 if model_architecture == 'resnet7_2_1': resnet_model = resnet7_2_1(pretrained=True, pretrained_model_path=checkpoint_path, output_channels=output_channels) resnet_model = resnet_model.to(device) ''' Create an instance of model with Grad-CAM ''' model_gcam = GradCAM(model=resnet_model) ''' Load the input image ''' image = load_image(image_path) ''' Run model inference on the image with Grad-CAM ''' pred, gcam_img, input_img = main_utils.inference_gradcam( model_gcam, image, 'layer7.1.conv2') pred = pred[0] severity = sum([i * pred[i] for i in range(len(pred))]) print(f"{image_path} has severity of {severity}") save_gradcam_overlay(gcam_path, gcam_img[0], input_img[0]) print(f"Grad-CAM overlay saved at {gcam_path}") return
def main(): args = parser.get_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # assert torch.cuda.is_available(), "No GPU/CUDA is detected!" assert args.do_train or args.do_eval, \ "Either do_train or do_eval has to be True!" assert not(args.do_train and args.do_eval), \ "do_train and do_eval cannot be both True!" # To track results from different commits (temporary) if args.commit_sha == None: args.run_id = args.run_id + '_' + str(sha) else: args.run_id = args.run_id + '_' + args.commit_sha if not args.run_id == None: args.output_dir = os.path.join(args.output_dir, args.run_id) if not (os.path.exists(args.output_dir)) and args.do_train: os.makedirs(args.output_dir) if args.do_eval: # output_dir has to exist if doing evaluation assert os.path.exists(args.output_dir), \ "Output directory {} doesn't exist!".format(args.output_dir) # if args.data_split_mode=='testing': # # Checkpoint has to exist if doing evaluation with testing split # assert os.path.exists(args.checkpoint_path), \ # "Checkpoint doesn't exist!" ''' Configure a log file ''' if args.do_train: log_path = os.path.join(args.output_dir, 'training.log') if args.do_eval: log_path = os.path.join(args.output_dir, 'evaluation.log') logging.basicConfig(filename=log_path, level=logging.INFO, filemode='w', format='%(asctime)s - %(name)s %(message)s', datefmt='%m-%d %H:%M') ''' Log important info ''' logger = logging.getLogger(__name__) logger.info("***** Code info *****") logger.info(" Git commit sha: %s", sha) ''' Print important info ''' print('Model architecture:', args.model_architecture) print('Training folds:', args.training_folds) print('Evaluation folds:', args.evaluation_folds) print('Device being used:', device) print('Output directory:', args.output_dir) print('Logging in:\t {}'.format(log_path)) print('Input image formet:', args.image_format) print('Loss function: {}'.format(args.loss)) if args.do_inference: ''' Create an instance of a resnet model and load a checkpoint ''' output_channels = 4 if args.model_architecture == 'resnet7_2_1': resnet_model = resnet7_2_1( pretrained=True, pretrained_model_path=args.checkpoint_path, output_channels=output_channels) resnet_model = resnet_model.to(device) ''' Load the input image ''' image = load_image(args.image_path) ''' Run model inference on the image ''' pred = main_utils.inference(resnet_model, image) pred = pred[0] severity = sum([i * pred[i] for i in range(len(pred))]) print(f"{args.image_path} has severity of {severity}") return if args.do_train: ''' Create tensorboard and checkpoint directories if they don't exist ''' args.tsbd_dir = os.path.join(args.output_dir, 'tsbd') args.checkpoints_dir = os.path.join(args.output_dir, 'checkpoints') directories = [args.tsbd_dir, args.checkpoints_dir] for directory in directories: if not (os.path.exists(directory)): os.makedirs(directory) # Avoid overwriting previous tensorboard and checkpoint data args.tsbd_dir = os.path.join( args.tsbd_dir, 'tsbd_{}'.format(len(os.listdir(args.tsbd_dir)))) if not os.path.exists(args.tsbd_dir): os.makedirs(args.tsbd_dir) args.checkpoints_dir = os.path.join( args.checkpoints_dir, 'checkpoints_{}'.format(len(os.listdir(args.checkpoints_dir)))) if not os.path.exists(args.checkpoints_dir): os.makedirs(args.checkpoints_dir) ''' Create an instance of a resnet model ''' output_channels = 4 if args.model_architecture == 'resnet7_2_1': resnet_model = resnet7_2_1(output_channels=output_channels) resnet_model = resnet_model.to(device) ''' Train the model ''' print("***** Training the model *****") main_utils.train(args, device, resnet_model) print("***** Finished training *****") if args.do_eval: def run_eval_on_checkpoint(checkpoint_path): ''' Create an instance of a resnet model and load a checkpoint ''' output_channels = 4 if args.model_architecture == 'resnet7_2_1': resnet_model = resnet7_2_1( pretrained=True, pretrained_model_path=checkpoint_path, output_channels=output_channels) resnet_model = resnet_model.to(device) ''' Evaluate the model ''' print("***** Evaluating the model *****") eval_results, embeddings, labels_raw = main_utils.evaluate( args, device, resnet_model) print("***** Finished evaluation *****") return eval_results, embeddings, labels_raw eval_results, _, _ = run_eval_on_checkpoint( checkpoint_path=args.checkpoint_path) results_path = os.path.join(args.output_dir, 'eval_results.json') with open(results_path, 'w') as fp: json.dump(eval_results, fp)