Exemplo n.º 1
0
def train(args):
    model = UNet(3, 3).to(device)
    batch_size = args.batch_size
    criterion = nn.BCEWithLogitsLoss()
    # criterion = DiceLoss()
    optimizer = optim.Adam(model.parameters())
    verse_data = DatasetVerse(dir_img,
                              dir_mask,
                              transform=x_transform,
                              target_transform=y_transform)
    dataloader = DataLoader(verse_data,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=4)
    train_model(model, criterion, optimizer, dataloader)
Exemplo n.º 2
0
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 print(device)  #GPU能否使用
 model = UNet(1, 1).to(device)
 #model = NestedUNet(1, 1).to(device)
 # model.load_state_dict(torch.load("pretrain/weights_80.pth",map_location='cpu'))
 # 设置随机数种子,保证复现能力
 torch.backends.cudnn.deterministic = True
 random.seed(1)
 torch.manual_seed(1)
 torch.cuda.manual_seed(1)
 np.random.seed(1)
 batch_size = 2
 learning_rate = 0.001
 criterion = torch.nn.BCELoss()
 optimizer = optim.Adam([{
     'params': model.parameters(),
     'initial_lr': learning_rate
 }],
                        lr=learning_rate)
 scheduler = lr_scheduler.StepLR(
     optimizer, step_size=10, gamma=0.8,
     last_epoch=0)  # 每10个epoch衰减0.8,注意last_epoch的设置!!
 x_transform = T.Compose([T.ToTensor(), T.Normalize([0.5], [0.5])])
 y_transform = T.ToTensor()
 cell_dataset = CellDataset1('dataset/dataset1/train/',
                             'dataset/dataset1/train_GT/SEG/',
                             transform=x_transform,
                             target_transform=y_transform)
 #cell_dataset = CellDataset2('dataset/dataset2/train/', 'dataset/dataset2/train_GT/SEG/', transform=x_transform,
 #                           target_transform=y_transform) #对应数据集2
 train_size = int(0.8 * len(cell_dataset))  #划分训练集和验证集,8:2
