Beispiel #1
0
def main():
    global args
    args = parser.parse_args()
    args.batch_size = 1  # only segment one image for experiment

    model_dir = os.path.dirname(args.dir)
    core_config_path = os.path.join(model_dir, 'configs/core.config')
    unet_config_path = os.path.join(model_dir, 'configs/unet.config')

    core_config = CoreConfig()
    core_config.read(core_config_path)
    print('Using core configuration from {}'.format(core_config_path))

    # loading Unet configuration
    unet_config = UnetConfig()
    unet_config.read(unet_config_path, args.train_image_size)
    print('Using unet configuration from {}'.format(unet_config_path))

    offset_list = core_config.offsets
    print("offsets are: {}".format(offset_list))

    # model configurations from core config
    num_classes = core_config.num_classes
    num_colors = core_config.num_colors
    num_offsets = len(core_config.offsets)
    # model configurations from unet config
    start_filters = unet_config.start_filters
    up_mode = unet_config.up_mode
    merge_mode = unet_config.merge_mode
    depth = unet_config.depth

    model = UNet(num_classes, num_offsets,
                 in_channels=num_colors, depth=depth,
                 start_filts=start_filters,
                 up_mode=up_mode,
                 merge_mode=merge_mode)

    model_path = os.path.join(model_dir, args.model)
    if os.path.isfile(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("loaded.")
    else:
        print("=> no checkpoint found at '{}'".format(model_path))

    testset = WaldoTestset(args.test_data, args.train_image_size,
                           job=args.job, num_jobs=args.num_jobs)
    print('Total samples in the test set: {0}'.format(len(testset)))

    dataloader = torch.utils.data.DataLoader(
        testset, num_workers=1, batch_size=args.batch_size)

    segment_dir = args.dir
    if not os.path.exists(segment_dir):
        os.makedirs(segment_dir)
    segment(dataloader, segment_dir, model, core_config)
    make_submission(segment_dir, args.csv)
Beispiel #2
0
def main():
    global args, best_loss
    args = parser.parse_args()

    if args.tensorboard:
        from tensorboard_logger import configure
        print("Using tensorboard")
        configure("%s" % (args.dir))

    # loading core configuration
    c_config = CoreConfig()
    if args.core_config == '':
        print('No core config file given, using default core configuration')
    if not os.path.exists(args.core_config):
        sys.exit('Cannot find the config file: {}'.format(args.core_config))
    else:
        c_config.read(args.core_config)
        print('Using core configuration from {}'.format(args.core_config))

    # loading Unet configuration
    u_config = UnetConfig()
    if args.unet_config == '':
        print('No unet config file given, using default unet configuration')
    if not os.path.exists(args.unet_config):
        sys.exit('Cannot find the unet configuration file: {}'.format(
            args.unet_config))
    else:
        # need train_image_size for validation
        u_config.read(args.unet_config, args.train_image_size)
        print('Using unet configuration from {}'.format(args.unet_config))

    offset_list = c_config.offsets
    print("offsets are: {}".format(offset_list))

    # model configurations from core config
    num_classes = c_config.num_classes
    num_colors = c_config.num_colors
    num_offsets = len(c_config.offsets)
    # model configurations from unet config
    start_filters = u_config.start_filters
    up_mode = u_config.up_mode
    merge_mode = u_config.merge_mode
    depth = u_config.depth

    train_data = args.train_dir + '/train'
    val_data = args.train_dir + '/val'

    trainset = WaldoDataset(train_data, c_config, args.train_image_size)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              num_workers=4,
                                              batch_size=args.batch_size,
                                              shuffle=True)

    valset = WaldoDataset(val_data, c_config, args.train_image_size)
    valloader = torch.utils.data.DataLoader(valset,
                                            num_workers=4,
                                            batch_size=args.batch_size)

    NUM_TRAIN = len(trainset)
    NUM_VAL = len(valset)
    NUM_ALL = NUM_TRAIN + NUM_VAL
    print('Total samples: {0} \n'
          'Using {1} samples for training, '
          '{2} samples for validation'.format(NUM_ALL, NUM_TRAIN, NUM_VAL))

    # create model
    model = UNet(num_classes,
                 num_offsets,
                 in_channels=num_colors,
                 depth=depth,
                 start_filts=start_filters,
                 up_mode=up_mode,
                 merge_mode=merge_mode).cuda()

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # define optimizer
    # optimizer = t.optim.Adam(model.parameters(), lr=1e-3)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    # Train
    for epoch in range(args.start_epoch, args.epochs):
        Train(trainloader, model, optimizer, epoch)
        val_loss = Validate(valloader, model, epoch)
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_loss,
            }, is_best)
    print('Best validation loss: ', best_loss)

    # visualize some example outputs
    outdir = '{}/imgs'.format(args.dir)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    sample(model, valloader, outdir, c_config)
