def get_model(model_path, model_type='UNet11', problem_type='binary'): """ :param model_path: :param model_type: 'UNet', 'UNet16', 'UNet11', 'LinkNet34', 'AlbuNet' :param problem_type: 'binary', 'parts', 'instruments' :return: """ if problem_type == 'binary': num_classes = 1 elif problem_type == 'parts': num_classes = 4 elif problem_type == 'instruments': num_classes = 8 if model_type == 'UNet16': model = UNet16(num_classes=num_classes) elif model_type == 'UNet11': model = UNet11(num_classes=num_classes) elif model_type == 'LinkNet34': model = LinkNet34(num_classes=num_classes) elif model_type == 'AlbuNet': model = AlbuNet(num_classes=num_classes) elif model_type == 'UNet': model = UNet(num_classes=num_classes) state = torch.load(str(model_path)) state = { key.replace('module.', ''): value for key, value in state['model'].items() } model.load_state_dict(state) # if torch.cuda.is_available(): # return model.cuda() return model
import torch from torch.utils.data import DataLoader import torch.optim as optim import albumentations as albu from dataloaders import SIIMDataset from losses import ComboLoss, soft_dice_loss from models import AlbuNet from MaskBinarizers import TripletMaskBinarization import numpy as np import cv2 device = torch.device('cuda:0') model = AlbuNet().to(device) checkpoint_path = "../checkpoints/seg-9.pth" model.load_state_dict(torch.load(checkpoint_path)) trainloader = DataLoader( SIIMDataset("../dataset"), batch_size=4, num_workers=4, pin_memory=True, shuffle=True, ) optimizer = optim.Adam(model.parameters(), lr=0.00001) criterion = ComboLoss({'bce': 3, 'dice': 1, 'focal': 4}) # for idx, batch in enumerate(trainloader): # images, masks = batch # images = images.to(device).type(torch.float32)
def main(): parser = argparse.ArgumentParser() arg = parser.add_argument arg('--jaccard-weight', default=0.5, type=float) arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs') arg('--fold', type=int, help='fold', default=0) arg('--root', default='runs/debug', help='checkpoint root') arg('--batch-size', type=int, default=1) arg('--n-epochs', type=int, default=100) arg('--lr', type=float, default=0.0001) arg('--workers', type=int, default=12) arg('--type', type=str, default='binary', choices=['binary', 'parts', 'instruments']) arg('--model', type=str, default='UNet', choices=['UNet', 'UNet11', 'LinkNet34', 'AlbuNet']) args = parser.parse_args() root = Path(args.root) root.mkdir(exist_ok=True, parents=True) if args.type == 'parts': num_classes = 4 elif args.type == 'instruments': num_classes = 8 else: num_classes = 1 if args.model == 'UNet': model = UNet(num_classes=num_classes) elif args.model == 'UNet11': model = UNet11(num_classes=num_classes, pretrained=True) elif args.model == 'UNet16': model = UNet16(num_classes=num_classes, pretrained=True) elif args.model == 'LinkNet34': model = LinkNet34(num_classes=num_classes, pretrained=True) elif args.model == 'AlbuNet': model = AlbuNet(num_classes=num_classes, pretrained=True) else: model = UNet(num_classes=num_classes, input_channels=3) if torch.cuda.is_available(): if args.device_ids: device_ids = list(map(int, args.device_ids.split(','))) else: device_ids = None model = nn.DataParallel(model, device_ids=device_ids).cuda() if args.type == 'binary': loss = LossBinary(jaccard_weight=args.jaccard_weight) else: loss = LossMulti(num_classes=num_classes, jaccard_weight=args.jaccard_weight) cudnn.benchmark = True def make_loader(file_names, shuffle=False, transform=None, problem_type='binary', batch_size=1): return DataLoader(dataset=SaltDataset(file_names, transform=transform, problem_type=problem_type), shuffle=shuffle, num_workers=args.workers, batch_size=batch_size, pin_memory=torch.cuda.is_available()) train_file_names, val_file_names = get_split(args.fold) print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names))) def train_transform(p=1): return Compose( [VerticalFlip(p=0.5), HorizontalFlip(p=0.5), Normalize(p=1)], p=p) def val_transform(p=1): return Compose([Normalize(p=1)], p=p) train_loader = make_loader(train_file_names, shuffle=True, transform=train_transform(p=1), problem_type=args.type, batch_size=args.batch_size) valid_loader = make_loader(val_file_names, transform=val_transform(p=1), problem_type=args.type, batch_size=len(device_ids)) root.joinpath('params.json').write_text( json.dumps(vars(args), indent=True, sort_keys=True)) if args.type == 'binary': valid = validation_binary else: valid = validation_multi utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr), args=args, model=model, criterion=loss, train_loader=train_loader, valid_loader=valid_loader, validation=valid, fold=args.fold, num_classes=num_classes)
from MaskBinarizers import TripletMaskBinarization from classify.dataloader import SIIMDataset import numpy as np import cv2 from tqdm import tqdm from efficientnet_pytorch import EfficientNet from sklearn.metrics import classification_report from losses import ComboLoss, soft_dice_loss # model device = torch.device('cuda:0') seg_model = AlbuNet().to(device) seg_model.load_state_dict(torch.load("checkpoints/seg-9.pth")) seg_model.eval() classify_model = EfficientNet.from_pretrained('efficientnet-b2', num_classes=1) classify_model = classify_model.to(device) classify_model.load_state_dict(torch.load("/home/tupm/HDD/projects/FCN_python/pneumothorax/checkpoints/efficientb2-2.pkl")) classify_model.eval() valloader = DataLoader( SIIMDataset("dataset", type="val"), batch_size=1, num_workers=1, pin_memory=True, shuffle=True,