Exemplo n.º 3
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.saveTest == 'True':
        args.saveTest = True
    elif args.saveTest == 'False':
        args.saveTest = False

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    cudnn.benchmark = True

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.TenCrop(args.resizedImageSize),
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
            #transforms.Lambda(lambda normalized: torch.stack([transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])(crop) for crop in normalized]))
            #transforms.RandomResizedCrop(224, interpolation=Image.NEAREST),
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomVerticalFlip(),
            #transforms.ToTensor(),
        ]),
        'test': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.ToTensor(),
            #transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])
        ]),
    }

    # Data Loading
    data_dir = 'datasets/miccaiSegRefined'
    # json path for class definitions
    json_path = 'datasets/miccaiSegClasses.json'

    image_datasets = {x: miccaiSegDataset(os.path.join(data_dir, x), data_transforms[x],
                        json_path) for x in ['train', 'test']}

    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)
                  for x in ['train', 'test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}

    # Get the dictionary for the id and RGB value pairs for the dataset
    classes = image_datasets['train'].classes
    key = utils.disentangleKey(classes)
    num_classes = len(key)

    # Initialize the model
    model = UNet(num_classes)

    # # Optionally resume from a checkpoint
    # if args.resume:
    #     if os.path.isfile(args.resume):
    #         print("=> loading checkpoint '{}'".format(args.resume))
    #         checkpoint = torch.load(args.resume)
    #         #args.start_epoch = checkpoint['epoch']
    #         pretrained_dict = checkpoint['state_dict']
    #         pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
    #         model.state_dict().update(pretrained_dict)
    #         model.load_state_dict(model.state_dict())
    #         print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    #     else:
    #         print("=> no checkpoint found at '{}'".format(args.resume))
    #
    #     # # Freeze the encoder weights
    #     # for param in model.encoder.parameters():
    #     #     param.requires_grad = False
    #
    #     optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)
    # else:
    optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)

    # Load the saved model
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # Use a learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    # Initialize an evaluation Object
    evaluator = utils.Evaluate(key, use_gpu)

    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        print('>>>>>>>>>>>>>>>>>>>>>>>Training<<<<<<<<<<<<<<<<<<<<<<<')
        train(dataloaders['train'], model, criterion, optimizer, scheduler, epoch, key)

        # Evaulate on validation set

        print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<')
        validate(dataloaders['test'], model, criterion, epoch, key, evaluator)

        # Calculate the metrics
        print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<')
        IoU = evaluator.getIoU()
        print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU))
        PRF1 = evaluator.getPRF1()
        precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2]
        print('Mean Precision: {}, Class-wise Precision: {}'.format(torch.mean(precision), precision))
        print('Mean Recall: {}, Class-wise Recall: {}'.format(torch.mean(recall), recall))
        print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))
        evaluator.reset()

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=os.path.join(args.save_dir, 'checkpoint_{}.tar'.format(epoch)))
Exemplo n.º 4
0
def train(args):
    print("Traning")

    print("Prepaing data")
    masks = pd.read_csv(os.path.join(args.dataset_dir, args.train_masks))
    unique_img_ids = get_unique_img_ids(masks, args)
    train_df, valid_df = get_balanced_train_valid(masks, unique_img_ids, args)

    if args.stage == 0:
        train_shape = (256, 256)
        batch_size = args.stage0_batch_size
        extra_epoch = args.stage0_epochs
    elif args.stage == 1:
        train_shape = (384, 384)
        batch_size = args.stage1_batch_size
        extra_epoch = args.stage1_epochs
    elif args.stage == 2:
        train_shape = (512, 512)
        batch_size = args.stage2_batch_size
        extra_epoch = args.stage2_epochs
    elif args.stage == 3:
        train_shape = (768, 768)
        batch_size = args.stage3_batch_size
        extra_epoch = args.stage3_epochs

    print("Stage {}".format(args.stage))

    train_transform = DualCompose([
        Resize(train_shape),
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        Shift(),
        Transpose(),
        # ImageOnly(RandomBrightness()),
        # ImageOnly(RandomContrast()),
    ])
    val_transform = DualCompose([
        Resize(train_shape),
    ])

    train_dataloader = make_dataloader(train_df,
                                       args,
                                       batch_size,
                                       args.shuffle,
                                       transform=train_transform)
    val_dataloader = make_dataloader(valid_df,
                                     args,
                                     batch_size // 2,
                                     args.shuffle,
                                     transform=val_transform)

    # Build model
    model = UNet()
    optimizer = Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=args.decay_fr, gamma=0.1)
    if args.gpu and torch.cuda.is_available():
        model = model.cuda()

    # Restore model ...
    run_id = 4

    model_path = Path('model_{run_id}.pt'.format(run_id=run_id))
    if not model_path.exists() and args.stage > 0:
        raise ValueError(
            'model_{run_id}.pt does not exist, initial train first.'.format(
                run_id=run_id))
    if model_path.exists():
        state = torch.load(str(model_path))
        last_epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        print('Restore model, epoch {}, step {:,}'.format(last_epoch, step))
    else:
        last_epoch = 1
        step = 0

    log_file = open('train_{run_id}.log'.format(run_id=run_id),
                    'at',
                    encoding='utf8')

    loss_fn = LossBinary(jaccard_weight=args.iou_weight)

    valid_losses = []

    print("Start training ...")
    for _ in range(last_epoch):
        scheduler.step()

    for epoch in range(last_epoch, last_epoch + extra_epoch):
        scheduler.step()
        model.train()
        random.seed()
        tq = tqdm(total=len(train_dataloader) * batch_size)
        tq.set_description('Run Id {}, Epoch {} of {}, lr {}'.format(
            run_id, epoch, last_epoch + extra_epoch,
            args.lr * (0.1**(epoch // args.decay_fr))))
        losses = []
        try:
            mean_loss = 0.
            for i, (inputs, targets) in enumerate(train_dataloader):
                inputs, targets = torch.tensor(inputs), torch.tensor(targets)
                if args.gpu and torch.cuda.is_available():
                    inputs = inputs.cuda()
                    targets = targets.cuda()

                outputs = model(inputs)
                loss = loss_fn(outputs, targets)
                loss.backward()
                optimizer.step()

                step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-args.log_fr:])
                tq.set_postfix(loss="{:.5f}".format(mean_loss))

                if i and (i % args.log_fr) == 0:
                    write_event(log_file, step, loss=mean_loss)
            write_event(log_file, step, loss=mean_loss)
            tq.close()
            save_model(model, epoch, step, model_path)

            valid_metrics = validation(args, model, loss_fn, val_dataloader)
            write_event(log_file, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)

        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save_model(model, epoch, step, model_path)
            print('Terminated.')
    print('Done.')
Exemplo n.º 5
0
def main():
    args = get_args()

    # set GPU device
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu  # default: '0'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set model
    model = UNet(n_channels=1, n_classes=1).to(device)
    if len(args.gpu) > 1:  # if multi-gpu
        model = torch.nn.DataParallel(model)
    """set img size
        - UNet type architecture require input img size be divisible by 2^N,
        - Where N is the number of the Max Pooling layers (in the Vanila UNet N = 5)
    """

    img_size = args.img_size  #default: 512

    # set transforms for dataset
    import torchvision.transforms as transforms
    from my_transforms import RandomHorizontalFlip, RandomVerticalFlip, ColorJitter, GrayScale, Resize, ToTensor
    train_transforms = transforms.Compose([
        #Data Augmentations
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        #shear
        #rotation
        #scale
        #transformations to fit in Network
        GrayScale(),
        Resize(img_size),
        ToTensor(),
    ])
    eval_transforms = transforms.Compose(
        [GrayScale(), Resize(img_size),
         ToTensor()])

    # set Dataset and DataLoader
    train_dataset = LungSegDataset(transforms=train_transforms)
    val_dataset = LungSegDataset(split='val', transforms=eval_transforms)
    test_dataset = LungSegDataset(split='test', transforms=eval_transforms)

    from torch.utils.data import DataLoader
    dataloader = {
        'train':
        DataLoader(dataset=train_dataset,
                   batch_size=args.batch_size,
                   num_workers=args.n_workers,
                   shuffle=True),
        'val':
        DataLoader(dataset=val_dataset,
                   batch_size=args.batch_size,
                   num_workers=args.n_workers),
        'test':
        DataLoader(dataset=test_dataset,
                   batch_size=args.batch_size,
                   num_workers=args.n_workers)
    }

    # checkpoint dir
    checkpoint_dir = os.path.join(os.getcwd(), 'checkpoint')
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)
    checkpoint_path = args.load_model

    # set optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)

    # learning rate scheduler
    from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
    # scheduler = StepLR(optimizer, step_size = 3 , gamma = 0.8)
    ## option 2.
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3)

    # # set criterion
    # if model.n_classes > 1:
    #     criterion = nn.CrossEntropyLoss()
    # else:
    #     criterion = nn.BCEWithLogitsLoss()
    criterion = nn.BCEWithLogitsLoss()

    train_and_validate(net=model,
                       criterion=criterion,
                       optimizer=optimizer,
                       dataloader=dataloader,
                       device=device,
                       epochs=args.epochs,
                       scheduler=scheduler,
                       load_model=checkpoint_path)