import torch
from torchvision import transforms
from tqdm import tqdm
import utils
from torch.utils.data import DataLoader
from models.Unet import UNet
import os
import matplotlib.pyplot as plt

# hyperparams

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=4, bilinear=True)
print(model)
model.load_state_dict(torch.load('weights/Unet_e3.pth'))
model.to(device)

transform = transforms.Compose([
    utils.transforms.RandomMirror(),
    utils.transforms.ToTensor(),
    utils.transforms.Downsize(2)
])

dataset = utils.datasets.SteelDefectDataset(
    csv_file='train.csv',
    root_dir='data/severstal-steel-defect-detection',
    transform=transform)
test_loader = DataLoader(dataset, batch_size=1, shuffle=True)

for batch, data in tqdm(enumerate(test_loader),
                        total=len(test_loader),
def main():
    global args, best_loss
    args = parser.parse_args()

    if args.tensorboard:
        print("Using tensorboard")
        configure("exp/%s" % (args.name))

    if not (os.path.exists(args.train_data) and os.path.exists(args.train_data)
            and os.path.exists(args.test_data)):
        train, val, test = DataProcess(args.train_path, args.test_path, 0.9,
                                       args.img_channels)
        torch.save(train, args.train_data)
        torch.save(val, args.val_data)
        torch.save(test, args.test_data)

    s_trans = tsf.Compose([
        tsf.ToPILImage(),
        tsf.Resize((args.img_height, args.img_width)),
        tsf.ToTensor(),
    ])

    offset_list = [(1, 1), (0, -2)]

    # split the training set into training set and validation set
    trainset = Dataset(args.train_data, s_trans, offset_list, args.num_classes,
                       args.img_height, args.img_width)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              num_workers=1,
                                              batch_size=args.batch_size)

    valset = Dataset(args.val_data, s_trans, offset_list, args.num_classes,
                     args.img_height, args.img_width)
    valloader = torch.utils.data.DataLoader(valset,
                                            num_workers=1,
                                            batch_size=args.batch_size)

    # datailer = iter(trainloader)
    # img, bound, class_id = datailer.next()
    # # print img.shape, bound.shape, class_id.shape
    # torch.set_printoptions(threshold=5000)
    # print bound.shape
    # torchvision.utils.save_image(img, 'raw.png')
    # torchvision.utils.save_image(bound[:, 0:1, :, :], 'bound1.png')
    # torchvision.utils.save_image(bound[:, 1:2, :, :], 'bound2.png')
    # torchvision.utils.save_image(class_id[:, 0:1, :, :], 'class1.png')
    # torchvision.utils.save_image(class_id[:, 1:2, :, :], 'class2.png')
    # sys.exit('stop')

    NUM_TRAIN = len(trainset)
    NUM_VAL = len(valset)
    NUM_ALL = NUM_TRAIN + NUM_VAL
    print(
        'Total samples: {0} \n'
        'Using {1} samples for training, '
        '{2} samples for validation'.format(NUM_ALL, NUM_TRAIN, NUM_VAL))

    # create model
    model = UNet(args.num_classes,
                 len(offset_list),
                 in_channels=3,
                 depth=args.depth).cuda()
    # model = UNet(3, 1, len(offset_list))

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # define optimizer
    # optimizer = t.optim.Adam(model.parameters(), lr=1e-3)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    # Train
    for epoch in range(args.start_epoch, args.epochs):
        Train(trainloader, model, optimizer, epoch)
        val_loss = Validate(valloader, model, epoch)
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_loss,
            }, is_best)
    print 'Best validation loss: ', best_loss

    # Visualize some predicted masks on training data to get a better intuition
    # about the performance. Comment it if not necessary.
    datailer = iter(trainloader)
    img, classification, bound = datailer.next()
    torchvision.utils.save_image(img, 'imgs/raw.png')
    for i in range(len(offset_list)):
        torchvision.utils.save_image(bound[:, i:i + 1, :, :],
                                     'imgs/bound_{}.png'.format(i))
    for i in range(args.num_classes):
        torchvision.utils.save_image(classification[:, i:i + 1, :, :],
                                     'imgs/class_{}.png'.format(i))
    img = torch.autograd.Variable(img).cuda()
    predictions = model(img)
    predictions = predictions.data
    class_pred = predictions[:, :args.num_classes, :, :]
    bound_pred = predictions[:, args.num_classes:, :, :]
    for i in range(len(offset_list)):
        torchvision.utils.save_image(bound_pred[:, i:i + 1, :, :],
                                     'imgs/bound_pred{}.png'.format(i))
    for i in range(args.num_classes):
        torchvision.utils.save_image(class_pred[:, i:i + 1, :, :],
                                     'imgs/class_pred{}.png'.format(i))
