Beispiel #1
0
def test(args):
    print("Predicting ...")
    test_paths = os.listdir(os.path.join(args.dataset_dir, args.test_img_dir))
    print(len(test_paths), 'test images found')
    test_df = pd.DataFrame({'ImageId': test_paths, 'EncodedPixels': None})

    from skimage.morphology import binary_opening, disk

    test_df = test_df[:5000]
    test_loader = make_dataloader(test_df,
                                  args,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  transform=None,
                                  mode='predict')

    model = UNet()
    if args.gpu and torch.cuda.is_available():
        model = model.cuda()
    run_id = 1
    print("Resuming run #{}...".format(run_id))
    model_path = Path('model_{run_id}.pt'.format(run_id=run_id))
    state = torch.load(str(model_path))
    state = {
        key.replace('module.', ''): value
        for key, value in state['model'].items()
    }
    model.load_state_dict(state)

    out_pred_rows = []

    for batch_id, (inputs,
                   image_paths) in enumerate(tqdm(test_loader,
                                                  desc='Predict')):
        if args.gpu and torch.cuda.is_available():
            inputs = inputs.cuda()
        inputs = torch.tensor(inputs)
        outputs = model(inputs)
        for i, image_name in enumerate(image_paths):
            mask = torch.sigmoid(outputs[i, 0]).data.cpu().numpy()
            cur_seg = binary_opening(mask > 0.5, disk(2))
            cur_rles = multi_rle_encode(cur_seg)
            if len(cur_rles) > 0:
                for c_rle in cur_rles:
                    out_pred_rows += [{
                        'ImageId': image_name,
                        'EncodedPixels': c_rle
                    }]
            else:
                out_pred_rows += [{
                    'ImageId': image_name,
                    'EncodedPixels': None
                }]

    submission_df = pd.DataFrame(out_pred_rows)[['ImageId', 'EncodedPixels']]
    submission_df.to_csv('submission.csv', index=False)
    print("done.")
Beispiel #2
0
def main(args):
    train_image_dir = args.train_image_dir
    train_label_dir = args.train_label_dir
    val_image_dir = args.val_image_dir
    val_label_dir = args.val_label_dir

    batch_size = 4
    num_workers = 4
    optimizer = optim.SGD
    criterion = MixLoss(nn.BCEWithLogitsLoss(), 0.5, DiceLoss(), 1)

    thresh = 0.1
    recall_partial = partial(recall, thresh=thresh)
    precision_partial = partial(precision, thresh=thresh)
    fbeta_score_partial = partial(fbeta_score, thresh=thresh)

    model = UNet(1, 1, first_out_channels=16)
    model = nn.DataParallel(model.cuda())

    transforms = [
        tsfm.Window(-200, 1000),
        tsfm.MinMaxNorm(-200, 1000)
    ]
    ds_train = FracNetTrainDataset(train_image_dir, train_label_dir,
        transforms=transforms)
    dl_train = FracNetTrainDataset.get_dataloader(ds_train, batch_size, False,
        num_workers)
    ds_val = FracNetTrainDataset(val_image_dir, val_label_dir,
        transforms=transforms)
    dl_val = FracNetTrainDataset.get_dataloader(ds_val, batch_size, False,
        num_workers)

    databunch = DataBunch(dl_train, dl_val,
        collate_fn=FracNetTrainDataset.collate_fn)

    learn = Learner(
        databunch,
        model,
        opt_func=optimizer,
        loss_func=criterion,
        metrics=[dice, recall_partial, precision_partial, fbeta_score_partial]
    )

    learn.fit_one_cycle(
        200,
        1e-1,
        pct_start=0,
        div_factor=1000,
        callbacks=[
            ShowGraph(learn),
        ]
    )

    if args.save_model:
        save(model.module.state_dict(), "./model_weights.pth")