Exemplo n.º 6
0
def main(args):
    dataset_kwargs = {
        'transforms': {},
        'max_length': None,
        'sensor_resolution': None,
        'preload_events': False,
        'num_bins': 16,
        'voxel_method': {
            'method': 'random_k_events',
            'k': 60000,
            't': 0.5,
            'sliding_window_w': 500,
            'sliding_window_t': 0.1
        }
    }

    unet_kwargs = {
        'base_num_channels': 32,  # written as '64' in EVFlowNet tf code
        'num_encoders': 4,
        'num_residual_blocks': 2,  # transition
        'num_output_channels': 2,  # (x, y) displacement
        'skip_type': 'concat',
        'norm': None,
        'use_upsample_conv': True,
        'kernel_size': 3,
        'channel_multiplier': 2,
        'num_bins': 16
    }

    torch.autograd.set_detect_anomaly(True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ev_loader = EventDataLoader(args.h5_file_path,
                                batch_size=1,
                                num_workers=6,
                                shuffle=True,
                                pin_memory=True,
                                dataset_kwargs=dataset_kwargs)

    H, W = ev_loader.H, ev_loader.W

    model = UNet(unet_kwargs)
    model = model.to(device)
    model.train()
    crop = CropParameters(W, H, 4)

    print("=== Let's use", torch.cuda.device_count(), "GPUs!")
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-5,
                                 betas=(0.9, 0.999))
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01)
    # raise
    # tmp_voxel = crop.pad(torch.randn(1, 9, H, W).to(device))
    # F, P = profile(model, inputs=(tmp_voxel, ))

    for idx in range(10):
        # for i, item in enumerate(tqdm(ev_loader)):
        for i, item in enumerate(ev_loader):

            events = item['events']
            voxel = item['voxel'].to(device)
            voxel = crop.pad(voxel)

            model.zero_grad()
            optimizer.zero_grad()

            flow = model(voxel) * 10

            flow = torch.clamp(flow, min=-40, max=40)
            loss = compute_loss(events, flow)
            loss.backward()

            # cvshow_voxel_grid(voxel.squeeze()[0:2].cpu().numpy())
            # raise
            optimizer.step()

            if i % 10 == 0:
                print(
                    idx,
                    i,
                    '\t',
                    "{0:.2f}".format(loss.data.item()),
                    "{0:.2f}".format(torch.max(flow[0, 0]).item()),
                    "{0:.2f}".format(torch.min(flow[0, 0]).item()),
                    "{0:.2f}".format(torch.max(flow[0, 1]).item()),
                    "{0:.2f}".format(torch.min(flow[0, 1]).item()),
                )

                xs, ys, ts, ps = events
                print_voxel = voxel[0].sum(axis=0).cpu().numpy()
                print_flow = flow[0].clone().detach().cpu().numpy()
                print_co = warp_events_with_flow_torch(
                    (xs[0][ps[0] == 1], ys[0][ps[0] == 1], ts[0][ps[0] == 1],
                     ps[0][ps[0] == 1]),
                    flow[0].clone().detach(),
                    sensor_size=(H, W))
                print_co = crop.pad(print_co)
                print_co = print_co.cpu().numpy()

                cvshow_all(idx=idx * 10000 + i,
                           voxel=print_voxel,
                           flow=flow[0].clone().detach().cpu().numpy(),
                           frame=None,
                           compensated=print_co)
Exemplo n.º 7
0
import time

num_epoches = 400
batch_size = 12
data_dir = "/userhome/Unet/unet/data/"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_dataloader, val_dataloader = create_dataset(data_dir,
                                                  repeat=1,
                                                  train_batch_size=12,
                                                  augment=True)

model = UNet(1, 2).to(device)
criterion = CrossEntropyWithLogits().to(device)
optimizer = Adam(model.parameters(), lr=0.0001, weight_decay=0.0005, eps=1e-08)

save_step = 200

##test data load time
# print("get-100-epoch")
# load_s = time.time()
# for i in range(2):
#     for sample in train_dataloader:
#         print(sample["image"].shape)
#         print(sample["mask"].shape)
# load_e = time.time()
# print("load data time: ", load_e - load_s)

# TODO: Initialization the params
val_loss = -1
Exemplo n.º 8
0
def train():
    # prepare the dataloader
    device = torch.device(args.devices if torch.cuda.is_available() else "cpu")
    #dataset = Training_Dataset(args.image_dir, (args.image_size, args.image_size), (args.noise,args.noise_param))
    # dataset = HongZhang_Dataset("/data_1/data/Noise2Noise/shenqingbiao/0202", "/data_1/data/Noise2Noise/hongzhang")
    # dataset = HongZhang_Dataset2("/data_1/data/红章图片", (256, 256))
    dataset = HongZhang_Dataset3("/data_1/data/红章图片/6_12", (256, 256))
    dataset_length = len(dataset)
    train_loader = DataLoader(dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=4)

    # choose the model
    if args.model == "unet":
        model = UNet(in_channels=args.image_channels,
                     out_channels=args.image_channels)
    elif args.model == "srresnet":
        model = SRResnet(args.image_channels, args.image_channels)
    elif args.model == "eesp":
        model = EESPNet_Seg(args.image_channels, 2)
    else:
        model = UNet(in_channels=args.image_channels,
                     out_channels=args.image_channels)
    model = model.to(device)

    # choose the loss type
    if args.loss == "l2":
        criterion = nn.MSELoss()
    elif args.loss == "l1":
        criterion = nn.L1Loss()
    elif args.loss == "ssim":
        criterion = SSIM()

    # resume the mode if needed
    if args.resume_model:
        resume_model(model, args.resume_model)

    optim = Adam(model.parameters(),
                 lr=args.lr,
                 betas=(0.9, 0.999),
                 eps=1e-8,
                 weight_decay=0,
                 amsgrad=True)
    #scheduler = lr_scheduler.StepLR(optim, step_size=args.scheduler_step, gamma=0.5)
    scheduler = lr_scheduler.MultiStepLR(optim, milestones=[20, 40], gamma=0.1)
    model.train()
    print(model)

    # start to train
    print("Starting Training Loop...")
    since = time.time()
    for epoch in range(args.epochs):
        print('Epoch {}/{}'.format(epoch, args.epochs - 1))
        print('-' * 10)
        running_loss = 0.0
        scheduler.step()
        for batch_idx, (target, source) in enumerate(train_loader):
            source = source.to(device)
            target = target.to(device)
            denoised_source = model(source)
            if args.loss == "ssim":
                loss = 1 - criterion(denoised_source, Variable(target))
            else:
                loss = criterion(denoised_source, Variable(target))
            optim.zero_grad()
            loss.backward()
            optim.step()

            running_loss += loss.item() * source.size(0)
            if batch_idx % args.steps_show == 0:
                print('{}/{} Current loss {}'.format(batch_idx,
                                                     len(train_loader),
                                                     loss.item()))
        epoch_loss = running_loss / dataset_length
        print('{} Loss: {:.4f}'.format('current ' + str(epoch), epoch_loss))
        if (epoch + 1) % args.save_per_epoch == 0:
            save_model(model, epoch + 1)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