Beispiel #5
0
def main():
    global args
    args = parser.parse_args()
    args.batch_size = 1  # only segment one image for experiment

    core_config_path = os.path.join(args.dir, 'configs/core.config')
    unet_config_path = os.path.join(args.dir, 'configs/unet.config')

    core_config = CoreConfig()
    core_config.read(core_config_path)
    print('Using core configuration from {}'.format(core_config_path))

    # loading Unet configuration
    unet_config = UnetConfig()
    unet_config.read(unet_config_path, args.train_image_size)
    print('Using unet configuration from {}'.format(unet_config_path))

    offset_list = core_config.offsets
    print("offsets are: {}".format(offset_list))

    # model configurations from core config
    num_classes = core_config.num_classes
    num_colors = core_config.num_colors
    num_offsets = len(core_config.offsets)
    # model configurations from unet config
    start_filters = unet_config.start_filters
    up_mode = unet_config.up_mode
    merge_mode = unet_config.merge_mode
    depth = unet_config.depth

    model = UNet(num_classes,
                 num_offsets,
                 in_channels=num_colors,
                 depth=depth,
                 start_filts=start_filters,
                 up_mode=up_mode,
                 merge_mode=merge_mode)

    model_path = os.path.join(args.dir, args.model)
    if os.path.isfile(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("loaded.")
    else:
        print("=> no checkpoint found at '{}'".format(model_path))

    model.eval()  # convert the model into evaluation mode

    testset = WaldoDataset(args.test_data, core_config, args.train_image_size)
    print('Total samples in the test set: {0}'.format(len(testset)))

    dataloader = torch.utils.data.DataLoader(testset,
                                             num_workers=1,
                                             batch_size=args.batch_size)

    segment_dir = '{}/segment'.format(args.dir)
    if not os.path.exists(segment_dir):
        os.makedirs(segment_dir)
    img, class_pred, adj_pred = sample(model, dataloader, segment_dir,
                                       core_config)

    seg = ObjectSegmenter(class_pred[0].detach().numpy(),
                          adj_pred[0].detach().numpy(), num_classes,
                          offset_list)
    mask_pred, object_class = seg.run_segmentation()
    x = {}
    # from (color, height, width) to (height, width, color)
    x['img'] = np.moveaxis(img[0].numpy(), 0, -1)
    x['mask'] = mask_pred.astype(int)
    x['object_class'] = object_class
    visualize_mask(x, core_config)
Beispiel #6
0
def main():
    global args, best_loss
    args = parser.parse_args()

    if args.tensorboard:
        print("Using tensorboard")
        configure("exp/%s" % (args.name))

    if not (os.path.exists(args.train_data) and os.path.exists(args.train_data)
            and os.path.exists(args.test_data)):
        train, val, test = DataProcess(args.train_path, args.test_path, 0.9,
                                       args.img_channels)
        t.save(train, args.train_data)
        t.save(val, args.val_data)
        t.save(test, args.test_data)

    s_trans = tsf.Compose([
        tsf.ToPILImage(),
        tsf.Resize((args.img_height, args.img_width)),
        tsf.ToTensor(),
    ])

    t_trans = tsf.Compose([
        tsf.ToPILImage(),
        tsf.Resize((args.img_height, args.img_width),
                   interpolation=PIL.Image.NEAREST),
        tsf.ToTensor(),
    ])

    # split the training set into training set and validation set
    trainset = TrainDataset(args.train_data, s_trans, t_trans)
    trainloader = t.utils.data.DataLoader(trainset,
                                          num_workers=1,
                                          batch_size=args.batch_size,
                                          shuffle=True)

    valset = TrainDataset(args.val_data, s_trans, t_trans)
    valloader = t.utils.data.DataLoader(valset,
                                        num_workers=1,
                                        batch_size=args.batch_size)

    NUM_TRAIN = len(trainset)
    NUM_VAL = len(valset)
    NUM_ALL = NUM_TRAIN + NUM_VAL
    print(
        'Total samples: {0} \n'
        'Using {1} samples for training, '
        '{2} samples for validation'.format(NUM_ALL, NUM_TRAIN, NUM_VAL))

    testset = TestDataset(args.test_data, s_trans)
    testloader = t.utils.data.DataLoader(testset, num_workers=1, batch_size=1)

    # create model
    model = UNet(1, in_channels=3, depth=args.depth).cuda()

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = t.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # define optimizer
    optimizer = t.optim.Adam(model.parameters(), lr=1e-3)

    # Train
    for epoch in range(args.start_epoch, args.epochs):
        Train(trainloader, model, optimizer, epoch)
        val_loss = Validate(valloader, model, epoch)
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_loss,
            }, is_best)
    print 'Best validation loss: ', best_loss

    # Visualize some predicted masks on training data to get a better intuition
    # about the performance. Comment it if not necessary.
    datailer = iter(trainloader)
    img, mask = datailer.next()
    torchvision.utils.save_image(img, 'raw.png')
    torchvision.utils.save_image(mask, 'mask.png')
    img = t.autograd.Variable(img).cuda()
    img_pred = model(img)
    img_pred = img_pred.data
    torchvision.utils.save_image(img_pred > 0.5, 'predicted.png')

    # Load the best model and evaluate on test set
    checkpoint = t.load('exp/%s/' % (args.name) + 'model_best.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    Test(testloader, model)
Beispiel #7
0
    pred_img = torch.argmax(pred,dim=1)
    img = sitk.GetImageFromArray(np.squeeze(np.array(pred_img.numpy(),dtype='uint8'),axis=0))
    sitk.WriteImage(img, os.path.join(save_path, filename))

    # save_tool.save(filename)
    print('\nAverage loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n'.format(
        val_loss, val_dice0, val_dice1, val_dice2))
    return val_loss, val_dice0, val_dice1, val_dice2

if __name__ == '__main__':
    args = config.args
    device = torch.device('cpu' if args.cpu else 'cuda')
    # model info
    model = UNet(1, [32, 48, 64, 96, 128], 3, net_mode='3d',conv_block=RecombinationBlock).to(device)
    ckpt = torch.load('./output/{}/best_model.pth'.format(args.save))
    model.load_state_dict(ckpt['net'])

    # data info
    test_data_path = r'F:\datasets\LiTS\test'
    result_save_path = r'./output/{}/result'.format(args.save)
    if not os.path.exists(result_save_path): os.mkdir(result_save_path)
    cut_param = {'patch_s': 32,
                 'patch_h': 128,
                 'patch_w': 128,
                 'stride_s': 24,
                 'stride_h': 96,
                 'stride_w': 96}
    datasets = test_Datasets(test_data_path,cut_param,resize_scale=args.resize_scale)
    for dataset,file_idx in datasets:
        test(model, dataset,result_save_path,'result-'+file_idx)