Beispiel #3
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.saveTest == 'True':
        args.saveTest = True
    elif args.saveTest == 'False':
        args.saveTest = False

    # Check if the save directory exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    cudnn.benchmark = True

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.TenCrop(args.resizedImageSize),
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
            #transforms.Lambda(lambda normalized: torch.stack([transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])(crop) for crop in normalized]))
            #transforms.RandomResizedCrop(224, interpolation=Image.NEAREST),
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomVerticalFlip(),
            #transforms.ToTensor(),
        ]),
        'test': transforms.Compose([
            transforms.Resize((args.imageSize, args.imageSize), interpolation=Image.NEAREST),
            transforms.ToTensor(),
            #transforms.Normalize([0.295, 0.204, 0.197], [0.221, 0.188, 0.182])
        ]),
    }

    # Data Loading
    data_dir = 'datasets/miccaiSegRefined'
    # json path for class definitions
    json_path = 'datasets/miccaiSegClasses.json'

    image_datasets = {x: miccaiSegDataset(os.path.join(data_dir, x), data_transforms[x],
                        json_path) for x in ['train', 'test']}

    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                  batch_size=args.batchSize,
                                                  shuffle=True,
                                                  num_workers=args.workers)
                  for x in ['train', 'test']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}

    # Get the dictionary for the id and RGB value pairs for the dataset
    classes = image_datasets['train'].classes
    key = utils.disentangleKey(classes)
    num_classes = len(key)

    # Initialize the model
    model = UNet(num_classes)

    # # Optionally resume from a checkpoint
    # if args.resume:
    #     if os.path.isfile(args.resume):
    #         print("=> loading checkpoint '{}'".format(args.resume))
    #         checkpoint = torch.load(args.resume)
    #         #args.start_epoch = checkpoint['epoch']
    #         pretrained_dict = checkpoint['state_dict']
    #         pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
    #         model.state_dict().update(pretrained_dict)
    #         model.load_state_dict(model.state_dict())
    #         print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    #     else:
    #         print("=> no checkpoint found at '{}'".format(args.resume))
    #
    #     # # Freeze the encoder weights
    #     # for param in model.encoder.parameters():
    #     #     param.requires_grad = False
    #
    #     optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)
    # else:
    optimizer = optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.wd)

    # Load the saved model
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # Use a learning rate scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    if use_gpu:
        model.cuda()
        criterion.cuda()

    # Initialize an evaluation Object
    evaluator = utils.Evaluate(key, use_gpu)

    for epoch in range(args.start_epoch, args.epochs):
        #adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        print('>>>>>>>>>>>>>>>>>>>>>>>Training<<<<<<<<<<<<<<<<<<<<<<<')
        train(dataloaders['train'], model, criterion, optimizer, scheduler, epoch, key)

        # Evaulate on validation set

        print('>>>>>>>>>>>>>>>>>>>>>>>Testing<<<<<<<<<<<<<<<<<<<<<<<')
        validate(dataloaders['test'], model, criterion, epoch, key, evaluator)

        # Calculate the metrics
        print('>>>>>>>>>>>>>>>>>> Evaluating the Metrics <<<<<<<<<<<<<<<<<')
        IoU = evaluator.getIoU()
        print('Mean IoU: {}, Class-wise IoU: {}'.format(torch.mean(IoU), IoU))
        PRF1 = evaluator.getPRF1()
        precision, recall, F1 = PRF1[0], PRF1[1], PRF1[2]
        print('Mean Precision: {}, Class-wise Precision: {}'.format(torch.mean(precision), precision))
        print('Mean Recall: {}, Class-wise Recall: {}'.format(torch.mean(recall), recall))
        print('Mean F1: {}, Class-wise F1: {}'.format(torch.mean(F1), F1))
        evaluator.reset()

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, filename=os.path.join(args.save_dir, 'checkpoint_{}.tar'.format(epoch)))
Beispiel #4
0
def train(args):
    print("Traning")

    print("Prepaing data")
    masks = pd.read_csv(os.path.join(args.dataset_dir, args.train_masks))
    unique_img_ids = get_unique_img_ids(masks, args)
    train_df, valid_df = get_balanced_train_valid(masks, unique_img_ids, args)

    if args.stage == 0:
        train_shape = (256, 256)
        batch_size = args.stage0_batch_size
        extra_epoch = args.stage0_epochs
    elif args.stage == 1:
        train_shape = (384, 384)
        batch_size = args.stage1_batch_size
        extra_epoch = args.stage1_epochs
    elif args.stage == 2:
        train_shape = (512, 512)
        batch_size = args.stage2_batch_size
        extra_epoch = args.stage2_epochs
    elif args.stage == 3:
        train_shape = (768, 768)
        batch_size = args.stage3_batch_size
        extra_epoch = args.stage3_epochs

    print("Stage {}".format(args.stage))

    train_transform = DualCompose([
        Resize(train_shape),
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        Shift(),
        Transpose(),
        # ImageOnly(RandomBrightness()),
        # ImageOnly(RandomContrast()),
    ])
    val_transform = DualCompose([
        Resize(train_shape),
    ])

    train_dataloader = make_dataloader(train_df,
                                       args,
                                       batch_size,
                                       args.shuffle,
                                       transform=train_transform)
    val_dataloader = make_dataloader(valid_df,
                                     args,
                                     batch_size // 2,
                                     args.shuffle,
                                     transform=val_transform)

    # Build model
    model = UNet()
    optimizer = Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=args.decay_fr, gamma=0.1)
    if args.gpu and torch.cuda.is_available():
        model = model.cuda()

    # Restore model ...
    run_id = 4

    model_path = Path('model_{run_id}.pt'.format(run_id=run_id))
    if not model_path.exists() and args.stage > 0:
        raise ValueError(
            'model_{run_id}.pt does not exist, initial train first.'.format(
                run_id=run_id))
    if model_path.exists():
        state = torch.load(str(model_path))
        last_epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        print('Restore model, epoch {}, step {:,}'.format(last_epoch, step))
    else:
        last_epoch = 1
        step = 0

    log_file = open('train_{run_id}.log'.format(run_id=run_id),
                    'at',
                    encoding='utf8')

    loss_fn = LossBinary(jaccard_weight=args.iou_weight)

    valid_losses = []

    print("Start training ...")
    for _ in range(last_epoch):
        scheduler.step()

    for epoch in range(last_epoch, last_epoch + extra_epoch):
        scheduler.step()
        model.train()
        random.seed()
        tq = tqdm(total=len(train_dataloader) * batch_size)
        tq.set_description('Run Id {}, Epoch {} of {}, lr {}'.format(
            run_id, epoch, last_epoch + extra_epoch,
            args.lr * (0.1**(epoch // args.decay_fr))))
        losses = []
        try:
            mean_loss = 0.
            for i, (inputs, targets) in enumerate(train_dataloader):
                inputs, targets = torch.tensor(inputs), torch.tensor(targets)
                if args.gpu and torch.cuda.is_available():
                    inputs = inputs.cuda()
                    targets = targets.cuda()

                outputs = model(inputs)
                loss = loss_fn(outputs, targets)
                loss.backward()
                optimizer.step()

                step += 1
                tq.update(batch_size)
                losses.append(loss.item())
                mean_loss = np.mean(losses[-args.log_fr:])
                tq.set_postfix(loss="{:.5f}".format(mean_loss))

                if i and (i % args.log_fr) == 0:
                    write_event(log_file, step, loss=mean_loss)
            write_event(log_file, step, loss=mean_loss)
            tq.close()
            save_model(model, epoch, step, model_path)

            valid_metrics = validation(args, model, loss_fn, val_dataloader)
            write_event(log_file, step, **valid_metrics)
            valid_loss = valid_metrics['valid_loss']
            valid_losses.append(valid_loss)

        except KeyboardInterrupt:
            tq.close()
            print('Ctrl+C, saving snapshot')
            save_model(model, epoch, step, model_path)
            print('Terminated.')
    print('Done.')
Beispiel #5
0
class BaseModel:
    losses = {'train': [], 'val': []}
    acces = {'train': [], 'val': []}
    scores = {'train': [], 'val': []}
    pred = {'train': [], 'val': []}
    true = {'train': [], 'val': []}

    def __init__(self, args):
        self.args = args
        self.net = None
        print(args.model_name)
        if args.model_name == 'UNet':
            self.net = UNet(args.in_channels, args.num_classes)
            self.net.apply(weights_init)
        elif args.model_name == 'UNetResNet34':
            self.net = UNetResNet34(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNetResNet152':
            self.net = UNetResNet152(args.num_classes, dropout_2d=0.2)
        elif args.model_name == 'UNet11':
            self.net = UNet11(args.num_classes, pretrained=True)
        elif args.model_name == 'UNetVGG16':
            self.net = UNetVGG16(args.num_classes,
                                 pretrained=True,
                                 dropout_2d=0.0,
                                 is_deconv=True)
        elif args.model_name == 'deeplab50_v2':
            if args.ms:
                raise NotImplemented
            else:
                self.net = deeplab50_v2(args.num_classes,
                                        pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v2':
            if args.ms:
                self.net = ms_deeplab_v2(args.num_classes,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v2(args.num_classes,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3':
            if args.ms:
                self.net = ms_deeplab_v3(args.num_classes,
                                         out_stride=args.out_stride,
                                         pretrained=args.pretrained,
                                         scales=args.ms_scales)
            else:
                self.net = deeplab_v3(args.num_classes,
                                      out_stride=args.out_stride,
                                      pretrained=args.pretrained)
        elif args.model_name == 'deeplab_v3_plus':
            if args.ms:
                self.net = ms_deeplab_v3_plus(args.num_classes,
                                              out_stride=args.out_stride,
                                              pretrained=args.pretrained,
                                              scales=args.ms_scales)
            else:
                self.net = deeplab_v3_plus(args.num_classes,
                                           out_stride=args.out_stride,
                                           pretrained=args.pretrained)

        self.interp = nn.Upsample(size=args.size, mode='bilinear')

        self.iterations = args.epochs
        self.lr_current = args.lr
        self.cuda = args.cuda
        self.phase = args.phase
        self.lr_policy = args.lr_policy
        self.cyclic_m = args.cyclic_m
        if self.lr_policy == 'cyclic':
            print('using cyclic')
            assert self.iterations % self.cyclic_m == 0
        if args.loss == 'CELoss':
            self.criterion = nn.CrossEntropyLoss(size_average=True)
        elif args.loss == 'DiceLoss':
            self.criterion = DiceLoss(num_classes=args.num_classes)
        elif args.loss == 'MixLoss':
            self.criterion = MixLoss(args.num_classes,
                                     weights=args.loss_weights)
        elif args.loss == 'LovaszLoss':
            self.criterion = LovaszSoftmax(per_image=args.loss_per_img)
        elif args.loss == 'FocalLoss':
            self.criterion = FocalLoss(args.num_classes, alpha=None, gamma=2)
        else:
            raise RuntimeError('must define loss')

        if 'deeplab' in args.model_name:
            self.optimizer = optim.SGD(
                [{
                    'params': get_1x_lr_params_NOscale(self.net),
                    'lr': args.lr
                }, {
                    'params': get_10x_lr_params(self.net),
                    'lr': 10 * args.lr
                }],
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
        else:
            self.optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                              self.net.parameters()),
                                       lr=args.lr,
                                       momentum=args.momentum,
                                       weight_decay=args.weight_decay)
        self.iters = 0
        self.best_val = 0.0
        self.count = 0

    def init_model(self):
        if self.args.resume_model:
            saved_state_dict = torch.load(
                self.args.resume_model,
                map_location=lambda storage, loc: storage)
            if self.args.ms:
                new_params = self.net.Scale.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if not (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                             == 'decoder'):
                        new_params[i] = saved_state_dict[i]
                self.net.Scale.load_state_dict(new_params)
            else:
                new_params = self.net.state_dict().copy()
                for i in saved_state_dict:
                    # Scale.layer5.conv2d_list.3.weight
                    i_parts = i.split('.')
                    # print i_parts
                    if (not i_parts[0] == 'layer5') and (not i_parts[0]
                                                         == 'decoder'):
                        # if not i_parts[0] == 'layer5':
                        new_params[i] = saved_state_dict[i]
                self.net.load_state_dict(new_params)

            print('Resuming training, image net loading {}...'.format(
                self.args.resume_model))
            # self.load_weights(self.net, self.args.resume_model)

        if self.args.mGPUs:
            self.net = nn.DataParallel(self.net)

        if self.args.cuda:
            self.net = self.net.cuda()
            cudnn.benchmark = True

    def _adjust_learning_rate(self, epoch):
        """Sets the learning rate to the initial LR decayed by 10 at every specified step
        # Adapted from PyTorch Imagenet example:
        # https://github.com/pytorch/examples/blob/master/imagenet/main.py
        """
        if epoch < int(self.iterations * 0.5):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-4)
        elif epoch < int(self.iterations * 0.85):
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-5)
        else:
            self.lr_current = max(self.lr_current * self.args.gamma, 1e-6)
        self.optimizer.param_groups[0]['lr'] = self.lr_current
        self.optimizer.param_groups[1]['lr'] = self.lr_current * 10

    def save_network(self, net, net_name, epoch, label=''):
        save_fname = '%s_%s_%s.pth' % (epoch, net_name, label)
        save_path = os.path.join(self.args.save_folder, self.args.exp_name,
                                 save_fname)
        torch.save(net.state_dict(), save_path)

    def load_weights(self, net, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            net.load_state_dict(
                torch.load(base_file,
                           map_location=lambda storage, loc: storage))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')

    def load_trained_model(self):
        path = os.path.join(self.args.save_folder, self.args.exp_name,
                            self.args.trained_model)
        print('eval cls, image net loading {}...'.format(path))
        if self.args.ms:
            self.load_weights(self.net.Scale, path)
        else:
            self.load_weights(self.net, path)

    def eval(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        output = []

        for i, image in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            output.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
        return np.array(output)

    def tta(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return np.argmax(results, 1)

    def tta_output(self, dataloaders):
        results = np.zeros(shape=(len(dataloaders[0].dataset),
                                  self.args.num_classes))
        for dataloader in dataloaders:
            output = self.eval(dataloader)
            results += output
        return results

    def test_val(self, dataloader):
        assert self.phase == 'test', "Command arg phase should be 'test'. "
        from tqdm import tqdm
        self.net.eval()
        predict = []
        true = []
        t1 = time.time()

        for i, (image, mask) in tqdm(enumerate(dataloader)):
            if self.cuda:
                image = Variable(image.cuda(), volatile=True)
                label_image = Variable(mask.cuda(), volatile=True)
            else:
                image = Variable(image, volatile=True)
                label_image = Variable(mask, volatile=True)

            # cls forward
            out = self.net(image)
            if isinstance(out, list):
                out_max = out[-1]
                if out_max.size(2) != label_image.size(2):
                    out = self.interp(out_max)
            else:
                if out.size(2) != image.size(2):
                    out = self.interp(out)
            # out [bs * num_tta, c, h, w]
            if self.args.use_tta:
                num_tta = len(tta_config)
                # out = F.softmax(out, dim=1)
                out = detta_score(
                    out.view(num_tta, -1, self.args.num_classes, out.size(2),
                             out.size(3)))  # [num_tta, bs, nclass, H, W]
                out = out.mean(dim=0)  # [bs, nclass, H, W]
            out = F.softmax(out)
            if self.args.aug == 'heng':
                out = out[:, :, 11:11 + 202, 11:11 + 202]
            predict.extend([
                resize(pred[1].data.cpu().numpy(), (101, 101)) for pred in out
            ])
            # predict.extend([pred[1, :101, :101].data.cpu().numpy() for pred in out])
            # pred.extend(out.data.cpu().numpy())
            true.extend(label_image.data.cpu().numpy())
        # pred_all = np.argmax(np.array(pred), 1)
        for t in np.arange(0.05, 0.51, 0.01):
            pred_all = np.array(predict) > t
            true_all = np.array(true).astype(np.int)
            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)
            mean_iou, iou_t = mIoU(true_all, pred_all)
            print('threshold : {:.4f}'.format(t))
            print('mean IoU : {:.4f}, IoU threshold : {:.4f}'.format(
                mean_iou, iou_t))

        return predict, true

    def run_epoch(self, dataloader, writer, epoch, train=True, metrics=True):
        if train:
            self.net.train()
            flag = 'train'
        else:
            self.net.eval()
            flag = 'val'
        t2 = time.time()
        for image, mask in dataloader:
            if train and self.lr_policy != 'step':
                adjust_learning_rate(self.args.lr, self.optimizer, self.iters,
                                     self.iterations * len(dataloader), 0.9,
                                     self.cyclic_m, self.lr_policy)
                self.iters += 1

            if self.cuda:
                image = Variable(image.cuda(), volatile=(not train))
                label_image = Variable(mask.cuda(), volatile=(not train))
            else:
                image = Variable(image, volatile=(not train))
                label_image = Variable(mask, volatile=(not train))
            # cls forward
            out = self.net(image)

            if isinstance(out, list):
                out_max = None
                loss = 0.0
                for i, out_scale in enumerate(out):
                    if out_scale.size(2) != label_image.size(2):
                        out_scale = self.interp(out_scale)
                    if i == (len(out) - 1):
                        out_max = out_scale
                    loss += self.criterion(out_scale, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out_max.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            else:
                if out.size(-1) != label_image.size(-1):
                    out = self.interp(out)

                loss = self.criterion(out, label_image)
                label_image_np = label_image.data.cpu().numpy()
                sig_out_np = out.data.cpu().numpy()
                acc = accuracy(label_image_np, np.argmax(sig_out_np, 1))

                self.pred[flag].extend(sig_out_np)
                self.true[flag].extend(label_image_np)

                self.losses[flag].append(loss.data[0])
                self.acces[flag].append(acc)

            if train:
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        if metrics:
            n = len(self.losses[flag])
            loss = sum(self.losses[flag]) / n
            scalars = [
                loss,
            ]
            names = [
                'loss',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_loss')

            all_acc = sum(self.acces[flag]) / n
            scalars = [
                all_acc,
            ]
            names = [
                'all_acc',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_acc')

            # all_score = sum(self.scores[flag]) / n
            # scalars = [all_score, ]
            # names = ['all_score', ]
            # write_scalars(writer, scalars, names, epoch, tag=flag + '_score')

            pred_all = np.argmax(np.array(self.pred[flag]), 1)
            true_all = np.array(self.true[flag]).astype(np.int)
            mean_iou, iou_t = mIoU(true_all, pred_all)

            # new_iou = intersection_over_union(true_all, pred_all)
            # new_iou_t = intersection_over_union_thresholds(true_all, pred_all)

            scalars = [
                mean_iou,
                iou_t,
            ]
            names = [
                'mIoU',
                'mIoU_threshold',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_IoU')

            scalars = [
                self.optimizer.param_groups[0]['lr'],
            ]
            names = [
                'learning_rate',
            ]
            write_scalars(writer, scalars, names, epoch, tag=flag + '_lr')

            print(
                '{} loss: {:.4f} | acc: {:.4f} | mIoU: {:.4f} | mIoU_threshold: {:.4f} |  n_iter: {} |  learning_rate: {} | time: {:.2f}'
                .format(flag, loss, all_acc, mean_iou, iou_t, epoch,
                        self.optimizer.param_groups[0]['lr'],
                        time.time() - t2))

            self.losses[flag] = []
            self.pred[flag] = []
            self.true[flag] = []
            self.acces[flag] = []
            self.scores[flag] = []

            if (not train) and (iou_t >= self.best_val):
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(self.net.module.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net.Scale,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                else:
                    if self.args.mGPUs:
                        self.save_network(self.net.module,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                    else:
                        self.save_network(self.net,
                                          self.args.model_name,
                                          epoch=epoch,
                                          label='best')
                print(
                    'val improve from {:.4f} to {:.4f} saving in best val_iteration {}'
                    .format(self.best_val, iou_t, epoch))
                self.best_val = iou_t
                self.count = 0

            if (not train) and (self.best_val - iou_t > 0.003) and (
                    self.count < 10) and (self.lr_policy == 'step'):
                self.count += 1
            if (not train) and (self.count >= 10) and (self.lr_policy
                                                       == 'step'):
                self._adjust_learning_rate(epoch)
                self.count = 0

    def train_val(self, dataloader_train, dataloader_val, writer):
        val_epoch = 0
        for epoch in range(self.iterations):
            if (self.lr_policy == 'cyclic') and (
                    epoch % int(self.iterations / self.cyclic_m) == 0):
                print('-------start cycle {}------------'.format(
                    epoch // int(self.iterations / self.cyclic_m)))
                self.best_val = 0.0
            self.run_epoch(dataloader_train,
                           writer,
                           epoch,
                           train=True,
                           metrics=True)
            self.run_epoch(dataloader_val,
                           writer,
                           val_epoch,
                           train=False,
                           metrics=True)
            val_epoch += 1
            if (epoch + 1) % self.args.save_freq == 0:
                if self.args.ms:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net.Scale,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                else:
                    if self.args.mGPUs:
                        self.save_network(
                            self.net.module,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                    else:
                        self.save_network(
                            self.net,
                            self.args.model_name,
                            epoch=val_epoch,
                        )
                print('saving in val_iteration {}'.format(val_epoch))
Beispiel #6
0
def train(datafile):

    # model = ResUNet(n_classes=2)
    model = UNet(n_channels=3, n_classes=2)

    if torch.cuda.is_available():
        model.cuda()
    # criterion = SoftDiceLoss(batch_dice=True)
    criterion_CE = nn.CrossEntropyLoss()
    criterion_SD = SoftDiceLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=LEARNING_RATE,
                                momentum=0.9)

    vis = PytorchVisdomLogger(name="GIANA", port=8080)

    giana_transform, giana_train_loader, giana_valid_loader = giana_data_pipeline(
        datafile)

    for epoch in range(EPOCHS):
        iteration = 0
        for iteration, (images, labels) in enumerate(giana_train_loader):
            # print('TRAIN', images.shape, labels.shape)
            images, labels = giana_transform.apply_transform([images, labels])

            labels_onehot = make_one_hot(labels, 2)
            # for images, labels in giana_pool.imap_unordered(giana_transform.apply_transform, giana_iter):
            if torch.cuda.is_available():
                images = images.cuda()
                labels_onehot = labels_onehot.cuda()

            optimizer.zero_grad()
            model.train()
            predictions = model(images)
            predictions_softmax = F.softmax(predictions, dim=1)

            # loss = 0.75 * criterion_CE(predictions, labels.squeeze().cuda().long()) + 0.25 * criterion_SD(predictions_softmax, labels_onehot)
            loss = criterion_CE(predictions, labels.squeeze().cuda().long())
            # loss = criterion_SD(predictions_softmax, labels_onehot)
            loss.backward()
            optimizer.step()

            # iteration += 1
            if iteration % PRINT_AFTER_ITERATIONS == 0:

                # print('Epoch: {0}, Iteration: {1}, Loss: {2}, Valid dice score: {3}'.format(epoch, iteration, loss, score))
                print('Epoch: {0}, Iteration: {1}, Loss: {2}'.format(
                    epoch, iteration, loss))

                image_args = {'normalize': True, 'range': (0, 1)}
                # viz.show_image_grid(images=images.cpu()[:, 0, ].unsqueeze(1), name='Images_train', image_args=image_args)
                vis.show_image_grid(
                    images=predictions_softmax.cpu()[:, 0, ].unsqueeze(1),
                    name='Predictions_1',
                    image_args=image_args)
                vis.show_image_grid(
                    images=predictions_softmax.cpu()[:, 1, ].unsqueeze(1),
                    name='Predictions_2',
                    image_args=image_args)
                vis.show_image_grid(images=labels.cpu(), name='Ground truth')
                vis.show_value(value=loss.item(),
                               name='Train_Loss',
                               label='Loss',
                               counter=epoch + (iteration / MAX_ITERATIONS))

            if iteration == MAX_ITERATIONS:
                break

        score = model.predict(giana_valid_loader, SCORE_TYPE,
                              MAX_VALIDATION_ITERATIONS, vis)
        vis.show_value(value=np.asarray([score]),
                       name='TestDiceScore',
                       label='Dice',
                       counter=epoch)
        print(
            '\n--------------------------------------------------\nEpoch: {0}, Score: {1}, Loss: {2}\n--------------------------------------------------\n'
            .format(epoch, score, loss))
Beispiel #7
0
from model.train import train
from model.unet import UNet
from model.dataloader import testloader
from model.evaluate import evaluate
import torch

COLAB = True
BATCH_SIZE = 1
PATH = 'unet_augment.pt'
# PATH = '../drive/My Drive/Colab Notebooks/im2height.pt'

test_loader = testloader(colab=COLAB, batch_size=BATCH_SIZE)

net = UNet()
net.load_state_dict(torch.load(PATH))
if torch.cuda.is_available():
    net.cuda()
criterion = torch.nn.L1Loss()
evaluate(net, test_loader, criterion=criterion)