Exemplo n.º 9
0
                batch_size=batch_size,num_workers=num_workers,pin_memory=True,shuffle=True)
val_loader = DataLoader(MRBrainSDataset(defualt_path, split='val', is_transform=True, \
                img_norm=True, augmentations=Compose([Scale(224)])), \
                batch_size=1,num_workers=num_workers,pin_memory=True,shuffle=False)

# Setup Model and summary
model = UNet().to(device)
summary(model, (3, 224, 224), batch_size)  # summary 网络参数
# model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

# 需要学习的参数
# base_learning_list = list(filter(lambda p: p.requires_grad, model.base_net.parameters()))
# learning_list = model.parameters()

# 优化器以及学习率设置
optimizer = torch.optim.SGD(model.parameters(),
                            lr=learning_rate,
                            momentum=momentum,
                            weight_decay=weight_decay)
# learning rate调节器
scheduler = lr_scheduler.MultiStepLR(optimizer,
                                     milestones=[
                                         int(0.2 * end_epoch),
                                         int(0.6 * end_epoch),
                                         int(0.9 * end_epoch)
                                     ],
                                     gamma=0.01)
# scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',patience=10, verbose=True)
criterion = cross_entropy2d
# criterion = BCEDiceLoss()
Exemplo n.º 10
0
class BaseModel:
    losses = {'train': [], 'val': []}
    acces = {'train': [], 'val': []}
    scores = {'train': [], 'val': []}
    pred = {'train': [], 'val': []}
    true = {'train': [], 'val': []}

    def __init__(self, args):
        self.args = args
        self.net = None
        print(args.model_name)
        if args.model_name == 'UNet':
            self.net = UNet(args.in_channels, args.num_classes)
            self.net.apply(weights_init)
        elif args.model_name == 'UNetResNet34':
            self.net = UNetResNet34(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNetResNet152':
            self.net = UNetResNet152(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNet11':
            self.net = UNet11(args.num_classes, pretrained=True)
        elif args.model_name == 'UNetVGG16':
            self.net = UNetVGG16(args.num_classes,
                                 pretrained=True,
                                 dropout_2d=0.0,
                                 is_deconv=True)
        elif args.model_name == 'deeplab50_v2':
            if args.ms:
                raise NotImplemented
            else:
                self.net = deeplab50_v2(args.num_classes,
                                        pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v2':
            if args.ms:
                self.net = ms_deeplab_v2(args.num_classes,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v2(args.num_classes,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3':
            if args.ms:
                self.net = ms_deeplab_v3(args.num_classes,
                                         out_stride=args.out_stride,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v3(args.num_classes,
                                      out_stride=args.out_stride,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3_plus':
            if args.ms:
                self.net = ms_deeplab_v3_plus(args.num_classes,
                                              out_stride=args.out_stride,
                                              pretrained=args.pretrained,
                                              scales=args.ms_scales)
            else:
                self.net = deeplab_v3_plus(args.num_classes,
                                           out_stride=args.out_stride,
                                           pretrained=args.pretrained)

        self.interp = nn.Upsample(size=args.size, mode='bilinear')

        self.iterations = args.epochs
        self.lr_current = args.lr
        self.cuda = args.cuda
        self.phase = args.phase
        self.lr_policy = args.lr_policy
        self.cyclic_m = args.cyclic_m
        if self.lr_policy == 'cyclic':
            print('using cyclic')
            assert self.iterations % self.cyclic_m == 0
        if args.loss == 'CELoss':
            self.criterion = nn.CrossEntropyLoss(size_average=True)
        elif args.loss == 'DiceLoss':
            self.criterion = DiceLoss(num_classes=args.num_classes)
        elif args.loss == 'MixLoss':
            self.criterion = MixLoss(args.num_classes,
                                     weights=args.loss_weights)
        elif args.loss == 'LovaszLoss':
            self.criterion = LovaszSoftmax(per_image=args.loss_per_img)
        elif args.loss == 'FocalLoss':
            self.criterion = FocalLoss(args.num_classes, alpha=None, gamma=2)
        else:
            raise RuntimeError('must define loss')

        if 'deeplab' in args.model_name:
            self.optimizer = optim.SGD(
                [{
                    'params': get_1x_lr_params_NOscale(self.net),
                    'lr': args.lr
                }, {
                    'params': get_10x_lr_params(self.net),
                    'lr': 10 * args.lr
                }],
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
        else:
            self.optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                              self.net.parameters()),
                                       lr=args.lr,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay)
        self.iters = 0
        self.best_val = 0.0
        self.count = 0

    def init_model(self):
        if self.args.resume_model:
            saved_state_dict = torch.load(
                self.args.resume_model,
                map_location=lambda storage, loc: storage)
            if self.args.ms:
                new_params = self.net.Scale.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if not (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                             == 'decoder'):
                        new_params[i] = saved_state_dict[i]
                self.net.Scale.load_state_dict(new_params)
            else:
                new_params = self.net.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                         == 'decoder'):
                        # if not i_parts[0] == 'layer5':
                        new_params[i] = saved_state_dict[i]
                self.net.load_state_dict(new_params)

            print('Resuming training, image net loading {}...'.format(
                self.args.resume_model))
            # self.load_weights(self.net, self.args.resume_model)

        if self.args.mGPUs:
            self.net = nn.DataParallel(self.net)

        if self.args.cuda:
            self.net = self.net.cuda()
            cudnn.benchmark = True

    def _adjust_learning_rate(self, epoch):
        """Sets the learning rate to the initial LR decayed by 10 at every specified step
        # Adapted from PyTorch Imagenet example:
        # https://github.com/pytorch/examples/blob/master/imagenet/main.py
        """
        if epoch < int(self.iterations * 0.5):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-4)
        elif epoch < int(self.iterations * 0.85):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-5)
        else:
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-6)
        self.optimizer.param_groups[0]['lr'] = self.lr_current
        self.optimizer.param_groups[1]['lr'] = self.lr_current * 10

    def save_network(self, net, net_name, epoch, label=''):
        save_fname = '%s_%s_%s.pth' % (epoch, net_name, label)
        save_path = os.path.join(self.args.save_folder, self.args.exp_name,
                                 save_fname)
        torch.save(net.state_dict(), save_path)

    def load_weights(self, net, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            net.load_state_dict(
                torch.load(base_file,
                           map_location=lambda storage, loc: storage))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')

    def load_trained_model(self):
        path = os.path.join(self.args.save_folder, self.args.exp_name,
                            self.args.trained_model)
        print('eval cls, image net loading {}...'.format(path))
        if self.args.ms:
            self.load_weights(self.net.Scale, path)
        else:
            self.load_weights(self.net, path)

    def eval(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        output = []

        for i, image in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            output.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
        return np.array(output)

    def tta(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return np.argmax(results, 1)

    def tta_output(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return results

    def test_val(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        predict = []
        true = []
        t1 = time.time()

        for i, (image, mask) in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
                label_image = Variable(mask.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)
                label_image = Variable(mask, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != label_image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            if self.args.aug == 'heng':
                out = out[:, :, 11:11 + 202, 11:11 + 202]
            predict.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
            # predict.extend([pred[1, :101, :101].data.cpu().numpy() for pred in out])
            # pred.extend(out.data.cpu().numpy())
            true.extend(label_image.data.cpu().numpy())
        # pred_all = np.argmax(np.array(pred), 1)
        for t in np.arange(0.05, 0.51, 0.01):
            pred_all = np.array(predict) > t
            true_all = np.array(true).astype(np.int)
            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)
            mean_iou, iou_t = mIoU(true_all, pred_all)
            print('threshold : {:.4f}'.format(t))
            print('mean IoU : {:.4f}, IoU threshold : {:.4f}'.format(
                mean_iou, iou_t))

        return predict, true

    def run_epoch(self, dataloader, writer, epoch, train=True, metrics=True):
        if train:
            self.net.train()
            flag = 'train'
        else:
            self.net.eval()
            flag = 'val'
        t2 = time.time()
        for image, mask in dataloader:
            if train and self.lr_policy != 'step':
                adjust_learning_rate(self.args.lr, self.optimizer, self.iters,
                                     self.iterations * len(dataloader), 0.9,
                                     self.cyclic_m, self.lr_policy)
                self.iters += 1

            if self.cuda:
                image = Variable(image.cuda(), volatile=(not train))
                label_image = Variable(mask.cuda(), volatile=(not train))
            else:
                image = Variable(image, volatile=(not train))
                label_image = Variable(mask, volatile=(not train))
            # cls forward
            out = self.net(image)

            if isinstance(out, list):
                out_max = None
                loss = 0.0
                for i, out_scale in enumerate(out):
                    if out_scale.size(2) != label_image.size(2):
                        out_scale = self.interp(out_scale)
                    if i == (len(out) - 1):
                        out_max = out_scale
                    loss += self.criterion(out_scale, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out_max.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            else:
                if out.size(-1) != label_image.size(-1):
                    out = self.interp(out)

                loss = self.criterion(out, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            if train:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        if metrics:
            n = len(self.losses[flag])
            loss = sum(self.losses[flag]) / n
            scalars = [
                loss,
            ]
            names = [
                'loss',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_loss')

            all_acc = sum(self.acces[flag]) / n
            scalars = [
                all_acc,
            ]
            names = [
                'all_acc',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_acc')

            # all_score = sum(self.scores[flag]) / n
            # scalars = [all_score, ]
            # names = ['all_score', ]
            # write_scalars(writer, scalars, names, epoch, tag=flag + '_score')

            pred_all = np.argmax(np.array(self.pred[flag]), 1)
            true_all = np.array(self.true[flag]).astype(np.int)
            mean_iou, iou_t = mIoU(true_all, pred_all)

            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)

            scalars = [
                mean_iou,
                iou_t,
            ]
            names = [
                'mIoU',
                'mIoU_threshold',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_IoU')

            scalars = [
                self.optimizer.param_groups[0]['lr'],
            ]
            names = [
                'learning_rate',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_lr')

            print(
                '{} loss: {:.4f} | acc: {:.4f} | mIoU: {:.4f} | mIoU_threshold: {:.4f} |  n_iter: {} |  learning_rate: {} | time: {:.2f}'
                .format(flag, loss, all_acc, mean_iou, iou_t, epoch,
                        self.optimizer.param_groups[0]['lr'],
                        time.time() - t2))

            self.losses[flag] = []
            self.pred[flag] = []
            self.true[flag] = []
            self.acces[flag] = []
            self.scores[flag] = []

            if (not train) and (iou_t >= self.best_val):
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(self.net.module.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                else:
                    if self.args.mGPUs:
                        self.save_network(self.net.module,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                print(
                    'val improve from {:.4f} to {:.4f} saving in best val_iteration {}'
                    .format(self.best_val, iou_t, epoch))
                self.best_val = iou_t
                self.count = 0

            if (not train) and (self.best_val - iou_t > 0.003) and (
                    self.count < 10) and (self.lr_policy == 'step'):
                self.count += 1
            if (not train) and (self.count >= 10) and (self.lr_policy
                                                       == 'step'):
                self._adjust_learning_rate(epoch)
                self.count = 0

    def train_val(self, dataloader_train, dataloader_val, writer):
        val_epoch = 0
        for epoch in range(self.iterations):
            if (self.lr_policy == 'cyclic') and (
                    epoch % int(self.iterations / self.cyclic_m) == 0):
                print('-------start cycle {}------------'.format(
                    epoch // int(self.iterations / self.cyclic_m)))
                self.best_val = 0.0
            self.run_epoch(dataloader_train,
                           writer,
                           epoch,
                           train=True,
                           metrics=True)
            self.run_epoch(dataloader_val,
                           writer,
                           val_epoch,
                           train=False,
                           metrics=True)
            val_epoch += 1
            if (epoch + 1) % self.args.save_freq == 0:
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                else:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                print('saving in val_iteration {}'.format(val_epoch))
Exemplo n.º 11
0
def train(datafile):

    # model = ResUNet(n_classes=2)
    model = UNet(n_channels=3, n_classes=2)

    if torch.cuda.is_available():
        model.cuda()
    # criterion = SoftDiceLoss(batch_dice=True)
    criterion_CE = nn.CrossEntropyLoss()
    criterion_SD = SoftDiceLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=LEARNING_RATE,
                                momentum=0.9)

    vis = PytorchVisdomLogger(name="GIANA", port=8080)

    giana_transform, giana_train_loader, giana_valid_loader = giana_data_pipeline(
        datafile)

    for epoch in range(EPOCHS):
        iteration = 0
        for iteration, (images, labels) in enumerate(giana_train_loader):
            # print('TRAIN', images.shape, labels.shape)
            images, labels = giana_transform.apply_transform([images, labels])

            labels_onehot = make_one_hot(labels, 2)
            # for images, labels in giana_pool.imap_unordered(giana_transform.apply_transform, giana_iter):
            if torch.cuda.is_available():
                images = images.cuda()
                labels_onehot = labels_onehot.cuda()

            optimizer.zero_grad()
            model.train()
            predictions = model(images)
            predictions_softmax = F.softmax(predictions, dim=1)

            # loss = 0.75 * criterion_CE(predictions, labels.squeeze().cuda().long()) + 0.25 * criterion_SD(predictions_softmax, labels_onehot)
            loss = criterion_CE(predictions, labels.squeeze().cuda().long())
            # loss = criterion_SD(predictions_softmax, labels_onehot)
            loss.backward()
            optimizer.step()

            # iteration += 1
            if iteration % PRINT_AFTER_ITERATIONS == 0:

                # print('Epoch: {0}, Iteration: {1}, Loss: {2}, Valid dice score: {3}'.format(epoch, iteration, loss, score))
                print('Epoch: {0}, Iteration: {1}, Loss: {2}'.format(
                    epoch, iteration, loss))

                image_args = {'normalize': True, 'range': (0, 1)}
                # viz.show_image_grid(images=images.cpu()[:, 0, ].unsqueeze(1), name='Images_train', image_args=image_args)
                vis.show_image_grid(
                    images=predictions_softmax.cpu()[:, 0, ].unsqueeze(1),
                    name='Predictions_1',
                    image_args=image_args)
                vis.show_image_grid(
                    images=predictions_softmax.cpu()[:, 1, ].unsqueeze(1),
                    name='Predictions_2',
                    image_args=image_args)
                vis.show_image_grid(images=labels.cpu(), name='Ground truth')
                vis.show_value(value=loss.item(),
                               name='Train_Loss',
                               label='Loss',
                               counter=epoch + (iteration / MAX_ITERATIONS))

            if iteration == MAX_ITERATIONS:
                break

        score = model.predict(giana_valid_loader, SCORE_TYPE,
                              MAX_VALIDATION_ITERATIONS, vis)
        vis.show_value(value=np.asarray([score]),
                       name='TestDiceScore',
                       label='Dice',
                       counter=epoch)
        print(
            '\n--------------------------------------------------\nEpoch: {0}, Score: {1}, Loss: {2}\n--------------------------------------------------\n'
            .format(epoch, score, loss))
Exemplo n.º 12
0
def train(train_sources, eval_source):
    path = sys.argv[1]
    dr = DataReader(path, train_sources)
    dr.read()
    print(len(dr.train.x))

    batch_size = 8
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda')

    dataset_s_train = MultiDomainDataset(dr.train.x, dr.train.y, dr.train.vendor, device, DomainAugmentation())
    dataset_s_dev = MultiDomainDataset(dr.dev.x, dr.dev.y, dr.dev.vendor, device)
    dataset_s_test = MultiDomainDataset(dr.test.x, dr.test.y, dr.test.vendor, device)
    loader_s_train = DataLoader(dataset_s_train, batch_size, shuffle=True)

    dr_eval = DataReader(path, [eval_source])
    dr_eval.read()

    dataset_eval_dev = MultiDomainDataset(dr_eval.dev.x, dr_eval.dev.y, dr_eval.dev.vendor, device)
    dataset_eval_test = MultiDomainDataset(dr_eval.test.x, dr_eval.test.y, dr_eval.test.vendor, device)

    dataset_da_train = MultiDomainDataset(dr.train.x+dr_eval.train.x, dr.train.y+dr_eval.train.y, dr.train.vendor+dr_eval.train.vendor, device, DomainAugmentation())
    loader_da_train = DataLoader(dataset_da_train, batch_size, shuffle=True)

    segmentator = UNet()
    discriminator = Discriminator(n_domains=len(train_sources))
    discriminator.to(device)
    segmentator.to(device)

    sigmoid = nn.Sigmoid()
    selector = Selector()

    s_criterion = nn.BCELoss()
    d_criterion = nn.CrossEntropyLoss()
    s_optimizer = optim.AdamW(segmentator.parameters(), lr=0.0001, weight_decay=0.01)
    d_optimizer = optim.AdamW(discriminator.parameters(), lr=0.001, weight_decay=0.01)
    a_optimizer = optim.AdamW(segmentator.encoder.parameters(), lr=0.001, weight_decay=0.01)
    lmbd = 1/150
    s_train_losses = []
    s_dev_losses = []
    d_train_losses = []
    eval_domain_losses = []
    train_dices = []
    dev_dices = []
    eval_dices = []
    epochs = 3
    da_loader_iter = iter(loader_da_train)
    for epoch in tqdm(range(epochs)):
        s_train_loss = 0.0
        d_train_loss = 0.0
        for index, sample in enumerate(loader_s_train):
            img = sample['image']
            target_mask = sample['target']

            da_sample = next(da_loader_iter, None)
            if epoch == 100:
                s_optimizer.defaults['lr'] = 0.001
                d_optimizer.defaults['lr'] = 0.0001
            if da_sample is None:
                da_loader_iter = iter(loader_da_train)
                da_sample = next(da_loader_iter, None)
            if epoch < 50 or epoch >= 100:
                # Training step of segmentator
                predicted_activations, inner_repr = segmentator(img)
                predicted_mask = sigmoid(predicted_activations)
                s_loss = s_criterion(predicted_mask, target_mask)
                s_optimizer.zero_grad()
                s_loss.backward()
                s_optimizer.step()
                s_train_loss += s_loss.cpu().detach().numpy()

            if epoch >= 50:
                # Training step of discriminator
                predicted_activations, inner_repr = segmentator(da_sample['image'])
                predicted_activations = predicted_activations.clone().detach()
                inner_repr = inner_repr.clone().detach()
                predicted_vendor = discriminator(predicted_activations, inner_repr)
                d_loss = d_criterion(predicted_vendor, da_sample['vendor'])
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()
                d_train_loss += d_loss.cpu().detach().numpy()

            if epoch >= 100:
                # adversarial training step
                predicted_mask, inner_repr = segmentator(da_sample['image'])
                predicted_vendor = discriminator(predicted_mask, inner_repr)
                a_loss = -1 * lmbd * d_criterion(predicted_vendor, da_sample['vendor'])
                a_optimizer.zero_grad()
                a_loss.backward()
                a_optimizer.step()
                lmbd += 1/150
        inference_model = nn.Sequential(segmentator, selector, sigmoid)
        inference_model.to(device)
        inference_model.eval()
        d_train_losses.append(d_train_loss / len(loader_s_train))
        s_train_losses.append(s_train_loss / len(loader_s_train))
        s_dev_losses.append(calculate_loss(dataset_s_dev, inference_model, s_criterion, batch_size))
        eval_domain_losses.append(calculate_loss(dataset_eval_dev, inference_model, s_criterion, batch_size))

        train_dices.append(calculate_dice(inference_model, dataset_s_train))
        dev_dices.append(calculate_dice(inference_model, dataset_s_dev))
        eval_dices.append(calculate_dice(inference_model, dataset_eval_dev))

        segmentator.train()

    date_time = datetime.now().strftime("%m%d%Y_%H%M%S")
    model_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "model", "weights", "segmentator"+str(date_time)+".pth")
    torch.save(segmentator.state_dict(), model_path)

    util.plot_data([(s_train_losses, 'train_losses'), (s_dev_losses, 'dev_losses'), (d_train_losses, 'discriminator_losses'),
               (eval_domain_losses, 'eval_domain_losses')],
              'losses.png')
    util.plot_dice([(train_dices, 'train_dice'), (dev_dices, 'dev_dice'), (eval_dices, 'eval_dice')],
              'dices.png')

    inference_model = nn.Sequential(segmentator, selector, sigmoid)
    inference_model.to(device)
    inference_model.eval()

    print('Dice on annotated: ', calculate_dice(inference_model, dataset_s_test))
    print('Dice on unannotated: ', calculate_dice(inference_model, dataset_eval_test))
Exemplo n.º 13
0
def train(args):
    result_path = 'result/%s/'%args.model
    if not os.path.exists(result_path):
        os.makedirs(result_path)
        os.makedirs('%simage'%result_path)
        os.makedirs('%scheckpoint'%result_path)

    train_set = MyDataset('train', args.label_type, 512)
    train_loader = DataLoader(
        train_set,
        batch_size=args.batchsize,
        shuffle=True, 
        num_workers=args.num_workers)

    # device = 'cuda:0'
    device = 'cuda:6' if torch.cuda.device_count()>1 else 'cuda:0'


    out_channels = 1 if args.label_type=='msk' else 2 
    print(out_channels)
    if args.model=='unet':
        print('using unet as model!')
        model = UNet(out_channels=out_channels)
    elif args.model=='deeplab':
        print('using deeplab as model!')
        model = torch.hub.load('pytorch/vision:v0.9.0',
                 'deeplabv3_resnet101', pretrained=False)
    else:
        print('no model!')

    model = model.to(device)
    # model = nn.DataParallel(model)

    img_show = train_set.__getitem__(0)['x']
    img_show = torch.tensor(img_show).to(device).float()[None, :]

    model.train()
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    loss_list = []
    loss_best = 10
    for epo in tqdm(range(1, args.epochs+1), ascii=True):
        epo_loss = []
        for idx, item in enumerate(train_loader):
            x = item['x'].to(device, dtype=torch.float)
            y = item['y'].to(device, dtype=torch.float)

            optimizer.zero_grad()

            if args.model=='unet':
                pred = model(x)
            elif args.model=='deeplab':
                pred = model(x)['out'][:,0][:,None]

            # print(y.shape, pred.shape)
            loss = criterion(pred, y)
            
            # print(loss.item())
            epo_loss.append(loss.data.item())

            loss.backward()
            optimizer.step()

        epo_loss_mean = np.array(epo_loss).mean()
        # print(epo_loss_mean)
        loss_list.append(epo_loss_mean)
        plot_loss(loss_list, '%simage/loss.png'%result_path)

        with torch.no_grad():
            if args.model=='unet':
                pred = model(img_show.clone())
            elif args.model=='deeplab':
                pred = model(img_show.clone())['out'][:,0][:,None]
            # y = model(img_show)
            # print(img_show.shape)
            if args.label_type=='msk':
                x = img_show[0].cpu().detach().numpy().transpose((1,2,0))
                y = pred[0, 0].cpu().detach().numpy()
            elif args.label_type=='flow':
                x = img_show[0].cpu().detach().numpy().transpose((1,2,0))
                y = pred[0].cpu().detach().numpy().transpose((1,2,0))
                
            plt.subplot(121)
            plt.imshow(x*255)       
            plt.subplot(122)
            plt.imshow(y[:,:,0])  
            plt.savefig('%simage/%d.png'%(result_path, epo))     
            plt.clf()
        #loss
        if epo % 3 ==0:
            torch.save(model, '%scheckpoint/%d.pt'%(result_path, epo))
            if epo_loss_mean < loss_best:
                loss_best = epo_loss_mean
                torch.save(model, '%scheckpoint/best.pt'%(result_path))
            np.save('%sloss.npy'%result_path, np.array(loss_list))
Exemplo n.º 14
0
import torch 
import torch.nn as nn 
from model.unet import UNet


if __name__ == '__main__':
    device = torch.device('cuda:0')

    LEARNING_RATE = 1e-3
    LR_DECAY_STEP = 2
    LR_DECAY_FACTOR = 0.5
    WEIGHT_DECAY = 5e-4
    BATCH_SIZE = 4
    MAX_EPOCHS = 30
    MODEL = UNet(1, 2).to(device)
    OPTIMIZER = torch.optim.Adam(MODEL.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    LR_SCHEDULER = torch.optim.lr_scheduler.StepLR(OPTIMIZER, step_size=LR_DECAY_STEP, gamma=LR_DECAY_FACTOR)
    CRITERION = nn.CrossEntropyLoss().to(device)

    tr_path_raw = 'data/tr/raw'
    tr_path_label = 'data/tr/label'
    ts_path_raw = 'data/ts/raw'
    ts_path_label = 'data/ts/label'

    checkpoints_dir = 'checkpoints'
    checkpoint_frequency = 1000
    dataloaders = make_dataloaders(tr_path_raw, tr_path_label, ts_path_raw, ts_path_label, BATCH_SIZE, n_workers=4)
    comment = 'liver_segmentation_U-Net_on_LITS_dataset_'
    verbose_train = 1
    verbose_val = 500
Exemplo n.º 15
0
def main():
    args = get_args()

    # set GPU device
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu  # default: '0'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # set model
    model = UNet(n_channels=1, n_classes=1).to(device)
    if len(args.gpu) > 1:  # if multi-gpu
        model = torch.nn.DataParallel(model)

    img_size = args.img_size  # default: 512
    # set transforms for dataset
    import torchvision.transforms as transforms
    from my_transforms import GrayScale, Resize, ToTensor, histogram_equalize, gamma_correction
    custom_transforms = transforms.Compose([
        GrayScale(),
        Resize(img_size),
        histogram_equalize(),
        gamma_correction(0.5),
        ToTensor(),
    ])

    # set Dataset and DataLoader
    chn_train = chn_dataset(split='train', transforms=custom_transforms)
    chn_val = chn_dataset(split='val', transforms=custom_transforms)
    mcu_train = mcu_dataset(split='train', transforms=custom_transforms)
    mcu_val = mcu_dataset(split='val', transforms=custom_transforms)

    from torch.utils.data import DataLoader
    dataloader = {
        'train': {
            'chn':
            DataLoader(dataset=chn_train,
                       batch_size=args.batch_size,
                       num_workers=args.n_workers,
                       shuffle=True),
            'mcu':
            DataLoader(dataset=mcu_train,
                       batch_size=args.batch_size,
                       num_workers=args.n_workers,
                       shuffle=True)
        },
        'val': {
            'chn':
            DataLoader(dataset=chn_val,
                       batch_size=args.batch_size,
                       num_workers=args.n_workers),
            'mcu':
            DataLoader(dataset=mcu_val,
                       batch_size=args.batch_size,
                       num_workers=args.n_workers)
        }
    }

    # checkpoint dir
    checkpoint_dir = os.path.join(os.getcwd(), 'checkpoint')
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)
    checkpoint_path = args.load_model

    # set optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)

    # learning rate scheduler
    # from torch.optim.lr_scheduler import StepLR
    # scheduler = StepLR(optimizer, step_size = 3 , gamma = 0.8)
    # option 2.
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=5)

    criterion = nn.BCEWithLogitsLoss()

    train_and_validate(net=model,
                       criterion=criterion,
                       optimizer=optimizer,
                       dataloader=dataloader,
                       device=device,
                       epochs=args.epochs,
                       scheduler=scheduler,
                       load_model=checkpoint_path)