コード例 #1
0
def get_model(model_path, model_type='UNet11', problem_type='binary'):
    """

    :param model_path:
    :param model_type: 'UNet', 'UNet16', 'UNet11', 'LinkNet34', 'AlbuNet'
    :param problem_type: 'binary', 'parts', 'instruments'
    :return:
    """
    if problem_type == 'binary':
        num_classes = 1
    elif problem_type == 'parts':
        num_classes = 4
    elif problem_type == 'instruments':
        num_classes = 8

    if model_type == 'UNet16':
        model = UNet16(num_classes=num_classes)
    elif model_type == 'UNet11':
        model = UNet11(num_classes=num_classes)
    elif model_type == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes)
    elif model_type == 'AlbuNet':
        model = AlbuNet(num_classes=num_classes)
    elif model_type == 'UNet':
        model = UNet(num_classes=num_classes)

    state = torch.load(str(model_path))
    state = {
        key.replace('module.', ''): value
        for key, value in state['model'].items()
    }
    model.load_state_dict(state)

    # if torch.cuda.is_available():
    #    return model.cuda()

    return model
コード例 #2
0
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import albumentations as albu

from dataloaders import SIIMDataset
from losses import ComboLoss, soft_dice_loss
from models import AlbuNet
from MaskBinarizers import TripletMaskBinarization
import numpy as np
import cv2

device = torch.device('cuda:0')
model = AlbuNet().to(device)
checkpoint_path = "../checkpoints/seg-9.pth"
model.load_state_dict(torch.load(checkpoint_path))

trainloader = DataLoader(
    SIIMDataset("../dataset"),
    batch_size=4,
    num_workers=4,
    pin_memory=True,
    shuffle=True,
)

optimizer = optim.Adam(model.parameters(), lr=0.00001)
criterion = ComboLoss({'bce': 3, 'dice': 1, 'focal': 4})

# for idx, batch in enumerate(trainloader):
#     images, masks = batch
#     images = images.to(device).type(torch.float32)
コード例 #3
0
ファイル: train.py プロジェクト: tinve/kaggle_salt
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=0.5, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=1)
    arg('--n-epochs', type=int, default=100)
    arg('--lr', type=float, default=0.0001)
    arg('--workers', type=int, default=12)
    arg('--type',
        type=str,
        default='binary',
        choices=['binary', 'parts', 'instruments'])
    arg('--model',
        type=str,
        default='UNet',
        choices=['UNet', 'UNet11', 'LinkNet34', 'AlbuNet'])

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    if args.type == 'parts':
        num_classes = 4
    elif args.type == 'instruments':
        num_classes = 8
    else:
        num_classes = 1

    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained=True)
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained=True)
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'AlbuNet':
        model = AlbuNet(num_classes=num_classes, pretrained=True)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()

    if args.type == 'binary':
        loss = LossBinary(jaccard_weight=args.jaccard_weight)
    else:
        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names,
                    shuffle=False,
                    transform=None,
                    problem_type='binary',
                    batch_size=1):
        return DataLoader(dataset=SaltDataset(file_names,
                                              transform=transform,
                                              problem_type=problem_type),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=batch_size,
                          pin_memory=torch.cuda.is_available())

    train_file_names, val_file_names = get_split(args.fold)

    print('num train = {}, num_val = {}'.format(len(train_file_names),
                                                len(val_file_names)))

    def train_transform(p=1):
        return Compose(
            [VerticalFlip(p=0.5),
             HorizontalFlip(p=0.5),
             Normalize(p=1)], p=p)

    def val_transform(p=1):
        return Compose([Normalize(p=1)], p=p)

    train_loader = make_loader(train_file_names,
                               shuffle=True,
                               transform=train_transform(p=1),
                               problem_type=args.type,
                               batch_size=args.batch_size)
    valid_loader = make_loader(val_file_names,
                               transform=val_transform(p=1),
                               problem_type=args.type,
                               batch_size=len(device_ids))

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=valid,
                fold=args.fold,
                num_classes=num_classes)
コード例 #4
0
ファイル: main.py プロジェクト: tupm2208/pneumothorax
from MaskBinarizers import TripletMaskBinarization
from classify.dataloader import SIIMDataset
import numpy as np
import cv2
from tqdm import tqdm

from efficientnet_pytorch import EfficientNet
from sklearn.metrics import classification_report
from losses import ComboLoss, soft_dice_loss

# model



device = torch.device('cuda:0')
seg_model = AlbuNet().to(device)
seg_model.load_state_dict(torch.load("checkpoints/seg-9.pth"))
seg_model.eval()

classify_model = EfficientNet.from_pretrained('efficientnet-b2', num_classes=1)
classify_model = classify_model.to(device)
classify_model.load_state_dict(torch.load("/home/tupm/HDD/projects/FCN_python/pneumothorax/checkpoints/efficientb2-2.pkl"))
classify_model.eval()


valloader = DataLoader(
    SIIMDataset("dataset", type="val"),
    batch_size=1,
    num_workers=1,
    pin_memory=True,
    shuffle=True,