def load_wsinet(args): wsinet = WsiNet(class_num=2, in_channels=args.fea_len, mode=args.fusion_mode) weightspath = os.path.join(args.model_dir, args.cnn_model, args.fusion_mode, args.wsi_cls_name) wsi_weights_dict = torch.load(weightspath, map_location=lambda storage, loc: storage) wsinet.load_state_dict(wsi_weights_dict) wsinet.cuda() wsinet.eval() return wsinet
def load_wsinet(args): wsinet = WsiNet(class_num=args.class_num, in_channels=args.input_fea_num, mode=args.mode) weightspath = os.path.join(args.data_dir, "Models/SlideModels/BestModels", args.model_type, args.mode, args.wsi_cls_name) wsi_weights_dict = torch.load(weightspath, map_location=lambda storage, loc: storage) wsinet.load_state_dict(wsi_weights_dict) wsinet.cuda() wsinet.eval() return wsinet
parser.add_argument('--maxepoch', type=int, default=100, help='number of epochs to train (default: 10)') parser.add_argument('--fusion_mode', type=str, default="pooling") args = parser.parse_args() return args if __name__ == '__main__': args = set_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) # prepare model net = WsiNet(class_num=2, in_channels=args.fea_len, mode=args.fusion_mode) net.cuda() # prepare dataset train_data_dir = os.path.join(args.data_dir, "SlideFeas", args.patch_model, "train") train_dataset = wsiDataSet(train_data_dir, pre_load=args.pre_load, testing=False) val_data_dir = os.path.join(args.data_dir, "SlideFeas", args.patch_model, "val") val_dataset = wsiDataSet(val_data_dir, pre_load=args.pre_load, testing=True) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size,
parser.add_argument("--mode", type=str, default="pooling") parser.add_argument("--class_num", type=int, default=3) parser.add_argument("--pre_load", action='store_true', default=True) parser.add_argument('--verbose', action='store_true') args = parser.parse_args() return args if __name__ == '__main__': args = set_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) # Model preparetion net = WsiNet(class_num=args.class_num, in_channels=args.input_fea_num, mode=args.mode) net.cuda() # Dataset preparetion train_data_root = os.path.join(args.data_dir, "Feas", args.model_type, "train") val_data_root = os.path.join(args.data_dir, "Feas", args.model_type, "val") # create dataset train_dataset = ThyroidDataSet(train_data_root, testing=False, pre_load=args.pre_load) val_dataset = ThyroidDataSet(val_data_root, testing=True, testing_num=128, pre_load=args.pre_load)
parser.add_argument('--save_overlay', action='store_true', default=True) args = parser.parse_args() return args if __name__ == "__main__": args = set_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id) # load patch cls model patch_model = torch.load(args.patch_model) patch_model.cuda() patch_model.eval() # load wsi cls model wsinet = WsiNet(class_num=2, in_channels=2048, mode=args.fusion_mode) wsi_weights_path = os.path.join(args.wsi_model_dir, args.fusion_mode, args.wsi_model_name) wsi_weights_dict = torch.load(wsi_weights_path, map_location=lambda storage, loc: storage) wsinet.load_state_dict(wsi_weights_dict) wsinet.cuda() wsinet.eval() # test slide test_slide_list = [ele for ele in os.listdir(args.img_dir) if "jpg" in ele] total_num = len(test_slide_list) if args.save_overlay: overlay_save_dir = os.path.join(os.path.dirname(args.img_dir), "overlay"+str(args.test_patch_num)) pydaily.filesystem.overwrite_dir(overlay_save_dir) correct_num = 0 for ind, test_slide in enumerate(test_slide_list):